matmul.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. from sympy.assumptions.ask import ask, Q
  2. from sympy.assumptions.refine import handlers_dict
  3. from sympy.core import Basic, sympify, S
  4. from sympy.core.mul import mul, Mul
  5. from sympy.core.numbers import Number, Integer
  6. from sympy.core.symbol import Dummy
  7. from sympy.functions import adjoint
  8. from sympy.strategies import (rm_id, unpack, typed, flatten, exhaust,
  9. do_one, new)
  10. from sympy.matrices.common import ShapeError, NonInvertibleMatrixError
  11. from sympy.matrices.matrices import MatrixBase
  12. from .inverse import Inverse
  13. from .matexpr import MatrixExpr
  14. from .matpow import MatPow
  15. from .transpose import transpose
  16. from .permutation import PermutationMatrix
  17. from .special import ZeroMatrix, Identity, GenericIdentity, OneMatrix
  18. # XXX: MatMul should perhaps not subclass directly from Mul
  19. class MatMul(MatrixExpr, Mul):
  20. """
  21. A product of matrix expressions
  22. Examples
  23. ========
  24. >>> from sympy import MatMul, MatrixSymbol
  25. >>> A = MatrixSymbol('A', 5, 4)
  26. >>> B = MatrixSymbol('B', 4, 3)
  27. >>> C = MatrixSymbol('C', 3, 6)
  28. >>> MatMul(A, B, C)
  29. A*B*C
  30. """
  31. is_MatMul = True
  32. identity = GenericIdentity()
  33. def __new__(cls, *args, evaluate=False, check=True, _sympify=True):
  34. if not args:
  35. return cls.identity
  36. # This must be removed aggressively in the constructor to avoid
  37. # TypeErrors from GenericIdentity().shape
  38. args = list(filter(lambda i: cls.identity != i, args))
  39. if _sympify:
  40. args = list(map(sympify, args))
  41. obj = Basic.__new__(cls, *args)
  42. factor, matrices = obj.as_coeff_matrices()
  43. if check:
  44. validate(*matrices)
  45. if not matrices:
  46. # Should it be
  47. #
  48. # return Basic.__neq__(cls, factor, GenericIdentity()) ?
  49. return factor
  50. if evaluate:
  51. return canonicalize(obj)
  52. return obj
  53. @property
  54. def shape(self):
  55. matrices = [arg for arg in self.args if arg.is_Matrix]
  56. return (matrices[0].rows, matrices[-1].cols)
  57. def could_extract_minus_sign(self):
  58. return self.args[0].could_extract_minus_sign()
  59. def _entry(self, i, j, expand=True, **kwargs):
  60. # Avoid cyclic imports
  61. from sympy.concrete.summations import Sum
  62. from sympy.matrices.immutable import ImmutableMatrix
  63. coeff, matrices = self.as_coeff_matrices()
  64. if len(matrices) == 1: # situation like 2*X, matmul is just X
  65. return coeff * matrices[0][i, j]
  66. indices = [None]*(len(matrices) + 1)
  67. ind_ranges = [None]*(len(matrices) - 1)
  68. indices[0] = i
  69. indices[-1] = j
  70. def f():
  71. counter = 1
  72. while True:
  73. yield Dummy("i_%i" % counter)
  74. counter += 1
  75. dummy_generator = kwargs.get("dummy_generator", f())
  76. for i in range(1, len(matrices)):
  77. indices[i] = next(dummy_generator)
  78. for i, arg in enumerate(matrices[:-1]):
  79. ind_ranges[i] = arg.shape[1] - 1
  80. matrices = [arg._entry(indices[i], indices[i+1], dummy_generator=dummy_generator) for i, arg in enumerate(matrices)]
  81. expr_in_sum = Mul.fromiter(matrices)
  82. if any(v.has(ImmutableMatrix) for v in matrices):
  83. expand = True
  84. result = coeff*Sum(
  85. expr_in_sum,
  86. *zip(indices[1:-1], [0]*len(ind_ranges), ind_ranges)
  87. )
  88. # Don't waste time in result.doit() if the sum bounds are symbolic
  89. if not any(isinstance(v, (Integer, int)) for v in ind_ranges):
  90. expand = False
  91. return result.doit() if expand else result
  92. def as_coeff_matrices(self):
  93. scalars = [x for x in self.args if not x.is_Matrix]
  94. matrices = [x for x in self.args if x.is_Matrix]
  95. coeff = Mul(*scalars)
  96. if coeff.is_commutative is False:
  97. raise NotImplementedError("noncommutative scalars in MatMul are not supported.")
  98. return coeff, matrices
  99. def as_coeff_mmul(self):
  100. coeff, matrices = self.as_coeff_matrices()
  101. return coeff, MatMul(*matrices)
  102. def _eval_transpose(self):
  103. """Transposition of matrix multiplication.
  104. Notes
  105. =====
  106. The following rules are applied.
  107. Transposition for matrix multiplied with another matrix:
  108. `\\left(A B\\right)^{T} = B^{T} A^{T}`
  109. Transposition for matrix multiplied with scalar:
  110. `\\left(c A\\right)^{T} = c A^{T}`
  111. References
  112. ==========
  113. .. [1] https://en.wikipedia.org/wiki/Transpose
  114. """
  115. coeff, matrices = self.as_coeff_matrices()
  116. return MatMul(
  117. coeff, *[transpose(arg) for arg in matrices[::-1]]).doit()
  118. def _eval_adjoint(self):
  119. return MatMul(*[adjoint(arg) for arg in self.args[::-1]]).doit()
  120. def _eval_trace(self):
  121. factor, mmul = self.as_coeff_mmul()
  122. if factor != 1:
  123. from .trace import trace
  124. return factor * trace(mmul.doit())
  125. else:
  126. raise NotImplementedError("Can't simplify any further")
  127. def _eval_determinant(self):
  128. from sympy.matrices.expressions.determinant import Determinant
  129. factor, matrices = self.as_coeff_matrices()
  130. square_matrices = only_squares(*matrices)
  131. return factor**self.rows * Mul(*list(map(Determinant, square_matrices)))
  132. def _eval_inverse(self):
  133. try:
  134. return MatMul(*[
  135. arg.inverse() if isinstance(arg, MatrixExpr) else arg**-1
  136. for arg in self.args[::-1]]).doit()
  137. except ShapeError:
  138. return Inverse(self)
  139. def doit(self, **kwargs):
  140. deep = kwargs.get('deep', True)
  141. if deep:
  142. args = [arg.doit(**kwargs) for arg in self.args]
  143. else:
  144. args = self.args
  145. # treat scalar*MatrixSymbol or scalar*MatPow separately
  146. expr = canonicalize(MatMul(*args))
  147. return expr
  148. # Needed for partial compatibility with Mul
  149. def args_cnc(self, **kwargs):
  150. coeff_c = [x for x in self.args if x.is_commutative]
  151. coeff_nc = [x for x in self.args if not x.is_commutative]
  152. return [coeff_c, coeff_nc]
  153. def _eval_derivative_matrix_lines(self, x):
  154. from .transpose import Transpose
  155. with_x_ind = [i for i, arg in enumerate(self.args) if arg.has(x)]
  156. lines = []
  157. for ind in with_x_ind:
  158. left_args = self.args[:ind]
  159. right_args = self.args[ind+1:]
  160. if right_args:
  161. right_mat = MatMul.fromiter(right_args)
  162. else:
  163. right_mat = Identity(self.shape[1])
  164. if left_args:
  165. left_rev = MatMul.fromiter([Transpose(i).doit() if i.is_Matrix else i for i in reversed(left_args)])
  166. else:
  167. left_rev = Identity(self.shape[0])
  168. d = self.args[ind]._eval_derivative_matrix_lines(x)
  169. for i in d:
  170. i.append_first(left_rev)
  171. i.append_second(right_mat)
  172. lines.append(i)
  173. return lines
  174. mul.register_handlerclass((Mul, MatMul), MatMul)
  175. def validate(*matrices):
  176. """ Checks for valid shapes for args of MatMul """
  177. for i in range(len(matrices)-1):
  178. A, B = matrices[i:i+2]
  179. if A.cols != B.rows:
  180. raise ShapeError("Matrices %s and %s are not aligned"%(A, B))
  181. # Rules
  182. def newmul(*args):
  183. if args[0] == 1:
  184. args = args[1:]
  185. return new(MatMul, *args)
  186. def any_zeros(mul):
  187. if any(arg.is_zero or (arg.is_Matrix and arg.is_ZeroMatrix)
  188. for arg in mul.args):
  189. matrices = [arg for arg in mul.args if arg.is_Matrix]
  190. return ZeroMatrix(matrices[0].rows, matrices[-1].cols)
  191. return mul
  192. def merge_explicit(matmul):
  193. """ Merge explicit MatrixBase arguments
  194. >>> from sympy import MatrixSymbol, Matrix, MatMul, pprint
  195. >>> from sympy.matrices.expressions.matmul import merge_explicit
  196. >>> A = MatrixSymbol('A', 2, 2)
  197. >>> B = Matrix([[1, 1], [1, 1]])
  198. >>> C = Matrix([[1, 2], [3, 4]])
  199. >>> X = MatMul(A, B, C)
  200. >>> pprint(X)
  201. [1 1] [1 2]
  202. A*[ ]*[ ]
  203. [1 1] [3 4]
  204. >>> pprint(merge_explicit(X))
  205. [4 6]
  206. A*[ ]
  207. [4 6]
  208. >>> X = MatMul(B, A, C)
  209. >>> pprint(X)
  210. [1 1] [1 2]
  211. [ ]*A*[ ]
  212. [1 1] [3 4]
  213. >>> pprint(merge_explicit(X))
  214. [1 1] [1 2]
  215. [ ]*A*[ ]
  216. [1 1] [3 4]
  217. """
  218. if not any(isinstance(arg, MatrixBase) for arg in matmul.args):
  219. return matmul
  220. newargs = []
  221. last = matmul.args[0]
  222. for arg in matmul.args[1:]:
  223. if isinstance(arg, (MatrixBase, Number)) and isinstance(last, (MatrixBase, Number)):
  224. last = last * arg
  225. else:
  226. newargs.append(last)
  227. last = arg
  228. newargs.append(last)
  229. return MatMul(*newargs)
  230. def remove_ids(mul):
  231. """ Remove Identities from a MatMul
  232. This is a modified version of sympy.strategies.rm_id.
  233. This is necesssary because MatMul may contain both MatrixExprs and Exprs
  234. as args.
  235. See Also
  236. ========
  237. sympy.strategies.rm_id
  238. """
  239. # Separate Exprs from MatrixExprs in args
  240. factor, mmul = mul.as_coeff_mmul()
  241. # Apply standard rm_id for MatMuls
  242. result = rm_id(lambda x: x.is_Identity is True)(mmul)
  243. if result != mmul:
  244. return newmul(factor, *result.args) # Recombine and return
  245. else:
  246. return mul
  247. def factor_in_front(mul):
  248. factor, matrices = mul.as_coeff_matrices()
  249. if factor != 1:
  250. return newmul(factor, *matrices)
  251. return mul
  252. def combine_powers(mul):
  253. r"""Combine consecutive powers with the same base into one, e.g.
  254. $$A \times A^2 \Rightarrow A^3$$
  255. This also cancels out the possible matrix inverses using the
  256. knowledgebase of :class:`~.Inverse`, e.g.,
  257. $$ Y \times X \times X^{-1} \Rightarrow Y $$
  258. """
  259. factor, args = mul.as_coeff_matrices()
  260. new_args = [args[0]]
  261. for B in args[1:]:
  262. A = new_args[-1]
  263. if A.is_square == False or B.is_square == False:
  264. new_args.append(B)
  265. continue
  266. if isinstance(A, MatPow):
  267. A_base, A_exp = A.args
  268. else:
  269. A_base, A_exp = A, S.One
  270. if isinstance(B, MatPow):
  271. B_base, B_exp = B.args
  272. else:
  273. B_base, B_exp = B, S.One
  274. if A_base == B_base:
  275. new_exp = A_exp + B_exp
  276. new_args[-1] = MatPow(A_base, new_exp).doit(deep=False)
  277. continue
  278. elif not isinstance(B_base, MatrixBase):
  279. try:
  280. B_base_inv = B_base.inverse()
  281. except NonInvertibleMatrixError:
  282. B_base_inv = None
  283. if B_base_inv is not None and A_base == B_base_inv:
  284. new_exp = A_exp - B_exp
  285. new_args[-1] = MatPow(A_base, new_exp).doit(deep=False)
  286. continue
  287. new_args.append(B)
  288. return newmul(factor, *new_args)
  289. def combine_permutations(mul):
  290. """Refine products of permutation matrices as the products of cycles.
  291. """
  292. args = mul.args
  293. l = len(args)
  294. if l < 2:
  295. return mul
  296. result = [args[0]]
  297. for i in range(1, l):
  298. A = result[-1]
  299. B = args[i]
  300. if isinstance(A, PermutationMatrix) and \
  301. isinstance(B, PermutationMatrix):
  302. cycle_1 = A.args[0]
  303. cycle_2 = B.args[0]
  304. result[-1] = PermutationMatrix(cycle_1 * cycle_2)
  305. else:
  306. result.append(B)
  307. return MatMul(*result)
  308. def combine_one_matrices(mul):
  309. """
  310. Combine products of OneMatrix
  311. e.g. OneMatrix(2, 3) * OneMatrix(3, 4) -> 3 * OneMatrix(2, 4)
  312. """
  313. factor, args = mul.as_coeff_matrices()
  314. new_args = [args[0]]
  315. for B in args[1:]:
  316. A = new_args[-1]
  317. if not isinstance(A, OneMatrix) or not isinstance(B, OneMatrix):
  318. new_args.append(B)
  319. continue
  320. new_args.pop()
  321. new_args.append(OneMatrix(A.shape[0], B.shape[1]))
  322. factor *= A.shape[1]
  323. return newmul(factor, *new_args)
  324. def distribute_monom(mul):
  325. """
  326. Simplify MatMul expressions but distributing
  327. rational term to MatMul.
  328. e.g. 2*(A+B) -> 2*A + 2*B
  329. """
  330. args = mul.args
  331. if len(args) == 2:
  332. from .matadd import MatAdd
  333. if args[0].is_MatAdd and args[1].is_Rational:
  334. return MatAdd(*[MatMul(mat, args[1]).doit() for mat in args[0].args])
  335. if args[1].is_MatAdd and args[0].is_Rational:
  336. return MatAdd(*[MatMul(args[0], mat).doit() for mat in args[1].args])
  337. return mul
  338. rules = (
  339. distribute_monom, any_zeros, remove_ids, combine_one_matrices, combine_powers, unpack, rm_id(lambda x: x == 1),
  340. merge_explicit, factor_in_front, flatten, combine_permutations)
  341. canonicalize = exhaust(typed({MatMul: do_one(*rules)}))
  342. def only_squares(*matrices):
  343. """factor matrices only if they are square"""
  344. if matrices[0].rows != matrices[-1].cols:
  345. raise RuntimeError("Invalid matrices being multiplied")
  346. out = []
  347. start = 0
  348. for i, M in enumerate(matrices):
  349. if M.cols == matrices[start].rows:
  350. out.append(MatMul(*matrices[start:i+1]).doit())
  351. start = i+1
  352. return out
  353. def refine_MatMul(expr, assumptions):
  354. """
  355. >>> from sympy import MatrixSymbol, Q, assuming, refine
  356. >>> X = MatrixSymbol('X', 2, 2)
  357. >>> expr = X * X.T
  358. >>> print(expr)
  359. X*X.T
  360. >>> with assuming(Q.orthogonal(X)):
  361. ... print(refine(expr))
  362. I
  363. """
  364. newargs = []
  365. exprargs = []
  366. for args in expr.args:
  367. if args.is_Matrix:
  368. exprargs.append(args)
  369. else:
  370. newargs.append(args)
  371. last = exprargs[0]
  372. for arg in exprargs[1:]:
  373. if arg == last.T and ask(Q.orthogonal(arg), assumptions):
  374. last = Identity(arg.shape[0])
  375. elif arg == last.conjugate() and ask(Q.unitary(arg), assumptions):
  376. last = Identity(arg.shape[0])
  377. else:
  378. newargs.append(last)
  379. last = arg
  380. newargs.append(last)
  381. return MatMul(*newargs)
  382. handlers_dict['MatMul'] = refine_MatMul