kronecker.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. """Implementation of the Kronecker product"""
  2. from sympy.core import Mul, prod, sympify
  3. from sympy.functions import adjoint
  4. from sympy.matrices.common import ShapeError
  5. from sympy.matrices.expressions.matexpr import MatrixExpr
  6. from sympy.matrices.expressions.transpose import transpose
  7. from sympy.matrices.expressions.special import Identity
  8. from sympy.matrices.matrices import MatrixBase
  9. from sympy.strategies import (
  10. canon, condition, distribute, do_one, exhaust, flatten, typed, unpack)
  11. from sympy.strategies.traverse import bottom_up
  12. from sympy.utilities import sift
  13. from .matadd import MatAdd
  14. from .matmul import MatMul
  15. from .matpow import MatPow
  16. def kronecker_product(*matrices):
  17. """
  18. The Kronecker product of two or more arguments.
  19. This computes the explicit Kronecker product for subclasses of
  20. ``MatrixBase`` i.e. explicit matrices. Otherwise, a symbolic
  21. ``KroneckerProduct`` object is returned.
  22. Examples
  23. ========
  24. For ``MatrixSymbol`` arguments a ``KroneckerProduct`` object is returned.
  25. Elements of this matrix can be obtained by indexing, or for MatrixSymbols
  26. with known dimension the explicit matrix can be obtained with
  27. ``.as_explicit()``
  28. >>> from sympy import kronecker_product, MatrixSymbol
  29. >>> A = MatrixSymbol('A', 2, 2)
  30. >>> B = MatrixSymbol('B', 2, 2)
  31. >>> kronecker_product(A)
  32. A
  33. >>> kronecker_product(A, B)
  34. KroneckerProduct(A, B)
  35. >>> kronecker_product(A, B)[0, 1]
  36. A[0, 0]*B[0, 1]
  37. >>> kronecker_product(A, B).as_explicit()
  38. Matrix([
  39. [A[0, 0]*B[0, 0], A[0, 0]*B[0, 1], A[0, 1]*B[0, 0], A[0, 1]*B[0, 1]],
  40. [A[0, 0]*B[1, 0], A[0, 0]*B[1, 1], A[0, 1]*B[1, 0], A[0, 1]*B[1, 1]],
  41. [A[1, 0]*B[0, 0], A[1, 0]*B[0, 1], A[1, 1]*B[0, 0], A[1, 1]*B[0, 1]],
  42. [A[1, 0]*B[1, 0], A[1, 0]*B[1, 1], A[1, 1]*B[1, 0], A[1, 1]*B[1, 1]]])
  43. For explicit matrices the Kronecker product is returned as a Matrix
  44. >>> from sympy import Matrix, kronecker_product
  45. >>> sigma_x = Matrix([
  46. ... [0, 1],
  47. ... [1, 0]])
  48. ...
  49. >>> Isigma_y = Matrix([
  50. ... [0, 1],
  51. ... [-1, 0]])
  52. ...
  53. >>> kronecker_product(sigma_x, Isigma_y)
  54. Matrix([
  55. [ 0, 0, 0, 1],
  56. [ 0, 0, -1, 0],
  57. [ 0, 1, 0, 0],
  58. [-1, 0, 0, 0]])
  59. See Also
  60. ========
  61. KroneckerProduct
  62. """
  63. if not matrices:
  64. raise TypeError("Empty Kronecker product is undefined")
  65. validate(*matrices)
  66. if len(matrices) == 1:
  67. return matrices[0]
  68. else:
  69. return KroneckerProduct(*matrices).doit()
  70. class KroneckerProduct(MatrixExpr):
  71. """
  72. The Kronecker product of two or more arguments.
  73. The Kronecker product is a non-commutative product of matrices.
  74. Given two matrices of dimension (m, n) and (s, t) it produces a matrix
  75. of dimension (m s, n t).
  76. This is a symbolic object that simply stores its argument without
  77. evaluating it. To actually compute the product, use the function
  78. ``kronecker_product()`` or call the ``.doit()`` or ``.as_explicit()``
  79. methods.
  80. >>> from sympy import KroneckerProduct, MatrixSymbol
  81. >>> A = MatrixSymbol('A', 5, 5)
  82. >>> B = MatrixSymbol('B', 5, 5)
  83. >>> isinstance(KroneckerProduct(A, B), KroneckerProduct)
  84. True
  85. """
  86. is_KroneckerProduct = True
  87. def __new__(cls, *args, check=True):
  88. args = list(map(sympify, args))
  89. if all(a.is_Identity for a in args):
  90. ret = Identity(prod(a.rows for a in args))
  91. if all(isinstance(a, MatrixBase) for a in args):
  92. return ret.as_explicit()
  93. else:
  94. return ret
  95. if check:
  96. validate(*args)
  97. return super().__new__(cls, *args)
  98. @property
  99. def shape(self):
  100. rows, cols = self.args[0].shape
  101. for mat in self.args[1:]:
  102. rows *= mat.rows
  103. cols *= mat.cols
  104. return (rows, cols)
  105. def _entry(self, i, j, **kwargs):
  106. result = 1
  107. for mat in reversed(self.args):
  108. i, m = divmod(i, mat.rows)
  109. j, n = divmod(j, mat.cols)
  110. result *= mat[m, n]
  111. return result
  112. def _eval_adjoint(self):
  113. return KroneckerProduct(*list(map(adjoint, self.args))).doit()
  114. def _eval_conjugate(self):
  115. return KroneckerProduct(*[a.conjugate() for a in self.args]).doit()
  116. def _eval_transpose(self):
  117. return KroneckerProduct(*list(map(transpose, self.args))).doit()
  118. def _eval_trace(self):
  119. from .trace import trace
  120. return prod(trace(a) for a in self.args)
  121. def _eval_determinant(self):
  122. from .determinant import det, Determinant
  123. if not all(a.is_square for a in self.args):
  124. return Determinant(self)
  125. m = self.rows
  126. return prod(det(a)**(m/a.rows) for a in self.args)
  127. def _eval_inverse(self):
  128. try:
  129. return KroneckerProduct(*[a.inverse() for a in self.args])
  130. except ShapeError:
  131. from sympy.matrices.expressions.inverse import Inverse
  132. return Inverse(self)
  133. def structurally_equal(self, other):
  134. '''Determine whether two matrices have the same Kronecker product structure
  135. Examples
  136. ========
  137. >>> from sympy import KroneckerProduct, MatrixSymbol, symbols
  138. >>> m, n = symbols(r'm, n', integer=True)
  139. >>> A = MatrixSymbol('A', m, m)
  140. >>> B = MatrixSymbol('B', n, n)
  141. >>> C = MatrixSymbol('C', m, m)
  142. >>> D = MatrixSymbol('D', n, n)
  143. >>> KroneckerProduct(A, B).structurally_equal(KroneckerProduct(C, D))
  144. True
  145. >>> KroneckerProduct(A, B).structurally_equal(KroneckerProduct(D, C))
  146. False
  147. >>> KroneckerProduct(A, B).structurally_equal(C)
  148. False
  149. '''
  150. # Inspired by BlockMatrix
  151. return (isinstance(other, KroneckerProduct)
  152. and self.shape == other.shape
  153. and len(self.args) == len(other.args)
  154. and all(a.shape == b.shape for (a, b) in zip(self.args, other.args)))
  155. def has_matching_shape(self, other):
  156. '''Determine whether two matrices have the appropriate structure to bring matrix
  157. multiplication inside the KroneckerProdut
  158. Examples
  159. ========
  160. >>> from sympy import KroneckerProduct, MatrixSymbol, symbols
  161. >>> m, n = symbols(r'm, n', integer=True)
  162. >>> A = MatrixSymbol('A', m, n)
  163. >>> B = MatrixSymbol('B', n, m)
  164. >>> KroneckerProduct(A, B).has_matching_shape(KroneckerProduct(B, A))
  165. True
  166. >>> KroneckerProduct(A, B).has_matching_shape(KroneckerProduct(A, B))
  167. False
  168. >>> KroneckerProduct(A, B).has_matching_shape(A)
  169. False
  170. '''
  171. return (isinstance(other, KroneckerProduct)
  172. and self.cols == other.rows
  173. and len(self.args) == len(other.args)
  174. and all(a.cols == b.rows for (a, b) in zip(self.args, other.args)))
  175. def _eval_expand_kroneckerproduct(self, **hints):
  176. return flatten(canon(typed({KroneckerProduct: distribute(KroneckerProduct, MatAdd)}))(self))
  177. def _kronecker_add(self, other):
  178. if self.structurally_equal(other):
  179. return self.__class__(*[a + b for (a, b) in zip(self.args, other.args)])
  180. else:
  181. return self + other
  182. def _kronecker_mul(self, other):
  183. if self.has_matching_shape(other):
  184. return self.__class__(*[a*b for (a, b) in zip(self.args, other.args)])
  185. else:
  186. return self * other
  187. def doit(self, **kwargs):
  188. deep = kwargs.get('deep', True)
  189. if deep:
  190. args = [arg.doit(**kwargs) for arg in self.args]
  191. else:
  192. args = self.args
  193. return canonicalize(KroneckerProduct(*args))
  194. def validate(*args):
  195. if not all(arg.is_Matrix for arg in args):
  196. raise TypeError("Mix of Matrix and Scalar symbols")
  197. # rules
  198. def extract_commutative(kron):
  199. c_part = []
  200. nc_part = []
  201. for arg in kron.args:
  202. c, nc = arg.args_cnc()
  203. c_part.extend(c)
  204. nc_part.append(Mul._from_args(nc))
  205. c_part = Mul(*c_part)
  206. if c_part != 1:
  207. return c_part*KroneckerProduct(*nc_part)
  208. return kron
  209. def matrix_kronecker_product(*matrices):
  210. """Compute the Kronecker product of a sequence of SymPy Matrices.
  211. This is the standard Kronecker product of matrices [1].
  212. Parameters
  213. ==========
  214. matrices : tuple of MatrixBase instances
  215. The matrices to take the Kronecker product of.
  216. Returns
  217. =======
  218. matrix : MatrixBase
  219. The Kronecker product matrix.
  220. Examples
  221. ========
  222. >>> from sympy import Matrix
  223. >>> from sympy.matrices.expressions.kronecker import (
  224. ... matrix_kronecker_product)
  225. >>> m1 = Matrix([[1,2],[3,4]])
  226. >>> m2 = Matrix([[1,0],[0,1]])
  227. >>> matrix_kronecker_product(m1, m2)
  228. Matrix([
  229. [1, 0, 2, 0],
  230. [0, 1, 0, 2],
  231. [3, 0, 4, 0],
  232. [0, 3, 0, 4]])
  233. >>> matrix_kronecker_product(m2, m1)
  234. Matrix([
  235. [1, 2, 0, 0],
  236. [3, 4, 0, 0],
  237. [0, 0, 1, 2],
  238. [0, 0, 3, 4]])
  239. References
  240. ==========
  241. .. [1] https://en.wikipedia.org/wiki/Kronecker_product
  242. """
  243. # Make sure we have a sequence of Matrices
  244. if not all(isinstance(m, MatrixBase) for m in matrices):
  245. raise TypeError(
  246. 'Sequence of Matrices expected, got: %s' % repr(matrices)
  247. )
  248. # Pull out the first element in the product.
  249. matrix_expansion = matrices[-1]
  250. # Do the kronecker product working from right to left.
  251. for mat in reversed(matrices[:-1]):
  252. rows = mat.rows
  253. cols = mat.cols
  254. # Go through each row appending kronecker product to.
  255. # running matrix_expansion.
  256. for i in range(rows):
  257. start = matrix_expansion*mat[i*cols]
  258. # Go through each column joining each item
  259. for j in range(cols - 1):
  260. start = start.row_join(
  261. matrix_expansion*mat[i*cols + j + 1]
  262. )
  263. # If this is the first element, make it the start of the
  264. # new row.
  265. if i == 0:
  266. next = start
  267. else:
  268. next = next.col_join(start)
  269. matrix_expansion = next
  270. MatrixClass = max(matrices, key=lambda M: M._class_priority).__class__
  271. if isinstance(matrix_expansion, MatrixClass):
  272. return matrix_expansion
  273. else:
  274. return MatrixClass(matrix_expansion)
  275. def explicit_kronecker_product(kron):
  276. # Make sure we have a sequence of Matrices
  277. if not all(isinstance(m, MatrixBase) for m in kron.args):
  278. return kron
  279. return matrix_kronecker_product(*kron.args)
  280. rules = (unpack,
  281. explicit_kronecker_product,
  282. flatten,
  283. extract_commutative)
  284. canonicalize = exhaust(condition(lambda x: isinstance(x, KroneckerProduct),
  285. do_one(*rules)))
  286. def _kronecker_dims_key(expr):
  287. if isinstance(expr, KroneckerProduct):
  288. return tuple(a.shape for a in expr.args)
  289. else:
  290. return (0,)
  291. def kronecker_mat_add(expr):
  292. from functools import reduce
  293. args = sift(expr.args, _kronecker_dims_key)
  294. nonkrons = args.pop((0,), None)
  295. if not args:
  296. return expr
  297. krons = [reduce(lambda x, y: x._kronecker_add(y), group)
  298. for group in args.values()]
  299. if not nonkrons:
  300. return MatAdd(*krons)
  301. else:
  302. return MatAdd(*krons) + nonkrons
  303. def kronecker_mat_mul(expr):
  304. # modified from block matrix code
  305. factor, matrices = expr.as_coeff_matrices()
  306. i = 0
  307. while i < len(matrices) - 1:
  308. A, B = matrices[i:i+2]
  309. if isinstance(A, KroneckerProduct) and isinstance(B, KroneckerProduct):
  310. matrices[i] = A._kronecker_mul(B)
  311. matrices.pop(i+1)
  312. else:
  313. i += 1
  314. return factor*MatMul(*matrices)
  315. def kronecker_mat_pow(expr):
  316. if isinstance(expr.base, KroneckerProduct) and all(a.is_square for a in expr.base.args):
  317. return KroneckerProduct(*[MatPow(a, expr.exp) for a in expr.base.args])
  318. else:
  319. return expr
  320. def combine_kronecker(expr):
  321. """Combine KronekeckerProduct with expression.
  322. If possible write operations on KroneckerProducts of compatible shapes
  323. as a single KroneckerProduct.
  324. Examples
  325. ========
  326. >>> from sympy.matrices.expressions import combine_kronecker
  327. >>> from sympy import MatrixSymbol, KroneckerProduct, symbols
  328. >>> m, n = symbols(r'm, n', integer=True)
  329. >>> A = MatrixSymbol('A', m, n)
  330. >>> B = MatrixSymbol('B', n, m)
  331. >>> combine_kronecker(KroneckerProduct(A, B)*KroneckerProduct(B, A))
  332. KroneckerProduct(A*B, B*A)
  333. >>> combine_kronecker(KroneckerProduct(A, B)+KroneckerProduct(B.T, A.T))
  334. KroneckerProduct(A + B.T, B + A.T)
  335. >>> C = MatrixSymbol('C', n, n)
  336. >>> D = MatrixSymbol('D', m, m)
  337. >>> combine_kronecker(KroneckerProduct(C, D)**m)
  338. KroneckerProduct(C**m, D**m)
  339. """
  340. def haskron(expr):
  341. return isinstance(expr, MatrixExpr) and expr.has(KroneckerProduct)
  342. rule = exhaust(
  343. bottom_up(exhaust(condition(haskron, typed(
  344. {MatAdd: kronecker_mat_add,
  345. MatMul: kronecker_mat_mul,
  346. MatPow: kronecker_mat_pow})))))
  347. result = rule(expr)
  348. doit = getattr(result, 'doit', None)
  349. if doit is not None:
  350. return doit()
  351. else:
  352. return result