hadamard.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. from sympy.core import Mul, sympify
  2. from sympy.core.add import Add
  3. from sympy.core.expr import ExprBuilder
  4. from sympy.core.sorting import default_sort_key
  5. from sympy.matrices.common import ShapeError
  6. from sympy.matrices.expressions.matexpr import MatrixExpr
  7. from sympy.matrices.expressions.special import ZeroMatrix, OneMatrix
  8. from sympy.strategies import (
  9. unpack, flatten, condition, exhaust, rm_id, sort
  10. )
  11. def hadamard_product(*matrices):
  12. """
  13. Return the elementwise (aka Hadamard) product of matrices.
  14. Examples
  15. ========
  16. >>> from sympy import hadamard_product, MatrixSymbol
  17. >>> A = MatrixSymbol('A', 2, 3)
  18. >>> B = MatrixSymbol('B', 2, 3)
  19. >>> hadamard_product(A)
  20. A
  21. >>> hadamard_product(A, B)
  22. HadamardProduct(A, B)
  23. >>> hadamard_product(A, B)[0, 1]
  24. A[0, 1]*B[0, 1]
  25. """
  26. if not matrices:
  27. raise TypeError("Empty Hadamard product is undefined")
  28. validate(*matrices)
  29. if len(matrices) == 1:
  30. return matrices[0]
  31. else:
  32. matrices = [i for i in matrices if not i.is_Identity]
  33. return HadamardProduct(*matrices).doit()
  34. class HadamardProduct(MatrixExpr):
  35. """
  36. Elementwise product of matrix expressions
  37. Examples
  38. ========
  39. Hadamard product for matrix symbols:
  40. >>> from sympy import hadamard_product, HadamardProduct, MatrixSymbol
  41. >>> A = MatrixSymbol('A', 5, 5)
  42. >>> B = MatrixSymbol('B', 5, 5)
  43. >>> isinstance(hadamard_product(A, B), HadamardProduct)
  44. True
  45. Notes
  46. =====
  47. This is a symbolic object that simply stores its argument without
  48. evaluating it. To actually compute the product, use the function
  49. ``hadamard_product()`` or ``HadamardProduct.doit``
  50. """
  51. is_HadamardProduct = True
  52. def __new__(cls, *args, evaluate=False, check=True):
  53. args = list(map(sympify, args))
  54. if check:
  55. validate(*args)
  56. obj = super().__new__(cls, *args)
  57. if evaluate:
  58. obj = obj.doit(deep=False)
  59. return obj
  60. @property
  61. def shape(self):
  62. return self.args[0].shape
  63. def _entry(self, i, j, **kwargs):
  64. return Mul(*[arg._entry(i, j, **kwargs) for arg in self.args])
  65. def _eval_transpose(self):
  66. from sympy.matrices.expressions.transpose import transpose
  67. return HadamardProduct(*list(map(transpose, self.args)))
  68. def doit(self, **ignored):
  69. expr = self.func(*[i.doit(**ignored) for i in self.args])
  70. # Check for explicit matrices:
  71. from sympy.matrices.matrices import MatrixBase
  72. from sympy.matrices.immutable import ImmutableMatrix
  73. explicit = [i for i in expr.args if isinstance(i, MatrixBase)]
  74. if explicit:
  75. remainder = [i for i in expr.args if i not in explicit]
  76. expl_mat = ImmutableMatrix([
  77. Mul.fromiter(i) for i in zip(*explicit)
  78. ]).reshape(*self.shape)
  79. expr = HadamardProduct(*([expl_mat] + remainder))
  80. return canonicalize(expr)
  81. def _eval_derivative(self, x):
  82. terms = []
  83. args = list(self.args)
  84. for i in range(len(args)):
  85. factors = args[:i] + [args[i].diff(x)] + args[i+1:]
  86. terms.append(hadamard_product(*factors))
  87. return Add.fromiter(terms)
  88. def _eval_derivative_matrix_lines(self, x):
  89. from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal
  90. from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
  91. from sympy.matrices.expressions.matexpr import _make_matrix
  92. with_x_ind = [i for i, arg in enumerate(self.args) if arg.has(x)]
  93. lines = []
  94. for ind in with_x_ind:
  95. left_args = self.args[:ind]
  96. right_args = self.args[ind+1:]
  97. d = self.args[ind]._eval_derivative_matrix_lines(x)
  98. hadam = hadamard_product(*(right_args + left_args))
  99. diagonal = [(0, 2), (3, 4)]
  100. diagonal = [e for j, e in enumerate(diagonal) if self.shape[j] != 1]
  101. for i in d:
  102. l1 = i._lines[i._first_line_index]
  103. l2 = i._lines[i._second_line_index]
  104. subexpr = ExprBuilder(
  105. ArrayDiagonal,
  106. [
  107. ExprBuilder(
  108. ArrayTensorProduct,
  109. [
  110. ExprBuilder(_make_matrix, [l1]),
  111. hadam,
  112. ExprBuilder(_make_matrix, [l2]),
  113. ]
  114. ),
  115. *diagonal],
  116. )
  117. i._first_pointer_parent = subexpr.args[0].args[0].args
  118. i._first_pointer_index = 0
  119. i._second_pointer_parent = subexpr.args[0].args[2].args
  120. i._second_pointer_index = 0
  121. i._lines = [subexpr]
  122. lines.append(i)
  123. return lines
  124. def validate(*args):
  125. if not all(arg.is_Matrix for arg in args):
  126. raise TypeError("Mix of Matrix and Scalar symbols")
  127. A = args[0]
  128. for B in args[1:]:
  129. if A.shape != B.shape:
  130. raise ShapeError("Matrices %s and %s are not aligned" % (A, B))
  131. # TODO Implement algorithm for rewriting Hadamard product as diagonal matrix
  132. # if matmul identy matrix is multiplied.
  133. def canonicalize(x):
  134. """Canonicalize the Hadamard product ``x`` with mathematical properties.
  135. Examples
  136. ========
  137. >>> from sympy import MatrixSymbol, HadamardProduct
  138. >>> from sympy import OneMatrix, ZeroMatrix
  139. >>> from sympy.matrices.expressions.hadamard import canonicalize
  140. >>> from sympy import init_printing
  141. >>> init_printing(use_unicode=False)
  142. >>> A = MatrixSymbol('A', 2, 2)
  143. >>> B = MatrixSymbol('B', 2, 2)
  144. >>> C = MatrixSymbol('C', 2, 2)
  145. Hadamard product associativity:
  146. >>> X = HadamardProduct(A, HadamardProduct(B, C))
  147. >>> X
  148. A.*(B.*C)
  149. >>> canonicalize(X)
  150. A.*B.*C
  151. Hadamard product commutativity:
  152. >>> X = HadamardProduct(A, B)
  153. >>> Y = HadamardProduct(B, A)
  154. >>> X
  155. A.*B
  156. >>> Y
  157. B.*A
  158. >>> canonicalize(X)
  159. A.*B
  160. >>> canonicalize(Y)
  161. A.*B
  162. Hadamard product identity:
  163. >>> X = HadamardProduct(A, OneMatrix(2, 2))
  164. >>> X
  165. A.*1
  166. >>> canonicalize(X)
  167. A
  168. Absorbing element of Hadamard product:
  169. >>> X = HadamardProduct(A, ZeroMatrix(2, 2))
  170. >>> X
  171. A.*0
  172. >>> canonicalize(X)
  173. 0
  174. Rewriting to Hadamard Power
  175. >>> X = HadamardProduct(A, A, A)
  176. >>> X
  177. A.*A.*A
  178. >>> canonicalize(X)
  179. .3
  180. A
  181. Notes
  182. =====
  183. As the Hadamard product is associative, nested products can be flattened.
  184. The Hadamard product is commutative so that factors can be sorted for
  185. canonical form.
  186. A matrix of only ones is an identity for Hadamard product,
  187. so every matrices of only ones can be removed.
  188. Any zero matrix will make the whole product a zero matrix.
  189. Duplicate elements can be collected and rewritten as HadamardPower
  190. References
  191. ==========
  192. .. [1] https://en.wikipedia.org/wiki/Hadamard_product_(matrices)
  193. """
  194. # Associativity
  195. rule = condition(
  196. lambda x: isinstance(x, HadamardProduct),
  197. flatten
  198. )
  199. fun = exhaust(rule)
  200. x = fun(x)
  201. # Identity
  202. fun = condition(
  203. lambda x: isinstance(x, HadamardProduct),
  204. rm_id(lambda x: isinstance(x, OneMatrix))
  205. )
  206. x = fun(x)
  207. # Absorbing by Zero Matrix
  208. def absorb(x):
  209. if any(isinstance(c, ZeroMatrix) for c in x.args):
  210. return ZeroMatrix(*x.shape)
  211. else:
  212. return x
  213. fun = condition(
  214. lambda x: isinstance(x, HadamardProduct),
  215. absorb
  216. )
  217. x = fun(x)
  218. # Rewriting with HadamardPower
  219. if isinstance(x, HadamardProduct):
  220. from collections import Counter
  221. tally = Counter(x.args)
  222. new_arg = []
  223. for base, exp in tally.items():
  224. if exp == 1:
  225. new_arg.append(base)
  226. else:
  227. new_arg.append(HadamardPower(base, exp))
  228. x = HadamardProduct(*new_arg)
  229. # Commutativity
  230. fun = condition(
  231. lambda x: isinstance(x, HadamardProduct),
  232. sort(default_sort_key)
  233. )
  234. x = fun(x)
  235. # Unpacking
  236. x = unpack(x)
  237. return x
  238. def hadamard_power(base, exp):
  239. base = sympify(base)
  240. exp = sympify(exp)
  241. if exp == 1:
  242. return base
  243. if not base.is_Matrix:
  244. return base**exp
  245. if exp.is_Matrix:
  246. raise ValueError("cannot raise expression to a matrix")
  247. return HadamardPower(base, exp)
  248. class HadamardPower(MatrixExpr):
  249. r"""
  250. Elementwise power of matrix expressions
  251. Parameters
  252. ==========
  253. base : scalar or matrix
  254. exp : scalar or matrix
  255. Notes
  256. =====
  257. There are four definitions for the hadamard power which can be used.
  258. Let's consider `A, B` as `(m, n)` matrices, and `a, b` as scalars.
  259. Matrix raised to a scalar exponent:
  260. .. math::
  261. A^{\circ b} = \begin{bmatrix}
  262. A_{0, 0}^b & A_{0, 1}^b & \cdots & A_{0, n-1}^b \\
  263. A_{1, 0}^b & A_{1, 1}^b & \cdots & A_{1, n-1}^b \\
  264. \vdots & \vdots & \ddots & \vdots \\
  265. A_{m-1, 0}^b & A_{m-1, 1}^b & \cdots & A_{m-1, n-1}^b
  266. \end{bmatrix}
  267. Scalar raised to a matrix exponent:
  268. .. math::
  269. a^{\circ B} = \begin{bmatrix}
  270. a^{B_{0, 0}} & a^{B_{0, 1}} & \cdots & a^{B_{0, n-1}} \\
  271. a^{B_{1, 0}} & a^{B_{1, 1}} & \cdots & a^{B_{1, n-1}} \\
  272. \vdots & \vdots & \ddots & \vdots \\
  273. a^{B_{m-1, 0}} & a^{B_{m-1, 1}} & \cdots & a^{B_{m-1, n-1}}
  274. \end{bmatrix}
  275. Matrix raised to a matrix exponent:
  276. .. math::
  277. A^{\circ B} = \begin{bmatrix}
  278. A_{0, 0}^{B_{0, 0}} & A_{0, 1}^{B_{0, 1}} &
  279. \cdots & A_{0, n-1}^{B_{0, n-1}} \\
  280. A_{1, 0}^{B_{1, 0}} & A_{1, 1}^{B_{1, 1}} &
  281. \cdots & A_{1, n-1}^{B_{1, n-1}} \\
  282. \vdots & \vdots &
  283. \ddots & \vdots \\
  284. A_{m-1, 0}^{B_{m-1, 0}} & A_{m-1, 1}^{B_{m-1, 1}} &
  285. \cdots & A_{m-1, n-1}^{B_{m-1, n-1}}
  286. \end{bmatrix}
  287. Scalar raised to a scalar exponent:
  288. .. math::
  289. a^{\circ b} = a^b
  290. """
  291. def __new__(cls, base, exp):
  292. base = sympify(base)
  293. exp = sympify(exp)
  294. if base.is_scalar and exp.is_scalar:
  295. return base ** exp
  296. if base.is_Matrix and exp.is_Matrix and base.shape != exp.shape:
  297. raise ValueError(
  298. 'The shape of the base {} and '
  299. 'the shape of the exponent {} do not match.'
  300. .format(base.shape, exp.shape)
  301. )
  302. obj = super().__new__(cls, base, exp)
  303. return obj
  304. @property
  305. def base(self):
  306. return self._args[0]
  307. @property
  308. def exp(self):
  309. return self._args[1]
  310. @property
  311. def shape(self):
  312. if self.base.is_Matrix:
  313. return self.base.shape
  314. return self.exp.shape
  315. def _entry(self, i, j, **kwargs):
  316. base = self.base
  317. exp = self.exp
  318. if base.is_Matrix:
  319. a = base._entry(i, j, **kwargs)
  320. elif base.is_scalar:
  321. a = base
  322. else:
  323. raise ValueError(
  324. 'The base {} must be a scalar or a matrix.'.format(base))
  325. if exp.is_Matrix:
  326. b = exp._entry(i, j, **kwargs)
  327. elif exp.is_scalar:
  328. b = exp
  329. else:
  330. raise ValueError(
  331. 'The exponent {} must be a scalar or a matrix.'.format(exp))
  332. return a ** b
  333. def _eval_transpose(self):
  334. from sympy.matrices.expressions.transpose import transpose
  335. return HadamardPower(transpose(self.base), self.exp)
  336. def _eval_derivative(self, x):
  337. from sympy.functions.elementary.exponential import log
  338. dexp = self.exp.diff(x)
  339. logbase = self.base.applyfunc(log)
  340. dlbase = logbase.diff(x)
  341. return hadamard_product(
  342. dexp*logbase + self.exp*dlbase,
  343. self
  344. )
  345. def _eval_derivative_matrix_lines(self, x):
  346. from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
  347. from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal
  348. from sympy.matrices.expressions.matexpr import _make_matrix
  349. lr = self.base._eval_derivative_matrix_lines(x)
  350. for i in lr:
  351. diagonal = [(1, 2), (3, 4)]
  352. diagonal = [e for j, e in enumerate(diagonal) if self.base.shape[j] != 1]
  353. l1 = i._lines[i._first_line_index]
  354. l2 = i._lines[i._second_line_index]
  355. subexpr = ExprBuilder(
  356. ArrayDiagonal,
  357. [
  358. ExprBuilder(
  359. ArrayTensorProduct,
  360. [
  361. ExprBuilder(_make_matrix, [l1]),
  362. self.exp*hadamard_power(self.base, self.exp-1),
  363. ExprBuilder(_make_matrix, [l2]),
  364. ]
  365. ),
  366. *diagonal],
  367. validator=ArrayDiagonal._validate
  368. )
  369. i._first_pointer_parent = subexpr.args[0].args[0].args
  370. i._first_pointer_index = 0
  371. i._first_line_index = 0
  372. i._second_pointer_parent = subexpr.args[0].args[2].args
  373. i._second_pointer_index = 0
  374. i._second_line_index = 0
  375. i._lines = [subexpr]
  376. return lr