trace.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from sympy.core.basic import Basic
  2. from sympy.core.expr import Expr, ExprBuilder
  3. from sympy.core.singleton import S
  4. from sympy.core.sorting import default_sort_key
  5. from sympy.core.symbol import Dummy
  6. from sympy.core.sympify import sympify
  7. from sympy.matrices.matrices import MatrixBase
  8. from sympy.matrices.common import NonSquareMatrixError
  9. class Trace(Expr):
  10. """Matrix Trace
  11. Represents the trace of a matrix expression.
  12. Examples
  13. ========
  14. >>> from sympy import MatrixSymbol, Trace, eye
  15. >>> A = MatrixSymbol('A', 3, 3)
  16. >>> Trace(A)
  17. Trace(A)
  18. >>> Trace(eye(3))
  19. Trace(Matrix([
  20. [1, 0, 0],
  21. [0, 1, 0],
  22. [0, 0, 1]]))
  23. >>> Trace(eye(3)).simplify()
  24. 3
  25. """
  26. is_Trace = True
  27. is_commutative = True
  28. def __new__(cls, mat):
  29. mat = sympify(mat)
  30. if not mat.is_Matrix:
  31. raise TypeError("input to Trace, %s, is not a matrix" % str(mat))
  32. if not mat.is_square:
  33. raise NonSquareMatrixError("Trace of a non-square matrix")
  34. return Basic.__new__(cls, mat)
  35. def _eval_transpose(self):
  36. return self
  37. def _eval_derivative(self, v):
  38. from sympy.concrete.summations import Sum
  39. from .matexpr import MatrixElement
  40. if isinstance(v, MatrixElement):
  41. return self.rewrite(Sum).diff(v)
  42. expr = self.doit()
  43. if isinstance(expr, Trace):
  44. # Avoid looping infinitely:
  45. raise NotImplementedError
  46. return expr._eval_derivative(v)
  47. def _eval_derivative_matrix_lines(self, x):
  48. from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayContraction
  49. r = self.args[0]._eval_derivative_matrix_lines(x)
  50. for lr in r:
  51. if lr.higher == 1:
  52. lr.higher = ExprBuilder(
  53. ArrayContraction,
  54. [
  55. ExprBuilder(
  56. ArrayTensorProduct,
  57. [
  58. lr._lines[0],
  59. lr._lines[1],
  60. ]
  61. ),
  62. (1, 3),
  63. ],
  64. validator=ArrayContraction._validate
  65. )
  66. else:
  67. # This is not a matrix line:
  68. lr.higher = ExprBuilder(
  69. ArrayContraction,
  70. [
  71. ExprBuilder(
  72. ArrayTensorProduct,
  73. [
  74. lr._lines[0],
  75. lr._lines[1],
  76. lr.higher,
  77. ]
  78. ),
  79. (1, 3), (0, 2)
  80. ]
  81. )
  82. lr._lines = [S.One, S.One]
  83. lr._first_pointer_parent = lr._lines
  84. lr._second_pointer_parent = lr._lines
  85. lr._first_pointer_index = 0
  86. lr._second_pointer_index = 1
  87. return r
  88. @property
  89. def arg(self):
  90. return self.args[0]
  91. def doit(self, **kwargs):
  92. if kwargs.get('deep', True):
  93. arg = self.arg.doit(**kwargs)
  94. try:
  95. return arg._eval_trace()
  96. except (AttributeError, NotImplementedError):
  97. return Trace(arg)
  98. else:
  99. # _eval_trace would go too deep here
  100. if isinstance(self.arg, MatrixBase):
  101. return trace(self.arg)
  102. else:
  103. return Trace(self.arg)
  104. def as_explicit(self):
  105. return Trace(self.arg.as_explicit()).doit()
  106. def _normalize(self):
  107. # Normalization of trace of matrix products. Use transposition and
  108. # cyclic properties of traces to make sure the arguments of the matrix
  109. # product are sorted and the first argument is not a trasposition.
  110. from sympy.matrices.expressions.matmul import MatMul
  111. from sympy.matrices.expressions.transpose import Transpose
  112. trace_arg = self.arg
  113. if isinstance(trace_arg, MatMul):
  114. def get_arg_key(x):
  115. a = trace_arg.args[x]
  116. if isinstance(a, Transpose):
  117. a = a.arg
  118. return default_sort_key(a)
  119. indmin = min(range(len(trace_arg.args)), key=get_arg_key)
  120. if isinstance(trace_arg.args[indmin], Transpose):
  121. trace_arg = Transpose(trace_arg).doit()
  122. indmin = min(range(len(trace_arg.args)), key=lambda x: default_sort_key(trace_arg.args[x]))
  123. trace_arg = MatMul.fromiter(trace_arg.args[indmin:] + trace_arg.args[:indmin])
  124. return Trace(trace_arg)
  125. return self
  126. def _eval_rewrite_as_Sum(self, expr, **kwargs):
  127. from sympy.concrete.summations import Sum
  128. i = Dummy('i')
  129. return Sum(self.arg[i, i], (i, 0, self.arg.rows-1)).doit()
  130. def trace(expr):
  131. """Trace of a Matrix. Sum of the diagonal elements.
  132. Examples
  133. ========
  134. >>> from sympy import trace, Symbol, MatrixSymbol, eye
  135. >>> n = Symbol('n')
  136. >>> X = MatrixSymbol('X', n, n) # A square matrix
  137. >>> trace(2*X)
  138. 2*Trace(X)
  139. >>> trace(eye(3))
  140. 3
  141. """
  142. return Trace(expr).doit()