matadd.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. from functools import reduce
  2. import operator
  3. from sympy.core import Basic, sympify
  4. from sympy.core.add import add, Add, _could_extract_minus_sign
  5. from sympy.core.sorting import default_sort_key
  6. from sympy.functions import adjoint
  7. from sympy.matrices.common import ShapeError
  8. from sympy.matrices.matrices import MatrixBase
  9. from sympy.matrices.expressions.transpose import transpose
  10. from sympy.strategies import (rm_id, unpack, flatten, sort, condition,
  11. exhaust, do_one, glom)
  12. from sympy.matrices.expressions.matexpr import MatrixExpr
  13. from sympy.matrices.expressions.special import ZeroMatrix, GenericZeroMatrix
  14. from sympy.utilities import sift
  15. # XXX: MatAdd should perhaps not subclass directly from Add
  16. class MatAdd(MatrixExpr, Add):
  17. """A Sum of Matrix Expressions
  18. MatAdd inherits from and operates like SymPy Add
  19. Examples
  20. ========
  21. >>> from sympy import MatAdd, MatrixSymbol
  22. >>> A = MatrixSymbol('A', 5, 5)
  23. >>> B = MatrixSymbol('B', 5, 5)
  24. >>> C = MatrixSymbol('C', 5, 5)
  25. >>> MatAdd(A, B, C)
  26. A + B + C
  27. """
  28. is_MatAdd = True
  29. identity = GenericZeroMatrix()
  30. def __new__(cls, *args, evaluate=False, check=False, _sympify=True):
  31. if not args:
  32. return cls.identity
  33. # This must be removed aggressively in the constructor to avoid
  34. # TypeErrors from GenericZeroMatrix().shape
  35. args = list(filter(lambda i: cls.identity != i, args))
  36. if _sympify:
  37. args = list(map(sympify, args))
  38. obj = Basic.__new__(cls, *args)
  39. if check:
  40. if not any(isinstance(i, MatrixExpr) for i in args):
  41. return Add.fromiter(args)
  42. validate(*args)
  43. if evaluate:
  44. if not any(isinstance(i, MatrixExpr) for i in args):
  45. return Add(*args, evaluate=True)
  46. obj = canonicalize(obj)
  47. return obj
  48. @property
  49. def shape(self):
  50. return self.args[0].shape
  51. def could_extract_minus_sign(self):
  52. return _could_extract_minus_sign(self)
  53. def _entry(self, i, j, **kwargs):
  54. return Add(*[arg._entry(i, j, **kwargs) for arg in self.args])
  55. def _eval_transpose(self):
  56. return MatAdd(*[transpose(arg) for arg in self.args]).doit()
  57. def _eval_adjoint(self):
  58. return MatAdd(*[adjoint(arg) for arg in self.args]).doit()
  59. def _eval_trace(self):
  60. from .trace import trace
  61. return Add(*[trace(arg) for arg in self.args]).doit()
  62. def doit(self, **kwargs):
  63. deep = kwargs.get('deep', True)
  64. if deep:
  65. args = [arg.doit(**kwargs) for arg in self.args]
  66. else:
  67. args = self.args
  68. return canonicalize(MatAdd(*args))
  69. def _eval_derivative_matrix_lines(self, x):
  70. add_lines = [arg._eval_derivative_matrix_lines(x) for arg in self.args]
  71. return [j for i in add_lines for j in i]
  72. add.register_handlerclass((Add, MatAdd), MatAdd)
  73. def validate(*args):
  74. if not all(arg.is_Matrix for arg in args):
  75. raise TypeError("Mix of Matrix and Scalar symbols")
  76. A = args[0]
  77. for B in args[1:]:
  78. if A.shape != B.shape:
  79. raise ShapeError("Matrices %s and %s are not aligned"%(A, B))
  80. factor_of = lambda arg: arg.as_coeff_mmul()[0]
  81. matrix_of = lambda arg: unpack(arg.as_coeff_mmul()[1])
  82. def combine(cnt, mat):
  83. if cnt == 1:
  84. return mat
  85. else:
  86. return cnt * mat
  87. def merge_explicit(matadd):
  88. """ Merge explicit MatrixBase arguments
  89. Examples
  90. ========
  91. >>> from sympy import MatrixSymbol, eye, Matrix, MatAdd, pprint
  92. >>> from sympy.matrices.expressions.matadd import merge_explicit
  93. >>> A = MatrixSymbol('A', 2, 2)
  94. >>> B = eye(2)
  95. >>> C = Matrix([[1, 2], [3, 4]])
  96. >>> X = MatAdd(A, B, C)
  97. >>> pprint(X)
  98. [1 0] [1 2]
  99. A + [ ] + [ ]
  100. [0 1] [3 4]
  101. >>> pprint(merge_explicit(X))
  102. [2 2]
  103. A + [ ]
  104. [3 5]
  105. """
  106. groups = sift(matadd.args, lambda arg: isinstance(arg, MatrixBase))
  107. if len(groups[True]) > 1:
  108. return MatAdd(*(groups[False] + [reduce(operator.add, groups[True])]))
  109. else:
  110. return matadd
  111. rules = (rm_id(lambda x: x == 0 or isinstance(x, ZeroMatrix)),
  112. unpack,
  113. flatten,
  114. glom(matrix_of, factor_of, combine),
  115. merge_explicit,
  116. sort(default_sort_key))
  117. canonicalize = exhaust(condition(lambda x: isinstance(x, MatAdd),
  118. do_one(*rules)))