matpow.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. from sympy.matrices.common import NonSquareMatrixError
  2. from .matexpr import MatrixExpr
  3. from .special import Identity
  4. from sympy.core import S
  5. from sympy.core.expr import ExprBuilder
  6. from sympy.core.cache import cacheit
  7. from sympy.core.power import Pow
  8. from sympy.core.sympify import _sympify
  9. from sympy.matrices import MatrixBase
  10. class MatPow(MatrixExpr):
  11. def __new__(cls, base, exp, evaluate=False, **options):
  12. base = _sympify(base)
  13. if not base.is_Matrix:
  14. raise TypeError("MatPow base should be a matrix")
  15. if not base.is_square:
  16. raise NonSquareMatrixError("Power of non-square matrix %s" % base)
  17. exp = _sympify(exp)
  18. obj = super().__new__(cls, base, exp)
  19. if evaluate:
  20. obj = obj.doit(deep=False)
  21. return obj
  22. @property
  23. def base(self):
  24. return self.args[0]
  25. @property
  26. def exp(self):
  27. return self.args[1]
  28. @property
  29. def shape(self):
  30. return self.base.shape
  31. @cacheit
  32. def _get_explicit_matrix(self):
  33. return self.base.as_explicit()**self.exp
  34. def _entry(self, i, j, **kwargs):
  35. from sympy.matrices.expressions import MatMul
  36. A = self.doit()
  37. if isinstance(A, MatPow):
  38. # We still have a MatPow, make an explicit MatMul out of it.
  39. if A.exp.is_Integer and A.exp.is_positive:
  40. A = MatMul(*[A.base for k in range(A.exp)])
  41. elif not self._is_shape_symbolic():
  42. return A._get_explicit_matrix()[i, j]
  43. else:
  44. # Leave the expression unevaluated:
  45. from sympy.matrices.expressions.matexpr import MatrixElement
  46. return MatrixElement(self, i, j)
  47. return A[i, j]
  48. def doit(self, **kwargs):
  49. if kwargs.get('deep', True):
  50. base, exp = [arg.doit(**kwargs) for arg in self.args]
  51. else:
  52. base, exp = self.args
  53. # combine all powers, e.g. (A ** 2) ** 3 -> A ** 6
  54. while isinstance(base, MatPow):
  55. exp *= base.args[1]
  56. base = base.args[0]
  57. if isinstance(base, MatrixBase):
  58. # Delegate
  59. return base ** exp
  60. # Handle simple cases so that _eval_power() in MatrixExpr sub-classes can ignore them
  61. if exp == S.One:
  62. return base
  63. if exp == S.Zero:
  64. return Identity(base.rows)
  65. if exp == S.NegativeOne:
  66. from sympy.matrices.expressions import Inverse
  67. return Inverse(base).doit(**kwargs)
  68. eval_power = getattr(base, '_eval_power', None)
  69. if eval_power is not None:
  70. return eval_power(exp)
  71. return MatPow(base, exp)
  72. def _eval_transpose(self):
  73. base, exp = self.args
  74. return MatPow(base.T, exp)
  75. def _eval_derivative(self, x):
  76. return Pow._eval_derivative(self, x)
  77. def _eval_derivative_matrix_lines(self, x):
  78. from sympy.tensor.array.expressions.array_expressions import ArrayContraction
  79. from ...tensor.array.expressions.array_expressions import ArrayTensorProduct
  80. from .matmul import MatMul
  81. from .inverse import Inverse
  82. exp = self.exp
  83. if self.base.shape == (1, 1) and not exp.has(x):
  84. lr = self.base._eval_derivative_matrix_lines(x)
  85. for i in lr:
  86. subexpr = ExprBuilder(
  87. ArrayContraction,
  88. [
  89. ExprBuilder(
  90. ArrayTensorProduct,
  91. [
  92. Identity(1),
  93. i._lines[0],
  94. exp*self.base**(exp-1),
  95. i._lines[1],
  96. Identity(1),
  97. ]
  98. ),
  99. (0, 3, 4), (5, 7, 8)
  100. ],
  101. validator=ArrayContraction._validate
  102. )
  103. i._first_pointer_parent = subexpr.args[0].args
  104. i._first_pointer_index = 0
  105. i._second_pointer_parent = subexpr.args[0].args
  106. i._second_pointer_index = 4
  107. i._lines = [subexpr]
  108. return lr
  109. if (exp > 0) == True:
  110. newexpr = MatMul.fromiter([self.base for i in range(exp)])
  111. elif (exp == -1) == True:
  112. return Inverse(self.base)._eval_derivative_matrix_lines(x)
  113. elif (exp < 0) == True:
  114. newexpr = MatMul.fromiter([Inverse(self.base) for i in range(-exp)])
  115. elif (exp == 0) == True:
  116. return self.doit()._eval_derivative_matrix_lines(x)
  117. else:
  118. raise NotImplementedError("cannot evaluate %s derived by %s" % (self, x))
  119. return newexpr._eval_derivative_matrix_lines(x)
  120. def _eval_inverse(self):
  121. return MatPow(self.base, -self.exp)