trace.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. from sympy.core.add import Add
  2. from sympy.core.containers import Tuple
  3. from sympy.core.expr import Expr
  4. from sympy.core.mul import Mul
  5. from sympy.core.power import Pow
  6. from sympy.core.sorting import default_sort_key
  7. from sympy.core.sympify import sympify
  8. from sympy.matrices import Matrix
  9. def _is_scalar(e):
  10. """ Helper method used in Tr"""
  11. # sympify to set proper attributes
  12. e = sympify(e)
  13. if isinstance(e, Expr):
  14. if (e.is_Integer or e.is_Float or
  15. e.is_Rational or e.is_Number or
  16. (e.is_Symbol and e.is_commutative)
  17. ):
  18. return True
  19. return False
  20. def _cycle_permute(l):
  21. """ Cyclic permutations based on canonical ordering
  22. Explanation
  23. ===========
  24. This method does the sort based ascii values while
  25. a better approach would be to used lexicographic sort.
  26. TODO: Handle condition such as symbols have subscripts/superscripts
  27. in case of lexicographic sort
  28. """
  29. if len(l) == 1:
  30. return l
  31. min_item = min(l, key=default_sort_key)
  32. indices = [i for i, x in enumerate(l) if x == min_item]
  33. le = list(l)
  34. le.extend(l) # duplicate and extend string for easy processing
  35. # adding the first min_item index back for easier looping
  36. indices.append(len(l) + indices[0])
  37. # create sublist of items with first item as min_item and last_item
  38. # in each of the sublist is item just before the next occurrence of
  39. # minitem in the cycle formed.
  40. sublist = [[le[indices[i]:indices[i + 1]]] for i in
  41. range(len(indices) - 1)]
  42. # we do comparison of strings by comparing elements
  43. # in each sublist
  44. idx = sublist.index(min(sublist))
  45. ordered_l = le[indices[idx]:indices[idx] + len(l)]
  46. return ordered_l
  47. def _rearrange_args(l):
  48. """ this just moves the last arg to first position
  49. to enable expansion of args
  50. A,B,A ==> A**2,B
  51. """
  52. if len(l) == 1:
  53. return l
  54. x = list(l[-1:])
  55. x.extend(l[0:-1])
  56. return Mul(*x).args
  57. class Tr(Expr):
  58. """ Generic Trace operation than can trace over:
  59. a) SymPy matrix
  60. b) operators
  61. c) outer products
  62. Parameters
  63. ==========
  64. o : operator, matrix, expr
  65. i : tuple/list indices (optional)
  66. Examples
  67. ========
  68. # TODO: Need to handle printing
  69. a) Trace(A+B) = Tr(A) + Tr(B)
  70. b) Trace(scalar*Operator) = scalar*Trace(Operator)
  71. >>> from sympy.physics.quantum.trace import Tr
  72. >>> from sympy import symbols, Matrix
  73. >>> a, b = symbols('a b', commutative=True)
  74. >>> A, B = symbols('A B', commutative=False)
  75. >>> Tr(a*A,[2])
  76. a*Tr(A)
  77. >>> m = Matrix([[1,2],[1,1]])
  78. >>> Tr(m)
  79. 2
  80. """
  81. def __new__(cls, *args):
  82. """ Construct a Trace object.
  83. Parameters
  84. ==========
  85. args = SymPy expression
  86. indices = tuple/list if indices, optional
  87. """
  88. # expect no indices,int or a tuple/list/Tuple
  89. if (len(args) == 2):
  90. if not isinstance(args[1], (list, Tuple, tuple)):
  91. indices = Tuple(args[1])
  92. else:
  93. indices = Tuple(*args[1])
  94. expr = args[0]
  95. elif (len(args) == 1):
  96. indices = Tuple()
  97. expr = args[0]
  98. else:
  99. raise ValueError("Arguments to Tr should be of form "
  100. "(expr[, [indices]])")
  101. if isinstance(expr, Matrix):
  102. return expr.trace()
  103. elif hasattr(expr, 'trace') and callable(expr.trace):
  104. #for any objects that have trace() defined e.g numpy
  105. return expr.trace()
  106. elif isinstance(expr, Add):
  107. return Add(*[Tr(arg, indices) for arg in expr.args])
  108. elif isinstance(expr, Mul):
  109. c_part, nc_part = expr.args_cnc()
  110. if len(nc_part) == 0:
  111. return Mul(*c_part)
  112. else:
  113. obj = Expr.__new__(cls, Mul(*nc_part), indices )
  114. #this check is needed to prevent cached instances
  115. #being returned even if len(c_part)==0
  116. return Mul(*c_part)*obj if len(c_part) > 0 else obj
  117. elif isinstance(expr, Pow):
  118. if (_is_scalar(expr.args[0]) and
  119. _is_scalar(expr.args[1])):
  120. return expr
  121. else:
  122. return Expr.__new__(cls, expr, indices)
  123. else:
  124. if (_is_scalar(expr)):
  125. return expr
  126. return Expr.__new__(cls, expr, indices)
  127. @property
  128. def kind(self):
  129. expr = self.args[0]
  130. expr_kind = expr.kind
  131. return expr_kind.element_kind
  132. def doit(self, **kwargs):
  133. """ Perform the trace operation.
  134. #TODO: Current version ignores the indices set for partial trace.
  135. >>> from sympy.physics.quantum.trace import Tr
  136. >>> from sympy.physics.quantum.operator import OuterProduct
  137. >>> from sympy.physics.quantum.spin import JzKet, JzBra
  138. >>> t = Tr(OuterProduct(JzKet(1,1), JzBra(1,1)))
  139. >>> t.doit()
  140. 1
  141. """
  142. if hasattr(self.args[0], '_eval_trace'):
  143. return self.args[0]._eval_trace(indices=self.args[1])
  144. return self
  145. @property
  146. def is_number(self):
  147. # TODO : improve this implementation
  148. return True
  149. #TODO: Review if the permute method is needed
  150. # and if it needs to return a new instance
  151. def permute(self, pos):
  152. """ Permute the arguments cyclically.
  153. Parameters
  154. ==========
  155. pos : integer, if positive, shift-right, else shift-left
  156. Examples
  157. ========
  158. >>> from sympy.physics.quantum.trace import Tr
  159. >>> from sympy import symbols
  160. >>> A, B, C, D = symbols('A B C D', commutative=False)
  161. >>> t = Tr(A*B*C*D)
  162. >>> t.permute(2)
  163. Tr(C*D*A*B)
  164. >>> t.permute(-2)
  165. Tr(C*D*A*B)
  166. """
  167. if pos > 0:
  168. pos = pos % len(self.args[0].args)
  169. else:
  170. pos = -(abs(pos) % len(self.args[0].args))
  171. args = list(self.args[0].args[-pos:] + self.args[0].args[0:-pos])
  172. return Tr(Mul(*(args)))
  173. def _hashable_content(self):
  174. if isinstance(self.args[0], Mul):
  175. args = _cycle_permute(_rearrange_args(self.args[0].args))
  176. else:
  177. args = [self.args[0]]
  178. return tuple(args) + (self.args[1], )