123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- from sympy.matrices.common import NonSquareMatrixError
- from .matexpr import MatrixExpr
- from .special import Identity
- from sympy.core import S
- from sympy.core.expr import ExprBuilder
- from sympy.core.cache import cacheit
- from sympy.core.power import Pow
- from sympy.core.sympify import _sympify
- from sympy.matrices import MatrixBase
- class MatPow(MatrixExpr):
- def __new__(cls, base, exp, evaluate=False, **options):
- base = _sympify(base)
- if not base.is_Matrix:
- raise TypeError("MatPow base should be a matrix")
- if not base.is_square:
- raise NonSquareMatrixError("Power of non-square matrix %s" % base)
- exp = _sympify(exp)
- obj = super().__new__(cls, base, exp)
- if evaluate:
- obj = obj.doit(deep=False)
- return obj
- @property
- def base(self):
- return self.args[0]
- @property
- def exp(self):
- return self.args[1]
- @property
- def shape(self):
- return self.base.shape
- @cacheit
- def _get_explicit_matrix(self):
- return self.base.as_explicit()**self.exp
- def _entry(self, i, j, **kwargs):
- from sympy.matrices.expressions import MatMul
- A = self.doit()
- if isinstance(A, MatPow):
- # We still have a MatPow, make an explicit MatMul out of it.
- if A.exp.is_Integer and A.exp.is_positive:
- A = MatMul(*[A.base for k in range(A.exp)])
- elif not self._is_shape_symbolic():
- return A._get_explicit_matrix()[i, j]
- else:
- # Leave the expression unevaluated:
- from sympy.matrices.expressions.matexpr import MatrixElement
- return MatrixElement(self, i, j)
- return A[i, j]
- def doit(self, **kwargs):
- if kwargs.get('deep', True):
- base, exp = [arg.doit(**kwargs) for arg in self.args]
- else:
- base, exp = self.args
- # combine all powers, e.g. (A ** 2) ** 3 -> A ** 6
- while isinstance(base, MatPow):
- exp *= base.args[1]
- base = base.args[0]
- if isinstance(base, MatrixBase):
- # Delegate
- return base ** exp
- # Handle simple cases so that _eval_power() in MatrixExpr sub-classes can ignore them
- if exp == S.One:
- return base
- if exp == S.Zero:
- return Identity(base.rows)
- if exp == S.NegativeOne:
- from sympy.matrices.expressions import Inverse
- return Inverse(base).doit(**kwargs)
- eval_power = getattr(base, '_eval_power', None)
- if eval_power is not None:
- return eval_power(exp)
- return MatPow(base, exp)
- def _eval_transpose(self):
- base, exp = self.args
- return MatPow(base.T, exp)
- def _eval_derivative(self, x):
- return Pow._eval_derivative(self, x)
- def _eval_derivative_matrix_lines(self, x):
- from sympy.tensor.array.expressions.array_expressions import ArrayContraction
- from ...tensor.array.expressions.array_expressions import ArrayTensorProduct
- from .matmul import MatMul
- from .inverse import Inverse
- exp = self.exp
- if self.base.shape == (1, 1) and not exp.has(x):
- lr = self.base._eval_derivative_matrix_lines(x)
- for i in lr:
- subexpr = ExprBuilder(
- ArrayContraction,
- [
- ExprBuilder(
- ArrayTensorProduct,
- [
- Identity(1),
- i._lines[0],
- exp*self.base**(exp-1),
- i._lines[1],
- Identity(1),
- ]
- ),
- (0, 3, 4), (5, 7, 8)
- ],
- validator=ArrayContraction._validate
- )
- i._first_pointer_parent = subexpr.args[0].args
- i._first_pointer_index = 0
- i._second_pointer_parent = subexpr.args[0].args
- i._second_pointer_index = 4
- i._lines = [subexpr]
- return lr
- if (exp > 0) == True:
- newexpr = MatMul.fromiter([self.base for i in range(exp)])
- elif (exp == -1) == True:
- return Inverse(self.base)._eval_derivative_matrix_lines(x)
- elif (exp < 0) == True:
- newexpr = MatMul.fromiter([Inverse(self.base) for i in range(-exp)])
- elif (exp == 0) == True:
- return self.doit()._eval_derivative_matrix_lines(x)
- else:
- raise NotImplementedError("cannot evaluate %s derived by %s" % (self, x))
- return newexpr._eval_derivative_matrix_lines(x)
- def _eval_inverse(self):
- return MatPow(self.base, -self.exp)
|