qexpr.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. from sympy.core.expr import Expr
  2. from sympy.core.symbol import Symbol
  3. from sympy.core.sympify import sympify
  4. from sympy.matrices.dense import Matrix
  5. from sympy.printing.pretty.stringpict import prettyForm
  6. from sympy.core.containers import Tuple
  7. from sympy.utilities.iterables import is_sequence
  8. from sympy.physics.quantum.dagger import Dagger
  9. from sympy.physics.quantum.matrixutils import (
  10. numpy_ndarray, scipy_sparse_matrix,
  11. to_sympy, to_numpy, to_scipy_sparse
  12. )
  13. __all__ = [
  14. 'QuantumError',
  15. 'QExpr'
  16. ]
  17. #-----------------------------------------------------------------------------
  18. # Error handling
  19. #-----------------------------------------------------------------------------
  20. class QuantumError(Exception):
  21. pass
  22. def _qsympify_sequence(seq):
  23. """Convert elements of a sequence to standard form.
  24. This is like sympify, but it performs special logic for arguments passed
  25. to QExpr. The following conversions are done:
  26. * (list, tuple, Tuple) => _qsympify_sequence each element and convert
  27. sequence to a Tuple.
  28. * basestring => Symbol
  29. * Matrix => Matrix
  30. * other => sympify
  31. Strings are passed to Symbol, not sympify to make sure that variables like
  32. 'pi' are kept as Symbols, not the SymPy built-in number subclasses.
  33. Examples
  34. ========
  35. >>> from sympy.physics.quantum.qexpr import _qsympify_sequence
  36. >>> _qsympify_sequence((1,2,[3,4,[1,]]))
  37. (1, 2, (3, 4, (1,)))
  38. """
  39. return tuple(__qsympify_sequence_helper(seq))
  40. def __qsympify_sequence_helper(seq):
  41. """
  42. Helper function for _qsympify_sequence
  43. This function does the actual work.
  44. """
  45. #base case. If not a list, do Sympification
  46. if not is_sequence(seq):
  47. if isinstance(seq, Matrix):
  48. return seq
  49. elif isinstance(seq, str):
  50. return Symbol(seq)
  51. else:
  52. return sympify(seq)
  53. # base condition, when seq is QExpr and also
  54. # is iterable.
  55. if isinstance(seq, QExpr):
  56. return seq
  57. #if list, recurse on each item in the list
  58. result = [__qsympify_sequence_helper(item) for item in seq]
  59. return Tuple(*result)
  60. #-----------------------------------------------------------------------------
  61. # Basic Quantum Expression from which all objects descend
  62. #-----------------------------------------------------------------------------
  63. class QExpr(Expr):
  64. """A base class for all quantum object like operators and states."""
  65. # In sympy, slots are for instance attributes that are computed
  66. # dynamically by the __new__ method. They are not part of args, but they
  67. # derive from args.
  68. # The Hilbert space a quantum Object belongs to.
  69. __slots__ = ('hilbert_space')
  70. is_commutative = False
  71. # The separator used in printing the label.
  72. _label_separator = ''
  73. @property
  74. def free_symbols(self):
  75. return {self}
  76. def __new__(cls, *args, **kwargs):
  77. """Construct a new quantum object.
  78. Parameters
  79. ==========
  80. args : tuple
  81. The list of numbers or parameters that uniquely specify the
  82. quantum object. For a state, this will be its symbol or its
  83. set of quantum numbers.
  84. Examples
  85. ========
  86. >>> from sympy.physics.quantum.qexpr import QExpr
  87. >>> q = QExpr(0)
  88. >>> q
  89. 0
  90. >>> q.label
  91. (0,)
  92. >>> q.hilbert_space
  93. H
  94. >>> q.args
  95. (0,)
  96. >>> q.is_commutative
  97. False
  98. """
  99. # First compute args and call Expr.__new__ to create the instance
  100. args = cls._eval_args(args, **kwargs)
  101. if len(args) == 0:
  102. args = cls._eval_args(tuple(cls.default_args()), **kwargs)
  103. inst = Expr.__new__(cls, *args)
  104. # Now set the slots on the instance
  105. inst.hilbert_space = cls._eval_hilbert_space(args)
  106. return inst
  107. @classmethod
  108. def _new_rawargs(cls, hilbert_space, *args, **old_assumptions):
  109. """Create new instance of this class with hilbert_space and args.
  110. This is used to bypass the more complex logic in the ``__new__``
  111. method in cases where you already have the exact ``hilbert_space``
  112. and ``args``. This should be used when you are positive these
  113. arguments are valid, in their final, proper form and want to optimize
  114. the creation of the object.
  115. """
  116. obj = Expr.__new__(cls, *args, **old_assumptions)
  117. obj.hilbert_space = hilbert_space
  118. return obj
  119. #-------------------------------------------------------------------------
  120. # Properties
  121. #-------------------------------------------------------------------------
  122. @property
  123. def label(self):
  124. """The label is the unique set of identifiers for the object.
  125. Usually, this will include all of the information about the state
  126. *except* the time (in the case of time-dependent objects).
  127. This must be a tuple, rather than a Tuple.
  128. """
  129. if len(self.args) == 0: # If there is no label specified, return the default
  130. return self._eval_args(list(self.default_args()))
  131. else:
  132. return self.args
  133. @property
  134. def is_symbolic(self):
  135. return True
  136. @classmethod
  137. def default_args(self):
  138. """If no arguments are specified, then this will return a default set
  139. of arguments to be run through the constructor.
  140. NOTE: Any classes that override this MUST return a tuple of arguments.
  141. Should be overridden by subclasses to specify the default arguments for kets and operators
  142. """
  143. raise NotImplementedError("No default arguments for this class!")
  144. #-------------------------------------------------------------------------
  145. # _eval_* methods
  146. #-------------------------------------------------------------------------
  147. def _eval_adjoint(self):
  148. obj = Expr._eval_adjoint(self)
  149. if obj is None:
  150. obj = Expr.__new__(Dagger, self)
  151. if isinstance(obj, QExpr):
  152. obj.hilbert_space = self.hilbert_space
  153. return obj
  154. @classmethod
  155. def _eval_args(cls, args):
  156. """Process the args passed to the __new__ method.
  157. This simply runs args through _qsympify_sequence.
  158. """
  159. return _qsympify_sequence(args)
  160. @classmethod
  161. def _eval_hilbert_space(cls, args):
  162. """Compute the Hilbert space instance from the args.
  163. """
  164. from sympy.physics.quantum.hilbert import HilbertSpace
  165. return HilbertSpace()
  166. #-------------------------------------------------------------------------
  167. # Printing
  168. #-------------------------------------------------------------------------
  169. # Utilities for printing: these operate on raw SymPy objects
  170. def _print_sequence(self, seq, sep, printer, *args):
  171. result = []
  172. for item in seq:
  173. result.append(printer._print(item, *args))
  174. return sep.join(result)
  175. def _print_sequence_pretty(self, seq, sep, printer, *args):
  176. pform = printer._print(seq[0], *args)
  177. for item in seq[1:]:
  178. pform = prettyForm(*pform.right(sep))
  179. pform = prettyForm(*pform.right(printer._print(item, *args)))
  180. return pform
  181. # Utilities for printing: these operate prettyForm objects
  182. def _print_subscript_pretty(self, a, b):
  183. top = prettyForm(*b.left(' '*a.width()))
  184. bot = prettyForm(*a.right(' '*b.width()))
  185. return prettyForm(binding=prettyForm.POW, *bot.below(top))
  186. def _print_superscript_pretty(self, a, b):
  187. return a**b
  188. def _print_parens_pretty(self, pform, left='(', right=')'):
  189. return prettyForm(*pform.parens(left=left, right=right))
  190. # Printing of labels (i.e. args)
  191. def _print_label(self, printer, *args):
  192. """Prints the label of the QExpr
  193. This method prints self.label, using self._label_separator to separate
  194. the elements. This method should not be overridden, instead, override
  195. _print_contents to change printing behavior.
  196. """
  197. return self._print_sequence(
  198. self.label, self._label_separator, printer, *args
  199. )
  200. def _print_label_repr(self, printer, *args):
  201. return self._print_sequence(
  202. self.label, ',', printer, *args
  203. )
  204. def _print_label_pretty(self, printer, *args):
  205. return self._print_sequence_pretty(
  206. self.label, self._label_separator, printer, *args
  207. )
  208. def _print_label_latex(self, printer, *args):
  209. return self._print_sequence(
  210. self.label, self._label_separator, printer, *args
  211. )
  212. # Printing of contents (default to label)
  213. def _print_contents(self, printer, *args):
  214. """Printer for contents of QExpr
  215. Handles the printing of any unique identifying contents of a QExpr to
  216. print as its contents, such as any variables or quantum numbers. The
  217. default is to print the label, which is almost always the args. This
  218. should not include printing of any brackets or parenteses.
  219. """
  220. return self._print_label(printer, *args)
  221. def _print_contents_pretty(self, printer, *args):
  222. return self._print_label_pretty(printer, *args)
  223. def _print_contents_latex(self, printer, *args):
  224. return self._print_label_latex(printer, *args)
  225. # Main printing methods
  226. def _sympystr(self, printer, *args):
  227. """Default printing behavior of QExpr objects
  228. Handles the default printing of a QExpr. To add other things to the
  229. printing of the object, such as an operator name to operators or
  230. brackets to states, the class should override the _print/_pretty/_latex
  231. functions directly and make calls to _print_contents where appropriate.
  232. This allows things like InnerProduct to easily control its printing the
  233. printing of contents.
  234. """
  235. return self._print_contents(printer, *args)
  236. def _sympyrepr(self, printer, *args):
  237. classname = self.__class__.__name__
  238. label = self._print_label_repr(printer, *args)
  239. return '%s(%s)' % (classname, label)
  240. def _pretty(self, printer, *args):
  241. pform = self._print_contents_pretty(printer, *args)
  242. return pform
  243. def _latex(self, printer, *args):
  244. return self._print_contents_latex(printer, *args)
  245. #-------------------------------------------------------------------------
  246. # Methods from Basic and Expr
  247. #-------------------------------------------------------------------------
  248. def doit(self, **kw_args):
  249. return self
  250. #-------------------------------------------------------------------------
  251. # Represent
  252. #-------------------------------------------------------------------------
  253. def _represent_default_basis(self, **options):
  254. raise NotImplementedError('This object does not have a default basis')
  255. def _represent(self, *, basis=None, **options):
  256. """Represent this object in a given basis.
  257. This method dispatches to the actual methods that perform the
  258. representation. Subclases of QExpr should define various methods to
  259. determine how the object will be represented in various bases. The
  260. format of these methods is::
  261. def _represent_BasisName(self, basis, **options):
  262. Thus to define how a quantum object is represented in the basis of
  263. the operator Position, you would define::
  264. def _represent_Position(self, basis, **options):
  265. Usually, basis object will be instances of Operator subclasses, but
  266. there is a chance we will relax this in the future to accommodate other
  267. types of basis sets that are not associated with an operator.
  268. If the ``format`` option is given it can be ("sympy", "numpy",
  269. "scipy.sparse"). This will ensure that any matrices that result from
  270. representing the object are returned in the appropriate matrix format.
  271. Parameters
  272. ==========
  273. basis : Operator
  274. The Operator whose basis functions will be used as the basis for
  275. representation.
  276. options : dict
  277. A dictionary of key/value pairs that give options and hints for
  278. the representation, such as the number of basis functions to
  279. be used.
  280. """
  281. if basis is None:
  282. result = self._represent_default_basis(**options)
  283. else:
  284. result = dispatch_method(self, '_represent', basis, **options)
  285. # If we get a matrix representation, convert it to the right format.
  286. format = options.get('format', 'sympy')
  287. result = self._format_represent(result, format)
  288. return result
  289. def _format_represent(self, result, format):
  290. if format == 'sympy' and not isinstance(result, Matrix):
  291. return to_sympy(result)
  292. elif format == 'numpy' and not isinstance(result, numpy_ndarray):
  293. return to_numpy(result)
  294. elif format == 'scipy.sparse' and \
  295. not isinstance(result, scipy_sparse_matrix):
  296. return to_scipy_sparse(result)
  297. return result
  298. def split_commutative_parts(e):
  299. """Split into commutative and non-commutative parts."""
  300. c_part, nc_part = e.args_cnc()
  301. c_part = list(c_part)
  302. return c_part, nc_part
  303. def split_qexpr_parts(e):
  304. """Split an expression into Expr and noncommutative QExpr parts."""
  305. expr_part = []
  306. qexpr_part = []
  307. for arg in e.args:
  308. if not isinstance(arg, QExpr):
  309. expr_part.append(arg)
  310. else:
  311. qexpr_part.append(arg)
  312. return expr_part, qexpr_part
  313. def dispatch_method(self, basename, arg, **options):
  314. """Dispatch a method to the proper handlers."""
  315. method_name = '%s_%s' % (basename, arg.__class__.__name__)
  316. if hasattr(self, method_name):
  317. f = getattr(self, method_name)
  318. # This can raise and we will allow it to propagate.
  319. result = f(arg, **options)
  320. if result is not None:
  321. return result
  322. raise NotImplementedError(
  323. "%s.%s cannot handle: %r" %
  324. (self.__class__.__name__, basename, arg)
  325. )