123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- from sympy.core.basic import Basic
- from sympy.core.expr import Expr, ExprBuilder
- from sympy.core.singleton import S
- from sympy.core.sorting import default_sort_key
- from sympy.core.symbol import Dummy
- from sympy.core.sympify import sympify
- from sympy.matrices.matrices import MatrixBase
- from sympy.matrices.common import NonSquareMatrixError
- class Trace(Expr):
- """Matrix Trace
- Represents the trace of a matrix expression.
- Examples
- ========
- >>> from sympy import MatrixSymbol, Trace, eye
- >>> A = MatrixSymbol('A', 3, 3)
- >>> Trace(A)
- Trace(A)
- >>> Trace(eye(3))
- Trace(Matrix([
- [1, 0, 0],
- [0, 1, 0],
- [0, 0, 1]]))
- >>> Trace(eye(3)).simplify()
- 3
- """
- is_Trace = True
- is_commutative = True
- def __new__(cls, mat):
- mat = sympify(mat)
- if not mat.is_Matrix:
- raise TypeError("input to Trace, %s, is not a matrix" % str(mat))
- if not mat.is_square:
- raise NonSquareMatrixError("Trace of a non-square matrix")
- return Basic.__new__(cls, mat)
- def _eval_transpose(self):
- return self
- def _eval_derivative(self, v):
- from sympy.concrete.summations import Sum
- from .matexpr import MatrixElement
- if isinstance(v, MatrixElement):
- return self.rewrite(Sum).diff(v)
- expr = self.doit()
- if isinstance(expr, Trace):
- # Avoid looping infinitely:
- raise NotImplementedError
- return expr._eval_derivative(v)
- def _eval_derivative_matrix_lines(self, x):
- from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayContraction
- r = self.args[0]._eval_derivative_matrix_lines(x)
- for lr in r:
- if lr.higher == 1:
- lr.higher = ExprBuilder(
- ArrayContraction,
- [
- ExprBuilder(
- ArrayTensorProduct,
- [
- lr._lines[0],
- lr._lines[1],
- ]
- ),
- (1, 3),
- ],
- validator=ArrayContraction._validate
- )
- else:
- # This is not a matrix line:
- lr.higher = ExprBuilder(
- ArrayContraction,
- [
- ExprBuilder(
- ArrayTensorProduct,
- [
- lr._lines[0],
- lr._lines[1],
- lr.higher,
- ]
- ),
- (1, 3), (0, 2)
- ]
- )
- lr._lines = [S.One, S.One]
- lr._first_pointer_parent = lr._lines
- lr._second_pointer_parent = lr._lines
- lr._first_pointer_index = 0
- lr._second_pointer_index = 1
- return r
- @property
- def arg(self):
- return self.args[0]
- def doit(self, **kwargs):
- if kwargs.get('deep', True):
- arg = self.arg.doit(**kwargs)
- try:
- return arg._eval_trace()
- except (AttributeError, NotImplementedError):
- return Trace(arg)
- else:
- # _eval_trace would go too deep here
- if isinstance(self.arg, MatrixBase):
- return trace(self.arg)
- else:
- return Trace(self.arg)
- def as_explicit(self):
- return Trace(self.arg.as_explicit()).doit()
- def _normalize(self):
- # Normalization of trace of matrix products. Use transposition and
- # cyclic properties of traces to make sure the arguments of the matrix
- # product are sorted and the first argument is not a trasposition.
- from sympy.matrices.expressions.matmul import MatMul
- from sympy.matrices.expressions.transpose import Transpose
- trace_arg = self.arg
- if isinstance(trace_arg, MatMul):
- def get_arg_key(x):
- a = trace_arg.args[x]
- if isinstance(a, Transpose):
- a = a.arg
- return default_sort_key(a)
- indmin = min(range(len(trace_arg.args)), key=get_arg_key)
- if isinstance(trace_arg.args[indmin], Transpose):
- trace_arg = Transpose(trace_arg).doit()
- indmin = min(range(len(trace_arg.args)), key=lambda x: default_sort_key(trace_arg.args[x]))
- trace_arg = MatMul.fromiter(trace_arg.args[indmin:] + trace_arg.args[:indmin])
- return Trace(trace_arg)
- return self
- def _eval_rewrite_as_Sum(self, expr, **kwargs):
- from sympy.concrete.summations import Sum
- i = Dummy('i')
- return Sum(self.arg[i, i], (i, 0, self.arg.rows-1)).doit()
- def trace(expr):
- """Trace of a Matrix. Sum of the diagonal elements.
- Examples
- ========
- >>> from sympy import trace, Symbol, MatrixSymbol, eye
- >>> n = Symbol('n')
- >>> X = MatrixSymbol('X', n, n) # A square matrix
- >>> trace(2*X)
- 2*Trace(X)
- >>> trace(eye(3))
- 3
- """
- return Trace(expr).doit()
|