123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467 |
- from sympy.assumptions.ask import ask, Q
- from sympy.assumptions.refine import handlers_dict
- from sympy.core import Basic, sympify, S
- from sympy.core.mul import mul, Mul
- from sympy.core.numbers import Number, Integer
- from sympy.core.symbol import Dummy
- from sympy.functions import adjoint
- from sympy.strategies import (rm_id, unpack, typed, flatten, exhaust,
- do_one, new)
- from sympy.matrices.common import ShapeError, NonInvertibleMatrixError
- from sympy.matrices.matrices import MatrixBase
- from .inverse import Inverse
- from .matexpr import MatrixExpr
- from .matpow import MatPow
- from .transpose import transpose
- from .permutation import PermutationMatrix
- from .special import ZeroMatrix, Identity, GenericIdentity, OneMatrix
- # XXX: MatMul should perhaps not subclass directly from Mul
- class MatMul(MatrixExpr, Mul):
- """
- A product of matrix expressions
- Examples
- ========
- >>> from sympy import MatMul, MatrixSymbol
- >>> A = MatrixSymbol('A', 5, 4)
- >>> B = MatrixSymbol('B', 4, 3)
- >>> C = MatrixSymbol('C', 3, 6)
- >>> MatMul(A, B, C)
- A*B*C
- """
- is_MatMul = True
- identity = GenericIdentity()
- def __new__(cls, *args, evaluate=False, check=True, _sympify=True):
- if not args:
- return cls.identity
- # This must be removed aggressively in the constructor to avoid
- # TypeErrors from GenericIdentity().shape
- args = list(filter(lambda i: cls.identity != i, args))
- if _sympify:
- args = list(map(sympify, args))
- obj = Basic.__new__(cls, *args)
- factor, matrices = obj.as_coeff_matrices()
- if check:
- validate(*matrices)
- if not matrices:
- # Should it be
- #
- # return Basic.__neq__(cls, factor, GenericIdentity()) ?
- return factor
- if evaluate:
- return canonicalize(obj)
- return obj
- @property
- def shape(self):
- matrices = [arg for arg in self.args if arg.is_Matrix]
- return (matrices[0].rows, matrices[-1].cols)
- def could_extract_minus_sign(self):
- return self.args[0].could_extract_minus_sign()
- def _entry(self, i, j, expand=True, **kwargs):
- # Avoid cyclic imports
- from sympy.concrete.summations import Sum
- from sympy.matrices.immutable import ImmutableMatrix
- coeff, matrices = self.as_coeff_matrices()
- if len(matrices) == 1: # situation like 2*X, matmul is just X
- return coeff * matrices[0][i, j]
- indices = [None]*(len(matrices) + 1)
- ind_ranges = [None]*(len(matrices) - 1)
- indices[0] = i
- indices[-1] = j
- def f():
- counter = 1
- while True:
- yield Dummy("i_%i" % counter)
- counter += 1
- dummy_generator = kwargs.get("dummy_generator", f())
- for i in range(1, len(matrices)):
- indices[i] = next(dummy_generator)
- for i, arg in enumerate(matrices[:-1]):
- ind_ranges[i] = arg.shape[1] - 1
- matrices = [arg._entry(indices[i], indices[i+1], dummy_generator=dummy_generator) for i, arg in enumerate(matrices)]
- expr_in_sum = Mul.fromiter(matrices)
- if any(v.has(ImmutableMatrix) for v in matrices):
- expand = True
- result = coeff*Sum(
- expr_in_sum,
- *zip(indices[1:-1], [0]*len(ind_ranges), ind_ranges)
- )
- # Don't waste time in result.doit() if the sum bounds are symbolic
- if not any(isinstance(v, (Integer, int)) for v in ind_ranges):
- expand = False
- return result.doit() if expand else result
- def as_coeff_matrices(self):
- scalars = [x for x in self.args if not x.is_Matrix]
- matrices = [x for x in self.args if x.is_Matrix]
- coeff = Mul(*scalars)
- if coeff.is_commutative is False:
- raise NotImplementedError("noncommutative scalars in MatMul are not supported.")
- return coeff, matrices
- def as_coeff_mmul(self):
- coeff, matrices = self.as_coeff_matrices()
- return coeff, MatMul(*matrices)
- def _eval_transpose(self):
- """Transposition of matrix multiplication.
- Notes
- =====
- The following rules are applied.
- Transposition for matrix multiplied with another matrix:
- `\\left(A B\\right)^{T} = B^{T} A^{T}`
- Transposition for matrix multiplied with scalar:
- `\\left(c A\\right)^{T} = c A^{T}`
- References
- ==========
- .. [1] https://en.wikipedia.org/wiki/Transpose
- """
- coeff, matrices = self.as_coeff_matrices()
- return MatMul(
- coeff, *[transpose(arg) for arg in matrices[::-1]]).doit()
- def _eval_adjoint(self):
- return MatMul(*[adjoint(arg) for arg in self.args[::-1]]).doit()
- def _eval_trace(self):
- factor, mmul = self.as_coeff_mmul()
- if factor != 1:
- from .trace import trace
- return factor * trace(mmul.doit())
- else:
- raise NotImplementedError("Can't simplify any further")
- def _eval_determinant(self):
- from sympy.matrices.expressions.determinant import Determinant
- factor, matrices = self.as_coeff_matrices()
- square_matrices = only_squares(*matrices)
- return factor**self.rows * Mul(*list(map(Determinant, square_matrices)))
- def _eval_inverse(self):
- try:
- return MatMul(*[
- arg.inverse() if isinstance(arg, MatrixExpr) else arg**-1
- for arg in self.args[::-1]]).doit()
- except ShapeError:
- return Inverse(self)
- def doit(self, **kwargs):
- deep = kwargs.get('deep', True)
- if deep:
- args = [arg.doit(**kwargs) for arg in self.args]
- else:
- args = self.args
- # treat scalar*MatrixSymbol or scalar*MatPow separately
- expr = canonicalize(MatMul(*args))
- return expr
- # Needed for partial compatibility with Mul
- def args_cnc(self, **kwargs):
- coeff_c = [x for x in self.args if x.is_commutative]
- coeff_nc = [x for x in self.args if not x.is_commutative]
- return [coeff_c, coeff_nc]
- def _eval_derivative_matrix_lines(self, x):
- from .transpose import Transpose
- with_x_ind = [i for i, arg in enumerate(self.args) if arg.has(x)]
- lines = []
- for ind in with_x_ind:
- left_args = self.args[:ind]
- right_args = self.args[ind+1:]
- if right_args:
- right_mat = MatMul.fromiter(right_args)
- else:
- right_mat = Identity(self.shape[1])
- if left_args:
- left_rev = MatMul.fromiter([Transpose(i).doit() if i.is_Matrix else i for i in reversed(left_args)])
- else:
- left_rev = Identity(self.shape[0])
- d = self.args[ind]._eval_derivative_matrix_lines(x)
- for i in d:
- i.append_first(left_rev)
- i.append_second(right_mat)
- lines.append(i)
- return lines
- mul.register_handlerclass((Mul, MatMul), MatMul)
- def validate(*matrices):
- """ Checks for valid shapes for args of MatMul """
- for i in range(len(matrices)-1):
- A, B = matrices[i:i+2]
- if A.cols != B.rows:
- raise ShapeError("Matrices %s and %s are not aligned"%(A, B))
- # Rules
- def newmul(*args):
- if args[0] == 1:
- args = args[1:]
- return new(MatMul, *args)
- def any_zeros(mul):
- if any(arg.is_zero or (arg.is_Matrix and arg.is_ZeroMatrix)
- for arg in mul.args):
- matrices = [arg for arg in mul.args if arg.is_Matrix]
- return ZeroMatrix(matrices[0].rows, matrices[-1].cols)
- return mul
- def merge_explicit(matmul):
- """ Merge explicit MatrixBase arguments
- >>> from sympy import MatrixSymbol, Matrix, MatMul, pprint
- >>> from sympy.matrices.expressions.matmul import merge_explicit
- >>> A = MatrixSymbol('A', 2, 2)
- >>> B = Matrix([[1, 1], [1, 1]])
- >>> C = Matrix([[1, 2], [3, 4]])
- >>> X = MatMul(A, B, C)
- >>> pprint(X)
- [1 1] [1 2]
- A*[ ]*[ ]
- [1 1] [3 4]
- >>> pprint(merge_explicit(X))
- [4 6]
- A*[ ]
- [4 6]
- >>> X = MatMul(B, A, C)
- >>> pprint(X)
- [1 1] [1 2]
- [ ]*A*[ ]
- [1 1] [3 4]
- >>> pprint(merge_explicit(X))
- [1 1] [1 2]
- [ ]*A*[ ]
- [1 1] [3 4]
- """
- if not any(isinstance(arg, MatrixBase) for arg in matmul.args):
- return matmul
- newargs = []
- last = matmul.args[0]
- for arg in matmul.args[1:]:
- if isinstance(arg, (MatrixBase, Number)) and isinstance(last, (MatrixBase, Number)):
- last = last * arg
- else:
- newargs.append(last)
- last = arg
- newargs.append(last)
- return MatMul(*newargs)
- def remove_ids(mul):
- """ Remove Identities from a MatMul
- This is a modified version of sympy.strategies.rm_id.
- This is necesssary because MatMul may contain both MatrixExprs and Exprs
- as args.
- See Also
- ========
- sympy.strategies.rm_id
- """
- # Separate Exprs from MatrixExprs in args
- factor, mmul = mul.as_coeff_mmul()
- # Apply standard rm_id for MatMuls
- result = rm_id(lambda x: x.is_Identity is True)(mmul)
- if result != mmul:
- return newmul(factor, *result.args) # Recombine and return
- else:
- return mul
- def factor_in_front(mul):
- factor, matrices = mul.as_coeff_matrices()
- if factor != 1:
- return newmul(factor, *matrices)
- return mul
- def combine_powers(mul):
- r"""Combine consecutive powers with the same base into one, e.g.
- $$A \times A^2 \Rightarrow A^3$$
- This also cancels out the possible matrix inverses using the
- knowledgebase of :class:`~.Inverse`, e.g.,
- $$ Y \times X \times X^{-1} \Rightarrow Y $$
- """
- factor, args = mul.as_coeff_matrices()
- new_args = [args[0]]
- for B in args[1:]:
- A = new_args[-1]
- if A.is_square == False or B.is_square == False:
- new_args.append(B)
- continue
- if isinstance(A, MatPow):
- A_base, A_exp = A.args
- else:
- A_base, A_exp = A, S.One
- if isinstance(B, MatPow):
- B_base, B_exp = B.args
- else:
- B_base, B_exp = B, S.One
- if A_base == B_base:
- new_exp = A_exp + B_exp
- new_args[-1] = MatPow(A_base, new_exp).doit(deep=False)
- continue
- elif not isinstance(B_base, MatrixBase):
- try:
- B_base_inv = B_base.inverse()
- except NonInvertibleMatrixError:
- B_base_inv = None
- if B_base_inv is not None and A_base == B_base_inv:
- new_exp = A_exp - B_exp
- new_args[-1] = MatPow(A_base, new_exp).doit(deep=False)
- continue
- new_args.append(B)
- return newmul(factor, *new_args)
- def combine_permutations(mul):
- """Refine products of permutation matrices as the products of cycles.
- """
- args = mul.args
- l = len(args)
- if l < 2:
- return mul
- result = [args[0]]
- for i in range(1, l):
- A = result[-1]
- B = args[i]
- if isinstance(A, PermutationMatrix) and \
- isinstance(B, PermutationMatrix):
- cycle_1 = A.args[0]
- cycle_2 = B.args[0]
- result[-1] = PermutationMatrix(cycle_1 * cycle_2)
- else:
- result.append(B)
- return MatMul(*result)
- def combine_one_matrices(mul):
- """
- Combine products of OneMatrix
- e.g. OneMatrix(2, 3) * OneMatrix(3, 4) -> 3 * OneMatrix(2, 4)
- """
- factor, args = mul.as_coeff_matrices()
- new_args = [args[0]]
- for B in args[1:]:
- A = new_args[-1]
- if not isinstance(A, OneMatrix) or not isinstance(B, OneMatrix):
- new_args.append(B)
- continue
- new_args.pop()
- new_args.append(OneMatrix(A.shape[0], B.shape[1]))
- factor *= A.shape[1]
- return newmul(factor, *new_args)
- def distribute_monom(mul):
- """
- Simplify MatMul expressions but distributing
- rational term to MatMul.
- e.g. 2*(A+B) -> 2*A + 2*B
- """
- args = mul.args
- if len(args) == 2:
- from .matadd import MatAdd
- if args[0].is_MatAdd and args[1].is_Rational:
- return MatAdd(*[MatMul(mat, args[1]).doit() for mat in args[0].args])
- if args[1].is_MatAdd and args[0].is_Rational:
- return MatAdd(*[MatMul(args[0], mat).doit() for mat in args[1].args])
- return mul
- rules = (
- distribute_monom, any_zeros, remove_ids, combine_one_matrices, combine_powers, unpack, rm_id(lambda x: x == 1),
- merge_explicit, factor_in_front, flatten, combine_permutations)
- canonicalize = exhaust(typed({MatMul: do_one(*rules)}))
- def only_squares(*matrices):
- """factor matrices only if they are square"""
- if matrices[0].rows != matrices[-1].cols:
- raise RuntimeError("Invalid matrices being multiplied")
- out = []
- start = 0
- for i, M in enumerate(matrices):
- if M.cols == matrices[start].rows:
- out.append(MatMul(*matrices[start:i+1]).doit())
- start = i+1
- return out
- def refine_MatMul(expr, assumptions):
- """
- >>> from sympy import MatrixSymbol, Q, assuming, refine
- >>> X = MatrixSymbol('X', 2, 2)
- >>> expr = X * X.T
- >>> print(expr)
- X*X.T
- >>> with assuming(Q.orthogonal(X)):
- ... print(refine(expr))
- I
- """
- newargs = []
- exprargs = []
- for args in expr.args:
- if args.is_Matrix:
- exprargs.append(args)
- else:
- newargs.append(args)
- last = exprargs[0]
- for arg in exprargs[1:]:
- if arg == last.T and ask(Q.orthogonal(arg), assumptions):
- last = Identity(arg.shape[0])
- elif arg == last.conjugate() and ask(Q.unitary(arg), assumptions):
- last = Identity(arg.shape[0])
- else:
- newargs.append(last)
- last = arg
- newargs.append(last)
- return MatMul(*newargs)
- handlers_dict['MatMul'] = refine_MatMul
|