matexpr.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868
  1. from typing import Tuple as tTuple
  2. from functools import wraps
  3. from sympy.core import S, Integer, Basic, Mul, Add
  4. from sympy.core.assumptions import check_assumptions
  5. from sympy.core.decorators import call_highest_priority
  6. from sympy.core.expr import Expr, ExprBuilder
  7. from sympy.core.logic import FuzzyBool
  8. from sympy.core.symbol import Str, Dummy, symbols, Symbol
  9. from sympy.core.sympify import SympifyError, _sympify
  10. from sympy.external.gmpy import SYMPY_INTS
  11. from sympy.functions import conjugate, adjoint
  12. from sympy.functions.special.tensor_functions import KroneckerDelta
  13. from sympy.matrices.common import NonSquareMatrixError
  14. from sympy.matrices.matrices import MatrixKind, MatrixBase
  15. from sympy.multipledispatch import dispatch
  16. from sympy.simplify import simplify
  17. from sympy.utilities.misc import filldedent
  18. def _sympifyit(arg, retval=None):
  19. # This version of _sympifyit sympifies MutableMatrix objects
  20. def deco(func):
  21. @wraps(func)
  22. def __sympifyit_wrapper(a, b):
  23. try:
  24. b = _sympify(b)
  25. return func(a, b)
  26. except SympifyError:
  27. return retval
  28. return __sympifyit_wrapper
  29. return deco
  30. class MatrixExpr(Expr):
  31. """Superclass for Matrix Expressions
  32. MatrixExprs represent abstract matrices, linear transformations represented
  33. within a particular basis.
  34. Examples
  35. ========
  36. >>> from sympy import MatrixSymbol
  37. >>> A = MatrixSymbol('A', 3, 3)
  38. >>> y = MatrixSymbol('y', 3, 1)
  39. >>> x = (A.T*A).I * A * y
  40. See Also
  41. ========
  42. MatrixSymbol, MatAdd, MatMul, Transpose, Inverse
  43. """
  44. # Should not be considered iterable by the
  45. # sympy.utilities.iterables.iterable function. Subclass that actually are
  46. # iterable (i.e., explicit matrices) should set this to True.
  47. _iterable = False
  48. _op_priority = 11.0
  49. is_Matrix = True # type: bool
  50. is_MatrixExpr = True # type: bool
  51. is_Identity = None # type: FuzzyBool
  52. is_Inverse = False
  53. is_Transpose = False
  54. is_ZeroMatrix = False
  55. is_MatAdd = False
  56. is_MatMul = False
  57. is_commutative = False
  58. is_number = False
  59. is_symbol = False
  60. is_scalar = False
  61. kind: MatrixKind = MatrixKind()
  62. def __new__(cls, *args, **kwargs):
  63. args = map(_sympify, args)
  64. return Basic.__new__(cls, *args, **kwargs)
  65. # The following is adapted from the core Expr object
  66. @property
  67. def shape(self) -> tTuple[Expr, Expr]:
  68. raise NotImplementedError
  69. @property
  70. def _add_handler(self):
  71. return MatAdd
  72. @property
  73. def _mul_handler(self):
  74. return MatMul
  75. def __neg__(self):
  76. return MatMul(S.NegativeOne, self).doit()
  77. def __abs__(self):
  78. raise NotImplementedError
  79. @_sympifyit('other', NotImplemented)
  80. @call_highest_priority('__radd__')
  81. def __add__(self, other):
  82. return MatAdd(self, other, check=True).doit()
  83. @_sympifyit('other', NotImplemented)
  84. @call_highest_priority('__add__')
  85. def __radd__(self, other):
  86. return MatAdd(other, self, check=True).doit()
  87. @_sympifyit('other', NotImplemented)
  88. @call_highest_priority('__rsub__')
  89. def __sub__(self, other):
  90. return MatAdd(self, -other, check=True).doit()
  91. @_sympifyit('other', NotImplemented)
  92. @call_highest_priority('__sub__')
  93. def __rsub__(self, other):
  94. return MatAdd(other, -self, check=True).doit()
  95. @_sympifyit('other', NotImplemented)
  96. @call_highest_priority('__rmul__')
  97. def __mul__(self, other):
  98. return MatMul(self, other).doit()
  99. @_sympifyit('other', NotImplemented)
  100. @call_highest_priority('__rmul__')
  101. def __matmul__(self, other):
  102. return MatMul(self, other).doit()
  103. @_sympifyit('other', NotImplemented)
  104. @call_highest_priority('__mul__')
  105. def __rmul__(self, other):
  106. return MatMul(other, self).doit()
  107. @_sympifyit('other', NotImplemented)
  108. @call_highest_priority('__mul__')
  109. def __rmatmul__(self, other):
  110. return MatMul(other, self).doit()
  111. @_sympifyit('other', NotImplemented)
  112. @call_highest_priority('__rpow__')
  113. def __pow__(self, other):
  114. return MatPow(self, other).doit()
  115. @_sympifyit('other', NotImplemented)
  116. @call_highest_priority('__pow__')
  117. def __rpow__(self, other):
  118. raise NotImplementedError("Matrix Power not defined")
  119. @_sympifyit('other', NotImplemented)
  120. @call_highest_priority('__rtruediv__')
  121. def __truediv__(self, other):
  122. return self * other**S.NegativeOne
  123. @_sympifyit('other', NotImplemented)
  124. @call_highest_priority('__truediv__')
  125. def __rtruediv__(self, other):
  126. raise NotImplementedError()
  127. #return MatMul(other, Pow(self, S.NegativeOne))
  128. @property
  129. def rows(self):
  130. return self.shape[0]
  131. @property
  132. def cols(self):
  133. return self.shape[1]
  134. @property
  135. def is_square(self):
  136. return self.rows == self.cols
  137. def _eval_conjugate(self):
  138. from sympy.matrices.expressions.adjoint import Adjoint
  139. return Adjoint(Transpose(self))
  140. def as_real_imag(self, deep=True, **hints):
  141. real = S.Half * (self + self._eval_conjugate())
  142. im = (self - self._eval_conjugate())/(2*S.ImaginaryUnit)
  143. return (real, im)
  144. def _eval_inverse(self):
  145. return Inverse(self)
  146. def _eval_determinant(self):
  147. return Determinant(self)
  148. def _eval_transpose(self):
  149. return Transpose(self)
  150. def _eval_power(self, exp):
  151. """
  152. Override this in sub-classes to implement simplification of powers. The cases where the exponent
  153. is -1, 0, 1 are already covered in MatPow.doit(), so implementations can exclude these cases.
  154. """
  155. return MatPow(self, exp)
  156. def _eval_simplify(self, **kwargs):
  157. if self.is_Atom:
  158. return self
  159. else:
  160. return self.func(*[simplify(x, **kwargs) for x in self.args])
  161. def _eval_adjoint(self):
  162. from sympy.matrices.expressions.adjoint import Adjoint
  163. return Adjoint(self)
  164. def _eval_derivative_n_times(self, x, n):
  165. return Basic._eval_derivative_n_times(self, x, n)
  166. def _eval_derivative(self, x):
  167. # `x` is a scalar:
  168. if self.has(x):
  169. # See if there are other methods using it:
  170. return super()._eval_derivative(x)
  171. else:
  172. return ZeroMatrix(*self.shape)
  173. @classmethod
  174. def _check_dim(cls, dim):
  175. """Helper function to check invalid matrix dimensions"""
  176. ok = check_assumptions(dim, integer=True, nonnegative=True)
  177. if ok is False:
  178. raise ValueError(
  179. "The dimension specification {} should be "
  180. "a nonnegative integer.".format(dim))
  181. def _entry(self, i, j, **kwargs):
  182. raise NotImplementedError(
  183. "Indexing not implemented for %s" % self.__class__.__name__)
  184. def adjoint(self):
  185. return adjoint(self)
  186. def as_coeff_Mul(self, rational=False):
  187. """Efficiently extract the coefficient of a product. """
  188. return S.One, self
  189. def conjugate(self):
  190. return conjugate(self)
  191. def transpose(self):
  192. from sympy.matrices.expressions.transpose import transpose
  193. return transpose(self)
  194. @property
  195. def T(self):
  196. '''Matrix transposition'''
  197. return self.transpose()
  198. def inverse(self):
  199. if not self.is_square:
  200. raise NonSquareMatrixError('Inverse of non-square matrix')
  201. return self._eval_inverse()
  202. def inv(self):
  203. return self.inverse()
  204. def det(self):
  205. from sympy.matrices.expressions.determinant import det
  206. return det(self)
  207. @property
  208. def I(self):
  209. return self.inverse()
  210. def valid_index(self, i, j):
  211. def is_valid(idx):
  212. return isinstance(idx, (int, Integer, Symbol, Expr))
  213. return (is_valid(i) and is_valid(j) and
  214. (self.rows is None or
  215. (0 <= i) != False and (i < self.rows) != False) and
  216. (0 <= j) != False and (j < self.cols) != False)
  217. def __getitem__(self, key):
  218. if not isinstance(key, tuple) and isinstance(key, slice):
  219. from sympy.matrices.expressions.slice import MatrixSlice
  220. return MatrixSlice(self, key, (0, None, 1))
  221. if isinstance(key, tuple) and len(key) == 2:
  222. i, j = key
  223. if isinstance(i, slice) or isinstance(j, slice):
  224. from sympy.matrices.expressions.slice import MatrixSlice
  225. return MatrixSlice(self, i, j)
  226. i, j = _sympify(i), _sympify(j)
  227. if self.valid_index(i, j) != False:
  228. return self._entry(i, j)
  229. else:
  230. raise IndexError("Invalid indices (%s, %s)" % (i, j))
  231. elif isinstance(key, (SYMPY_INTS, Integer)):
  232. # row-wise decomposition of matrix
  233. rows, cols = self.shape
  234. # allow single indexing if number of columns is known
  235. if not isinstance(cols, Integer):
  236. raise IndexError(filldedent('''
  237. Single indexing is only supported when the number
  238. of columns is known.'''))
  239. key = _sympify(key)
  240. i = key // cols
  241. j = key % cols
  242. if self.valid_index(i, j) != False:
  243. return self._entry(i, j)
  244. else:
  245. raise IndexError("Invalid index %s" % key)
  246. elif isinstance(key, (Symbol, Expr)):
  247. raise IndexError(filldedent('''
  248. Only integers may be used when addressing the matrix
  249. with a single index.'''))
  250. raise IndexError("Invalid index, wanted %s[i,j]" % self)
  251. def _is_shape_symbolic(self) -> bool:
  252. return (not isinstance(self.rows, (SYMPY_INTS, Integer))
  253. or not isinstance(self.cols, (SYMPY_INTS, Integer)))
  254. def as_explicit(self):
  255. """
  256. Returns a dense Matrix with elements represented explicitly
  257. Returns an object of type ImmutableDenseMatrix.
  258. Examples
  259. ========
  260. >>> from sympy import Identity
  261. >>> I = Identity(3)
  262. >>> I
  263. I
  264. >>> I.as_explicit()
  265. Matrix([
  266. [1, 0, 0],
  267. [0, 1, 0],
  268. [0, 0, 1]])
  269. See Also
  270. ========
  271. as_mutable: returns mutable Matrix type
  272. """
  273. if self._is_shape_symbolic():
  274. raise ValueError(
  275. 'Matrix with symbolic shape '
  276. 'cannot be represented explicitly.')
  277. from sympy.matrices.immutable import ImmutableDenseMatrix
  278. return ImmutableDenseMatrix([[self[i, j]
  279. for j in range(self.cols)]
  280. for i in range(self.rows)])
  281. def as_mutable(self):
  282. """
  283. Returns a dense, mutable matrix with elements represented explicitly
  284. Examples
  285. ========
  286. >>> from sympy import Identity
  287. >>> I = Identity(3)
  288. >>> I
  289. I
  290. >>> I.shape
  291. (3, 3)
  292. >>> I.as_mutable()
  293. Matrix([
  294. [1, 0, 0],
  295. [0, 1, 0],
  296. [0, 0, 1]])
  297. See Also
  298. ========
  299. as_explicit: returns ImmutableDenseMatrix
  300. """
  301. return self.as_explicit().as_mutable()
  302. def __array__(self):
  303. from numpy import empty
  304. a = empty(self.shape, dtype=object)
  305. for i in range(self.rows):
  306. for j in range(self.cols):
  307. a[i, j] = self[i, j]
  308. return a
  309. def equals(self, other):
  310. """
  311. Test elementwise equality between matrices, potentially of different
  312. types
  313. >>> from sympy import Identity, eye
  314. >>> Identity(3).equals(eye(3))
  315. True
  316. """
  317. return self.as_explicit().equals(other)
  318. def canonicalize(self):
  319. return self
  320. def as_coeff_mmul(self):
  321. return 1, MatMul(self)
  322. @staticmethod
  323. def from_index_summation(expr, first_index=None, last_index=None, dimensions=None):
  324. r"""
  325. Parse expression of matrices with explicitly summed indices into a
  326. matrix expression without indices, if possible.
  327. This transformation expressed in mathematical notation:
  328. `\sum_{j=0}^{N-1} A_{i,j} B_{j,k} \Longrightarrow \mathbf{A}\cdot \mathbf{B}`
  329. Optional parameter ``first_index``: specify which free index to use as
  330. the index starting the expression.
  331. Examples
  332. ========
  333. >>> from sympy import MatrixSymbol, MatrixExpr, Sum
  334. >>> from sympy.abc import i, j, k, l, N
  335. >>> A = MatrixSymbol("A", N, N)
  336. >>> B = MatrixSymbol("B", N, N)
  337. >>> expr = Sum(A[i, j]*B[j, k], (j, 0, N-1))
  338. >>> MatrixExpr.from_index_summation(expr)
  339. A*B
  340. Transposition is detected:
  341. >>> expr = Sum(A[j, i]*B[j, k], (j, 0, N-1))
  342. >>> MatrixExpr.from_index_summation(expr)
  343. A.T*B
  344. Detect the trace:
  345. >>> expr = Sum(A[i, i], (i, 0, N-1))
  346. >>> MatrixExpr.from_index_summation(expr)
  347. Trace(A)
  348. More complicated expressions:
  349. >>> expr = Sum(A[i, j]*B[k, j]*A[l, k], (j, 0, N-1), (k, 0, N-1))
  350. >>> MatrixExpr.from_index_summation(expr)
  351. A*B.T*A.T
  352. """
  353. from sympy.tensor.array.expressions.conv_indexed_to_array import convert_indexed_to_array
  354. from sympy.tensor.array.expressions.conv_array_to_matrix import convert_array_to_matrix
  355. first_indices = []
  356. if first_index is not None:
  357. first_indices.append(first_index)
  358. if last_index is not None:
  359. first_indices.append(last_index)
  360. arr = convert_indexed_to_array(expr, first_indices=first_indices)
  361. return convert_array_to_matrix(arr)
  362. def applyfunc(self, func):
  363. from .applyfunc import ElementwiseApplyFunction
  364. return ElementwiseApplyFunction(func, self)
  365. @dispatch(MatrixExpr, Expr)
  366. def _eval_is_eq(lhs, rhs): # noqa:F811
  367. return False
  368. @dispatch(MatrixExpr, MatrixExpr) # type: ignore
  369. def _eval_is_eq(lhs, rhs): # noqa:F811
  370. if lhs.shape != rhs.shape:
  371. return False
  372. if (lhs - rhs).is_ZeroMatrix:
  373. return True
  374. def get_postprocessor(cls):
  375. def _postprocessor(expr):
  376. # To avoid circular imports, we can't have MatMul/MatAdd on the top level
  377. mat_class = {Mul: MatMul, Add: MatAdd}[cls]
  378. nonmatrices = []
  379. matrices = []
  380. for term in expr.args:
  381. if isinstance(term, MatrixExpr):
  382. matrices.append(term)
  383. else:
  384. nonmatrices.append(term)
  385. if not matrices:
  386. return cls._from_args(nonmatrices)
  387. if nonmatrices:
  388. if cls == Mul:
  389. for i in range(len(matrices)):
  390. if not matrices[i].is_MatrixExpr:
  391. # If one of the matrices explicit, absorb the scalar into it
  392. # (doit will combine all explicit matrices into one, so it
  393. # doesn't matter which)
  394. matrices[i] = matrices[i].__mul__(cls._from_args(nonmatrices))
  395. nonmatrices = []
  396. break
  397. else:
  398. # Maintain the ability to create Add(scalar, matrix) without
  399. # raising an exception. That way different algorithms can
  400. # replace matrix expressions with non-commutative symbols to
  401. # manipulate them like non-commutative scalars.
  402. return cls._from_args(nonmatrices + [mat_class(*matrices).doit(deep=False)])
  403. if mat_class == MatAdd:
  404. return mat_class(*matrices).doit(deep=False)
  405. return mat_class(cls._from_args(nonmatrices), *matrices).doit(deep=False)
  406. return _postprocessor
  407. Basic._constructor_postprocessor_mapping[MatrixExpr] = {
  408. "Mul": [get_postprocessor(Mul)],
  409. "Add": [get_postprocessor(Add)],
  410. }
  411. def _matrix_derivative(expr, x, old_algorithm=False):
  412. if isinstance(expr, MatrixBase) or isinstance(x, MatrixBase):
  413. # Do not use array expressions for explicit matrices:
  414. old_algorithm = True
  415. if old_algorithm:
  416. return _matrix_derivative_old_algorithm(expr, x)
  417. from sympy.tensor.array.expressions.conv_matrix_to_array import convert_matrix_to_array
  418. from sympy.tensor.array.expressions.arrayexpr_derivatives import array_derive
  419. from sympy.tensor.array.expressions.conv_array_to_matrix import convert_array_to_matrix
  420. array_expr = convert_matrix_to_array(expr)
  421. diff_array_expr = array_derive(array_expr, x)
  422. diff_matrix_expr = convert_array_to_matrix(diff_array_expr)
  423. return diff_matrix_expr
  424. def _matrix_derivative_old_algorithm(expr, x):
  425. from sympy.tensor.array.array_derivatives import ArrayDerivative
  426. lines = expr._eval_derivative_matrix_lines(x)
  427. parts = [i.build() for i in lines]
  428. from sympy.tensor.array.expressions.conv_array_to_matrix import convert_array_to_matrix
  429. parts = [[convert_array_to_matrix(j) for j in i] for i in parts]
  430. def _get_shape(elem):
  431. if isinstance(elem, MatrixExpr):
  432. return elem.shape
  433. return 1, 1
  434. def get_rank(parts):
  435. return sum([j not in (1, None) for i in parts for j in _get_shape(i)])
  436. ranks = [get_rank(i) for i in parts]
  437. rank = ranks[0]
  438. def contract_one_dims(parts):
  439. if len(parts) == 1:
  440. return parts[0]
  441. else:
  442. p1, p2 = parts[:2]
  443. if p2.is_Matrix:
  444. p2 = p2.T
  445. if p1 == Identity(1):
  446. pbase = p2
  447. elif p2 == Identity(1):
  448. pbase = p1
  449. else:
  450. pbase = p1*p2
  451. if len(parts) == 2:
  452. return pbase
  453. else: # len(parts) > 2
  454. if pbase.is_Matrix:
  455. raise ValueError("")
  456. return pbase*Mul.fromiter(parts[2:])
  457. if rank <= 2:
  458. return Add.fromiter([contract_one_dims(i) for i in parts])
  459. return ArrayDerivative(expr, x)
  460. class MatrixElement(Expr):
  461. parent = property(lambda self: self.args[0])
  462. i = property(lambda self: self.args[1])
  463. j = property(lambda self: self.args[2])
  464. _diff_wrt = True
  465. is_symbol = True
  466. is_commutative = True
  467. def __new__(cls, name, n, m):
  468. n, m = map(_sympify, (n, m))
  469. from sympy.matrices.matrices import MatrixBase
  470. if isinstance(name, (MatrixBase,)):
  471. if n.is_Integer and m.is_Integer:
  472. return name[n, m]
  473. if isinstance(name, str):
  474. name = Symbol(name)
  475. else:
  476. name = _sympify(name)
  477. if not isinstance(name.kind, MatrixKind):
  478. raise TypeError("First argument of MatrixElement should be a matrix")
  479. obj = Expr.__new__(cls, name, n, m)
  480. return obj
  481. def doit(self, **kwargs):
  482. deep = kwargs.get('deep', True)
  483. if deep:
  484. args = [arg.doit(**kwargs) for arg in self.args]
  485. else:
  486. args = self.args
  487. return args[0][args[1], args[2]]
  488. @property
  489. def indices(self):
  490. return self.args[1:]
  491. def _eval_derivative(self, v):
  492. if not isinstance(v, MatrixElement):
  493. from sympy.matrices.matrices import MatrixBase
  494. if isinstance(self.parent, MatrixBase):
  495. return self.parent.diff(v)[self.i, self.j]
  496. return S.Zero
  497. M = self.args[0]
  498. m, n = self.parent.shape
  499. if M == v.args[0]:
  500. return KroneckerDelta(self.args[1], v.args[1], (0, m-1)) * \
  501. KroneckerDelta(self.args[2], v.args[2], (0, n-1))
  502. if isinstance(M, Inverse):
  503. from sympy.concrete.summations import Sum
  504. i, j = self.args[1:]
  505. i1, i2 = symbols("z1, z2", cls=Dummy)
  506. Y = M.args[0]
  507. r1, r2 = Y.shape
  508. return -Sum(M[i, i1]*Y[i1, i2].diff(v)*M[i2, j], (i1, 0, r1-1), (i2, 0, r2-1))
  509. if self.has(v.args[0]):
  510. return None
  511. return S.Zero
  512. class MatrixSymbol(MatrixExpr):
  513. """Symbolic representation of a Matrix object
  514. Creates a SymPy Symbol to represent a Matrix. This matrix has a shape and
  515. can be included in Matrix Expressions
  516. Examples
  517. ========
  518. >>> from sympy import MatrixSymbol, Identity
  519. >>> A = MatrixSymbol('A', 3, 4) # A 3 by 4 Matrix
  520. >>> B = MatrixSymbol('B', 4, 3) # A 4 by 3 Matrix
  521. >>> A.shape
  522. (3, 4)
  523. >>> 2*A*B + Identity(3)
  524. I + 2*A*B
  525. """
  526. is_commutative = False
  527. is_symbol = True
  528. _diff_wrt = True
  529. def __new__(cls, name, n, m):
  530. n, m = _sympify(n), _sympify(m)
  531. cls._check_dim(m)
  532. cls._check_dim(n)
  533. if isinstance(name, str):
  534. name = Str(name)
  535. obj = Basic.__new__(cls, name, n, m)
  536. return obj
  537. @property
  538. def shape(self):
  539. return self.args[1], self.args[2]
  540. @property
  541. def name(self):
  542. return self.args[0].name
  543. def _entry(self, i, j, **kwargs):
  544. return MatrixElement(self, i, j)
  545. @property
  546. def free_symbols(self):
  547. return {self}
  548. def _eval_simplify(self, **kwargs):
  549. return self
  550. def _eval_derivative(self, x):
  551. # x is a scalar:
  552. return ZeroMatrix(self.shape[0], self.shape[1])
  553. def _eval_derivative_matrix_lines(self, x):
  554. if self != x:
  555. first = ZeroMatrix(x.shape[0], self.shape[0]) if self.shape[0] != 1 else S.Zero
  556. second = ZeroMatrix(x.shape[1], self.shape[1]) if self.shape[1] != 1 else S.Zero
  557. return [_LeftRightArgs(
  558. [first, second],
  559. )]
  560. else:
  561. first = Identity(self.shape[0]) if self.shape[0] != 1 else S.One
  562. second = Identity(self.shape[1]) if self.shape[1] != 1 else S.One
  563. return [_LeftRightArgs(
  564. [first, second],
  565. )]
  566. def matrix_symbols(expr):
  567. return [sym for sym in expr.free_symbols if sym.is_Matrix]
  568. class _LeftRightArgs:
  569. r"""
  570. Helper class to compute matrix derivatives.
  571. The logic: when an expression is derived by a matrix `X_{mn}`, two lines of
  572. matrix multiplications are created: the one contracted to `m` (first line),
  573. and the one contracted to `n` (second line).
  574. Transposition flips the side by which new matrices are connected to the
  575. lines.
  576. The trace connects the end of the two lines.
  577. """
  578. def __init__(self, lines, higher=S.One):
  579. self._lines = [i for i in lines]
  580. self._first_pointer_parent = self._lines
  581. self._first_pointer_index = 0
  582. self._first_line_index = 0
  583. self._second_pointer_parent = self._lines
  584. self._second_pointer_index = 1
  585. self._second_line_index = 1
  586. self.higher = higher
  587. @property
  588. def first_pointer(self):
  589. return self._first_pointer_parent[self._first_pointer_index]
  590. @first_pointer.setter
  591. def first_pointer(self, value):
  592. self._first_pointer_parent[self._first_pointer_index] = value
  593. @property
  594. def second_pointer(self):
  595. return self._second_pointer_parent[self._second_pointer_index]
  596. @second_pointer.setter
  597. def second_pointer(self, value):
  598. self._second_pointer_parent[self._second_pointer_index] = value
  599. def __repr__(self):
  600. built = [self._build(i) for i in self._lines]
  601. return "_LeftRightArgs(lines=%s, higher=%s)" % (
  602. built,
  603. self.higher,
  604. )
  605. def transpose(self):
  606. self._first_pointer_parent, self._second_pointer_parent = self._second_pointer_parent, self._first_pointer_parent
  607. self._first_pointer_index, self._second_pointer_index = self._second_pointer_index, self._first_pointer_index
  608. self._first_line_index, self._second_line_index = self._second_line_index, self._first_line_index
  609. return self
  610. @staticmethod
  611. def _build(expr):
  612. if isinstance(expr, ExprBuilder):
  613. return expr.build()
  614. if isinstance(expr, list):
  615. if len(expr) == 1:
  616. return expr[0]
  617. else:
  618. return expr[0](*[_LeftRightArgs._build(i) for i in expr[1]])
  619. else:
  620. return expr
  621. def build(self):
  622. data = [self._build(i) for i in self._lines]
  623. if self.higher != 1:
  624. data += [self._build(self.higher)]
  625. data = [i for i in data]
  626. return data
  627. def matrix_form(self):
  628. if self.first != 1 and self.higher != 1:
  629. raise ValueError("higher dimensional array cannot be represented")
  630. def _get_shape(elem):
  631. if isinstance(elem, MatrixExpr):
  632. return elem.shape
  633. return (None, None)
  634. if _get_shape(self.first)[1] != _get_shape(self.second)[1]:
  635. # Remove one-dimensional identity matrices:
  636. # (this is needed by `a.diff(a)` where `a` is a vector)
  637. if _get_shape(self.second) == (1, 1):
  638. return self.first*self.second[0, 0]
  639. if _get_shape(self.first) == (1, 1):
  640. return self.first[1, 1]*self.second.T
  641. raise ValueError("incompatible shapes")
  642. if self.first != 1:
  643. return self.first*self.second.T
  644. else:
  645. return self.higher
  646. def rank(self):
  647. """
  648. Number of dimensions different from trivial (warning: not related to
  649. matrix rank).
  650. """
  651. rank = 0
  652. if self.first != 1:
  653. rank += sum([i != 1 for i in self.first.shape])
  654. if self.second != 1:
  655. rank += sum([i != 1 for i in self.second.shape])
  656. if self.higher != 1:
  657. rank += 2
  658. return rank
  659. def _multiply_pointer(self, pointer, other):
  660. from ...tensor.array.expressions.array_expressions import ArrayTensorProduct
  661. from ...tensor.array.expressions.array_expressions import ArrayContraction
  662. subexpr = ExprBuilder(
  663. ArrayContraction,
  664. [
  665. ExprBuilder(
  666. ArrayTensorProduct,
  667. [
  668. pointer,
  669. other
  670. ]
  671. ),
  672. (1, 2)
  673. ],
  674. validator=ArrayContraction._validate
  675. )
  676. return subexpr
  677. def append_first(self, other):
  678. self.first_pointer *= other
  679. def append_second(self, other):
  680. self.second_pointer *= other
  681. def _make_matrix(x):
  682. from sympy.matrices.immutable import ImmutableDenseMatrix
  683. if isinstance(x, MatrixExpr):
  684. return x
  685. return ImmutableDenseMatrix([[x]])
  686. from .matmul import MatMul
  687. from .matadd import MatAdd
  688. from .matpow import MatPow
  689. from .transpose import Transpose
  690. from .inverse import Inverse
  691. from .special import ZeroMatrix, Identity
  692. from .determinant import Determinant