blockmatrix.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969
  1. from sympy.assumptions.ask import (Q, ask)
  2. from sympy.core import Basic, Add, Mul, S
  3. from sympy.core.sympify import _sympify
  4. from sympy.functions.elementary.complexes import re, im
  5. from sympy.strategies import typed, exhaust, condition, do_one, unpack
  6. from sympy.strategies.traverse import bottom_up
  7. from sympy.utilities.iterables import is_sequence, sift
  8. from sympy.utilities.misc import filldedent
  9. from sympy.matrices import Matrix, ShapeError
  10. from sympy.matrices.common import NonInvertibleMatrixError
  11. from sympy.matrices.expressions.determinant import det, Determinant
  12. from sympy.matrices.expressions.inverse import Inverse
  13. from sympy.matrices.expressions.matadd import MatAdd
  14. from sympy.matrices.expressions.matexpr import MatrixExpr, MatrixElement
  15. from sympy.matrices.expressions.matmul import MatMul
  16. from sympy.matrices.expressions.matpow import MatPow
  17. from sympy.matrices.expressions.slice import MatrixSlice
  18. from sympy.matrices.expressions.special import ZeroMatrix, Identity
  19. from sympy.matrices.expressions.trace import trace
  20. from sympy.matrices.expressions.transpose import Transpose, transpose
  21. class BlockMatrix(MatrixExpr):
  22. """A BlockMatrix is a Matrix comprised of other matrices.
  23. The submatrices are stored in a SymPy Matrix object but accessed as part of
  24. a Matrix Expression
  25. >>> from sympy import (MatrixSymbol, BlockMatrix, symbols,
  26. ... Identity, ZeroMatrix, block_collapse)
  27. >>> n,m,l = symbols('n m l')
  28. >>> X = MatrixSymbol('X', n, n)
  29. >>> Y = MatrixSymbol('Y', m, m)
  30. >>> Z = MatrixSymbol('Z', n, m)
  31. >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m,n), Y]])
  32. >>> print(B)
  33. Matrix([
  34. [X, Z],
  35. [0, Y]])
  36. >>> C = BlockMatrix([[Identity(n), Z]])
  37. >>> print(C)
  38. Matrix([[I, Z]])
  39. >>> print(block_collapse(C*B))
  40. Matrix([[X, Z + Z*Y]])
  41. Some matrices might be comprised of rows of blocks with
  42. the matrices in each row having the same height and the
  43. rows all having the same total number of columns but
  44. not having the same number of columns for each matrix
  45. in each row. In this case, the matrix is not a block
  46. matrix and should be instantiated by Matrix.
  47. >>> from sympy import ones, Matrix
  48. >>> dat = [
  49. ... [ones(3,2), ones(3,3)*2],
  50. ... [ones(2,3)*3, ones(2,2)*4]]
  51. ...
  52. >>> BlockMatrix(dat)
  53. Traceback (most recent call last):
  54. ...
  55. ValueError:
  56. Although this matrix is comprised of blocks, the blocks do not fill
  57. the matrix in a size-symmetric fashion. To create a full matrix from
  58. these arguments, pass them directly to Matrix.
  59. >>> Matrix(dat)
  60. Matrix([
  61. [1, 1, 2, 2, 2],
  62. [1, 1, 2, 2, 2],
  63. [1, 1, 2, 2, 2],
  64. [3, 3, 3, 4, 4],
  65. [3, 3, 3, 4, 4]])
  66. See Also
  67. ========
  68. sympy.matrices.matrices.MatrixBase.irregular
  69. """
  70. def __new__(cls, *args, **kwargs):
  71. from sympy.matrices.immutable import ImmutableDenseMatrix
  72. isMat = lambda i: getattr(i, 'is_Matrix', False)
  73. if len(args) != 1 or \
  74. not is_sequence(args[0]) or \
  75. len({isMat(r) for r in args[0]}) != 1:
  76. raise ValueError(filldedent('''
  77. expecting a sequence of 1 or more rows
  78. containing Matrices.'''))
  79. rows = args[0] if args else []
  80. if not isMat(rows):
  81. if rows and isMat(rows[0]):
  82. rows = [rows] # rows is not list of lists or []
  83. # regularity check
  84. # same number of matrices in each row
  85. blocky = ok = len({len(r) for r in rows}) == 1
  86. if ok:
  87. # same number of rows for each matrix in a row
  88. for r in rows:
  89. ok = len({i.rows for i in r}) == 1
  90. if not ok:
  91. break
  92. blocky = ok
  93. if ok:
  94. # same number of cols for each matrix in each col
  95. for c in range(len(rows[0])):
  96. ok = len({rows[i][c].cols
  97. for i in range(len(rows))}) == 1
  98. if not ok:
  99. break
  100. if not ok:
  101. # same total cols in each row
  102. ok = len({
  103. sum([i.cols for i in r]) for r in rows}) == 1
  104. if blocky and ok:
  105. raise ValueError(filldedent('''
  106. Although this matrix is comprised of blocks,
  107. the blocks do not fill the matrix in a
  108. size-symmetric fashion. To create a full matrix
  109. from these arguments, pass them directly to
  110. Matrix.'''))
  111. raise ValueError(filldedent('''
  112. When there are not the same number of rows in each
  113. row's matrices or there are not the same number of
  114. total columns in each row, the matrix is not a
  115. block matrix. If this matrix is known to consist of
  116. blocks fully filling a 2-D space then see
  117. Matrix.irregular.'''))
  118. mat = ImmutableDenseMatrix(rows, evaluate=False)
  119. obj = Basic.__new__(cls, mat)
  120. return obj
  121. @property
  122. def shape(self):
  123. numrows = numcols = 0
  124. M = self.blocks
  125. for i in range(M.shape[0]):
  126. numrows += M[i, 0].shape[0]
  127. for i in range(M.shape[1]):
  128. numcols += M[0, i].shape[1]
  129. return (numrows, numcols)
  130. @property
  131. def blockshape(self):
  132. return self.blocks.shape
  133. @property
  134. def blocks(self):
  135. return self.args[0]
  136. @property
  137. def rowblocksizes(self):
  138. return [self.blocks[i, 0].rows for i in range(self.blockshape[0])]
  139. @property
  140. def colblocksizes(self):
  141. return [self.blocks[0, i].cols for i in range(self.blockshape[1])]
  142. def structurally_equal(self, other):
  143. return (isinstance(other, BlockMatrix)
  144. and self.shape == other.shape
  145. and self.blockshape == other.blockshape
  146. and self.rowblocksizes == other.rowblocksizes
  147. and self.colblocksizes == other.colblocksizes)
  148. def _blockmul(self, other):
  149. if (isinstance(other, BlockMatrix) and
  150. self.colblocksizes == other.rowblocksizes):
  151. return BlockMatrix(self.blocks*other.blocks)
  152. return self * other
  153. def _blockadd(self, other):
  154. if (isinstance(other, BlockMatrix)
  155. and self.structurally_equal(other)):
  156. return BlockMatrix(self.blocks + other.blocks)
  157. return self + other
  158. def _eval_transpose(self):
  159. # Flip all the individual matrices
  160. matrices = [transpose(matrix) for matrix in self.blocks]
  161. # Make a copy
  162. M = Matrix(self.blockshape[0], self.blockshape[1], matrices)
  163. # Transpose the block structure
  164. M = M.transpose()
  165. return BlockMatrix(M)
  166. def _eval_trace(self):
  167. if self.rowblocksizes == self.colblocksizes:
  168. return Add(*[trace(self.blocks[i, i])
  169. for i in range(self.blockshape[0])])
  170. raise NotImplementedError(
  171. "Can't perform trace of irregular blockshape")
  172. def _eval_determinant(self):
  173. if self.blockshape == (1, 1):
  174. return det(self.blocks[0, 0])
  175. if self.blockshape == (2, 2):
  176. [[A, B],
  177. [C, D]] = self.blocks.tolist()
  178. if ask(Q.invertible(A)):
  179. return det(A)*det(D - C*A.I*B)
  180. elif ask(Q.invertible(D)):
  181. return det(D)*det(A - B*D.I*C)
  182. return Determinant(self)
  183. def as_real_imag(self):
  184. real_matrices = [re(matrix) for matrix in self.blocks]
  185. real_matrices = Matrix(self.blockshape[0], self.blockshape[1], real_matrices)
  186. im_matrices = [im(matrix) for matrix in self.blocks]
  187. im_matrices = Matrix(self.blockshape[0], self.blockshape[1], im_matrices)
  188. return (real_matrices, im_matrices)
  189. def transpose(self):
  190. """Return transpose of matrix.
  191. Examples
  192. ========
  193. >>> from sympy import MatrixSymbol, BlockMatrix, ZeroMatrix
  194. >>> from sympy.abc import m, n
  195. >>> X = MatrixSymbol('X', n, n)
  196. >>> Y = MatrixSymbol('Y', m, m)
  197. >>> Z = MatrixSymbol('Z', n, m)
  198. >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m,n), Y]])
  199. >>> B.transpose()
  200. Matrix([
  201. [X.T, 0],
  202. [Z.T, Y.T]])
  203. >>> _.transpose()
  204. Matrix([
  205. [X, Z],
  206. [0, Y]])
  207. """
  208. return self._eval_transpose()
  209. def schur(self, mat = 'A', generalized = False):
  210. """Return the Schur Complement of the 2x2 BlockMatrix
  211. Parameters
  212. ==========
  213. mat : String, optional
  214. The matrix with respect to which the
  215. Schur Complement is calculated. 'A' is
  216. used by default
  217. generalized : bool, optional
  218. If True, returns the generalized Schur
  219. Component which uses Moore-Penrose Inverse
  220. Examples
  221. ========
  222. >>> from sympy import symbols, MatrixSymbol, BlockMatrix
  223. >>> m, n = symbols('m n')
  224. >>> A = MatrixSymbol('A', n, n)
  225. >>> B = MatrixSymbol('B', n, m)
  226. >>> C = MatrixSymbol('C', m, n)
  227. >>> D = MatrixSymbol('D', m, m)
  228. >>> X = BlockMatrix([[A, B], [C, D]])
  229. The default Schur Complement is evaluated with "A"
  230. >>> X.schur()
  231. -C*A**(-1)*B + D
  232. >>> X.schur('D')
  233. A - B*D**(-1)*C
  234. Schur complement with non-invertible matrices is not
  235. defined. Instead, the generalized Schur complement can
  236. be calculated which uses the Moore-Penrose Inverse. To
  237. achieve this, `generalized` must be set to `True`
  238. >>> X.schur('B', generalized=True)
  239. C - D*(B.T*B)**(-1)*B.T*A
  240. >>> X.schur('C', generalized=True)
  241. -A*(C.T*C)**(-1)*C.T*D + B
  242. Returns
  243. =======
  244. M : Matrix
  245. The Schur Complement Matrix
  246. Raises
  247. ======
  248. ShapeError
  249. If the block matrix is not a 2x2 matrix
  250. NonInvertibleMatrixError
  251. If given matrix is non-invertible
  252. References
  253. ==========
  254. .. [1] Wikipedia Article on Schur Component : https://en.wikipedia.org/wiki/Schur_complement
  255. See Also
  256. ========
  257. sympy.matrices.matrices.MatrixBase.pinv
  258. """
  259. if self.blockshape == (2, 2):
  260. [[A, B],
  261. [C, D]] = self.blocks.tolist()
  262. d={'A' : A, 'B' : B, 'C' : C, 'D' : D}
  263. try:
  264. inv = (d[mat].T*d[mat]).inv()*d[mat].T if generalized else d[mat].inv()
  265. if mat == 'A':
  266. return D - C * inv * B
  267. elif mat == 'B':
  268. return C - D * inv * A
  269. elif mat == 'C':
  270. return B - A * inv * D
  271. elif mat == 'D':
  272. return A - B * inv * C
  273. #For matrices where no sub-matrix is square
  274. return self
  275. except NonInvertibleMatrixError:
  276. raise NonInvertibleMatrixError('The given matrix is not invertible. Please set generalized=True \
  277. to compute the generalized Schur Complement which uses Moore-Penrose Inverse')
  278. else:
  279. raise ShapeError('Schur Complement can only be calculated for 2x2 block matrices')
  280. def LDUdecomposition(self):
  281. """Returns the Block LDU decomposition of
  282. a 2x2 Block Matrix
  283. Returns
  284. =======
  285. (L, D, U) : Matrices
  286. L : Lower Diagonal Matrix
  287. D : Diagonal Matrix
  288. U : Upper Diagonal Matrix
  289. Examples
  290. ========
  291. >>> from sympy import symbols, MatrixSymbol, BlockMatrix, block_collapse
  292. >>> m, n = symbols('m n')
  293. >>> A = MatrixSymbol('A', n, n)
  294. >>> B = MatrixSymbol('B', n, m)
  295. >>> C = MatrixSymbol('C', m, n)
  296. >>> D = MatrixSymbol('D', m, m)
  297. >>> X = BlockMatrix([[A, B], [C, D]])
  298. >>> L, D, U = X.LDUdecomposition()
  299. >>> block_collapse(L*D*U)
  300. Matrix([
  301. [A, B],
  302. [C, D]])
  303. Raises
  304. ======
  305. ShapeError
  306. If the block matrix is not a 2x2 matrix
  307. NonInvertibleMatrixError
  308. If the matrix "A" is non-invertible
  309. See Also
  310. ========
  311. sympy.matrices.expressions.blockmatrix.BlockMatrix.UDLdecomposition
  312. sympy.matrices.expressions.blockmatrix.BlockMatrix.LUdecomposition
  313. """
  314. if self.blockshape == (2,2):
  315. [[A, B],
  316. [C, D]] = self.blocks.tolist()
  317. try:
  318. AI = A.I
  319. except NonInvertibleMatrixError:
  320. raise NonInvertibleMatrixError('Block LDU decomposition cannot be calculated when\
  321. "A" is singular')
  322. Ip = Identity(B.shape[0])
  323. Iq = Identity(B.shape[1])
  324. Z = ZeroMatrix(*B.shape)
  325. L = BlockMatrix([[Ip, Z], [C*AI, Iq]])
  326. D = BlockDiagMatrix(A, self.schur())
  327. U = BlockMatrix([[Ip, AI*B],[Z.T, Iq]])
  328. return L, D, U
  329. else:
  330. raise ShapeError("Block LDU decomposition is supported only for 2x2 block matrices")
  331. def UDLdecomposition(self):
  332. """Returns the Block UDL decomposition of
  333. a 2x2 Block Matrix
  334. Returns
  335. =======
  336. (U, D, L) : Matrices
  337. U : Upper Diagonal Matrix
  338. D : Diagonal Matrix
  339. L : Lower Diagonal Matrix
  340. Examples
  341. ========
  342. >>> from sympy import symbols, MatrixSymbol, BlockMatrix, block_collapse
  343. >>> m, n = symbols('m n')
  344. >>> A = MatrixSymbol('A', n, n)
  345. >>> B = MatrixSymbol('B', n, m)
  346. >>> C = MatrixSymbol('C', m, n)
  347. >>> D = MatrixSymbol('D', m, m)
  348. >>> X = BlockMatrix([[A, B], [C, D]])
  349. >>> U, D, L = X.UDLdecomposition()
  350. >>> block_collapse(U*D*L)
  351. Matrix([
  352. [A, B],
  353. [C, D]])
  354. Raises
  355. ======
  356. ShapeError
  357. If the block matrix is not a 2x2 matrix
  358. NonInvertibleMatrixError
  359. If the matrix "D" is non-invertible
  360. See Also
  361. ========
  362. sympy.matrices.expressions.blockmatrix.BlockMatrix.LDUdecomposition
  363. sympy.matrices.expressions.blockmatrix.BlockMatrix.LUdecomposition
  364. """
  365. if self.blockshape == (2,2):
  366. [[A, B],
  367. [C, D]] = self.blocks.tolist()
  368. try:
  369. DI = D.I
  370. except NonInvertibleMatrixError:
  371. raise NonInvertibleMatrixError('Block UDL decomposition cannot be calculated when\
  372. "D" is singular')
  373. Ip = Identity(A.shape[0])
  374. Iq = Identity(B.shape[1])
  375. Z = ZeroMatrix(*B.shape)
  376. U = BlockMatrix([[Ip, B*DI], [Z.T, Iq]])
  377. D = BlockDiagMatrix(self.schur('D'), D)
  378. L = BlockMatrix([[Ip, Z],[DI*C, Iq]])
  379. return U, D, L
  380. else:
  381. raise ShapeError("Block UDL decomposition is supported only for 2x2 block matrices")
  382. def LUdecomposition(self):
  383. """Returns the Block LU decomposition of
  384. a 2x2 Block Matrix
  385. Returns
  386. =======
  387. (L, U) : Matrices
  388. L : Lower Diagonal Matrix
  389. U : Upper Diagonal Matrix
  390. Examples
  391. ========
  392. >>> from sympy import symbols, MatrixSymbol, BlockMatrix, block_collapse
  393. >>> m, n = symbols('m n')
  394. >>> A = MatrixSymbol('A', n, n)
  395. >>> B = MatrixSymbol('B', n, m)
  396. >>> C = MatrixSymbol('C', m, n)
  397. >>> D = MatrixSymbol('D', m, m)
  398. >>> X = BlockMatrix([[A, B], [C, D]])
  399. >>> L, U = X.LUdecomposition()
  400. >>> block_collapse(L*U)
  401. Matrix([
  402. [A, B],
  403. [C, D]])
  404. Raises
  405. ======
  406. ShapeError
  407. If the block matrix is not a 2x2 matrix
  408. NonInvertibleMatrixError
  409. If the matrix "A" is non-invertible
  410. See Also
  411. ========
  412. sympy.matrices.expressions.blockmatrix.BlockMatrix.UDLdecomposition
  413. sympy.matrices.expressions.blockmatrix.BlockMatrix.LDUdecomposition
  414. """
  415. if self.blockshape == (2,2):
  416. [[A, B],
  417. [C, D]] = self.blocks.tolist()
  418. try:
  419. A = A**0.5
  420. AI = A.I
  421. except NonInvertibleMatrixError:
  422. raise NonInvertibleMatrixError('Block LU decomposition cannot be calculated when\
  423. "A" is singular')
  424. Z = ZeroMatrix(*B.shape)
  425. Q = self.schur()**0.5
  426. L = BlockMatrix([[A, Z], [C*AI, Q]])
  427. U = BlockMatrix([[A, AI*B],[Z.T, Q]])
  428. return L, U
  429. else:
  430. raise ShapeError("Block LU decomposition is supported only for 2x2 block matrices")
  431. def _entry(self, i, j, **kwargs):
  432. # Find row entry
  433. orig_i, orig_j = i, j
  434. for row_block, numrows in enumerate(self.rowblocksizes):
  435. cmp = i < numrows
  436. if cmp == True:
  437. break
  438. elif cmp == False:
  439. i -= numrows
  440. elif row_block < self.blockshape[0] - 1:
  441. # Can't tell which block and it's not the last one, return unevaluated
  442. return MatrixElement(self, orig_i, orig_j)
  443. for col_block, numcols in enumerate(self.colblocksizes):
  444. cmp = j < numcols
  445. if cmp == True:
  446. break
  447. elif cmp == False:
  448. j -= numcols
  449. elif col_block < self.blockshape[1] - 1:
  450. return MatrixElement(self, orig_i, orig_j)
  451. return self.blocks[row_block, col_block][i, j]
  452. @property
  453. def is_Identity(self):
  454. if self.blockshape[0] != self.blockshape[1]:
  455. return False
  456. for i in range(self.blockshape[0]):
  457. for j in range(self.blockshape[1]):
  458. if i==j and not self.blocks[i, j].is_Identity:
  459. return False
  460. if i!=j and not self.blocks[i, j].is_ZeroMatrix:
  461. return False
  462. return True
  463. @property
  464. def is_structurally_symmetric(self):
  465. return self.rowblocksizes == self.colblocksizes
  466. def equals(self, other):
  467. if self == other:
  468. return True
  469. if (isinstance(other, BlockMatrix) and self.blocks == other.blocks):
  470. return True
  471. return super().equals(other)
  472. class BlockDiagMatrix(BlockMatrix):
  473. """A sparse matrix with block matrices along its diagonals
  474. Examples
  475. ========
  476. >>> from sympy import MatrixSymbol, BlockDiagMatrix, symbols
  477. >>> n, m, l = symbols('n m l')
  478. >>> X = MatrixSymbol('X', n, n)
  479. >>> Y = MatrixSymbol('Y', m, m)
  480. >>> BlockDiagMatrix(X, Y)
  481. Matrix([
  482. [X, 0],
  483. [0, Y]])
  484. Notes
  485. =====
  486. If you want to get the individual diagonal blocks, use
  487. :meth:`get_diag_blocks`.
  488. See Also
  489. ========
  490. sympy.matrices.dense.diag
  491. """
  492. def __new__(cls, *mats):
  493. return Basic.__new__(BlockDiagMatrix, *[_sympify(m) for m in mats])
  494. @property
  495. def diag(self):
  496. return self.args
  497. @property
  498. def blocks(self):
  499. from sympy.matrices.immutable import ImmutableDenseMatrix
  500. mats = self.args
  501. data = [[mats[i] if i == j else ZeroMatrix(mats[i].rows, mats[j].cols)
  502. for j in range(len(mats))]
  503. for i in range(len(mats))]
  504. return ImmutableDenseMatrix(data, evaluate=False)
  505. @property
  506. def shape(self):
  507. return (sum(block.rows for block in self.args),
  508. sum(block.cols for block in self.args))
  509. @property
  510. def blockshape(self):
  511. n = len(self.args)
  512. return (n, n)
  513. @property
  514. def rowblocksizes(self):
  515. return [block.rows for block in self.args]
  516. @property
  517. def colblocksizes(self):
  518. return [block.cols for block in self.args]
  519. def _all_square_blocks(self):
  520. """Returns true if all blocks are square"""
  521. return all(mat.is_square for mat in self.args)
  522. def _eval_determinant(self):
  523. if self._all_square_blocks():
  524. return Mul(*[det(mat) for mat in self.args])
  525. # At least one block is non-square. Since the entire matrix must be square we know there must
  526. # be at least two blocks in this matrix, in which case the entire matrix is necessarily rank-deficient
  527. return S.Zero
  528. def _eval_inverse(self, expand='ignored'):
  529. if self._all_square_blocks():
  530. return BlockDiagMatrix(*[mat.inverse() for mat in self.args])
  531. # See comment in _eval_determinant()
  532. raise NonInvertibleMatrixError('Matrix det == 0; not invertible.')
  533. def _eval_transpose(self):
  534. return BlockDiagMatrix(*[mat.transpose() for mat in self.args])
  535. def _blockmul(self, other):
  536. if (isinstance(other, BlockDiagMatrix) and
  537. self.colblocksizes == other.rowblocksizes):
  538. return BlockDiagMatrix(*[a*b for a, b in zip(self.args, other.args)])
  539. else:
  540. return BlockMatrix._blockmul(self, other)
  541. def _blockadd(self, other):
  542. if (isinstance(other, BlockDiagMatrix) and
  543. self.blockshape == other.blockshape and
  544. self.rowblocksizes == other.rowblocksizes and
  545. self.colblocksizes == other.colblocksizes):
  546. return BlockDiagMatrix(*[a + b for a, b in zip(self.args, other.args)])
  547. else:
  548. return BlockMatrix._blockadd(self, other)
  549. def get_diag_blocks(self):
  550. """Return the list of diagonal blocks of the matrix.
  551. Examples
  552. ========
  553. >>> from sympy import BlockDiagMatrix, Matrix
  554. >>> A = Matrix([[1, 2], [3, 4]])
  555. >>> B = Matrix([[5, 6], [7, 8]])
  556. >>> M = BlockDiagMatrix(A, B)
  557. How to get diagonal blocks from the block diagonal matrix:
  558. >>> diag_blocks = M.get_diag_blocks()
  559. >>> diag_blocks[0]
  560. Matrix([
  561. [1, 2],
  562. [3, 4]])
  563. >>> diag_blocks[1]
  564. Matrix([
  565. [5, 6],
  566. [7, 8]])
  567. """
  568. return self.args
  569. def block_collapse(expr):
  570. """Evaluates a block matrix expression
  571. >>> from sympy import MatrixSymbol, BlockMatrix, symbols, Identity, ZeroMatrix, block_collapse
  572. >>> n,m,l = symbols('n m l')
  573. >>> X = MatrixSymbol('X', n, n)
  574. >>> Y = MatrixSymbol('Y', m, m)
  575. >>> Z = MatrixSymbol('Z', n, m)
  576. >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m, n), Y]])
  577. >>> print(B)
  578. Matrix([
  579. [X, Z],
  580. [0, Y]])
  581. >>> C = BlockMatrix([[Identity(n), Z]])
  582. >>> print(C)
  583. Matrix([[I, Z]])
  584. >>> print(block_collapse(C*B))
  585. Matrix([[X, Z + Z*Y]])
  586. """
  587. from sympy.strategies.util import expr_fns
  588. hasbm = lambda expr: isinstance(expr, MatrixExpr) and expr.has(BlockMatrix)
  589. conditioned_rl = condition(
  590. hasbm,
  591. typed(
  592. {MatAdd: do_one(bc_matadd, bc_block_plus_ident),
  593. MatMul: do_one(bc_matmul, bc_dist),
  594. MatPow: bc_matmul,
  595. Transpose: bc_transpose,
  596. Inverse: bc_inverse,
  597. BlockMatrix: do_one(bc_unpack, deblock)}
  598. )
  599. )
  600. rule = exhaust(
  601. bottom_up(
  602. exhaust(conditioned_rl),
  603. fns=expr_fns
  604. )
  605. )
  606. result = rule(expr)
  607. doit = getattr(result, 'doit', None)
  608. if doit is not None:
  609. return doit()
  610. else:
  611. return result
  612. def bc_unpack(expr):
  613. if expr.blockshape == (1, 1):
  614. return expr.blocks[0, 0]
  615. return expr
  616. def bc_matadd(expr):
  617. args = sift(expr.args, lambda M: isinstance(M, BlockMatrix))
  618. blocks = args[True]
  619. if not blocks:
  620. return expr
  621. nonblocks = args[False]
  622. block = blocks[0]
  623. for b in blocks[1:]:
  624. block = block._blockadd(b)
  625. if nonblocks:
  626. return MatAdd(*nonblocks) + block
  627. else:
  628. return block
  629. def bc_block_plus_ident(expr):
  630. idents = [arg for arg in expr.args if arg.is_Identity]
  631. if not idents:
  632. return expr
  633. blocks = [arg for arg in expr.args if isinstance(arg, BlockMatrix)]
  634. if (blocks and all(b.structurally_equal(blocks[0]) for b in blocks)
  635. and blocks[0].is_structurally_symmetric):
  636. block_id = BlockDiagMatrix(*[Identity(k)
  637. for k in blocks[0].rowblocksizes])
  638. rest = [arg for arg in expr.args if not arg.is_Identity and not isinstance(arg, BlockMatrix)]
  639. return MatAdd(block_id * len(idents), *blocks, *rest).doit()
  640. return expr
  641. def bc_dist(expr):
  642. """ Turn a*[X, Y] into [a*X, a*Y] """
  643. factor, mat = expr.as_coeff_mmul()
  644. if factor == 1:
  645. return expr
  646. unpacked = unpack(mat)
  647. if isinstance(unpacked, BlockDiagMatrix):
  648. B = unpacked.diag
  649. new_B = [factor * mat for mat in B]
  650. return BlockDiagMatrix(*new_B)
  651. elif isinstance(unpacked, BlockMatrix):
  652. B = unpacked.blocks
  653. new_B = [
  654. [factor * B[i, j] for j in range(B.cols)] for i in range(B.rows)]
  655. return BlockMatrix(new_B)
  656. return expr
  657. def bc_matmul(expr):
  658. if isinstance(expr, MatPow):
  659. if expr.args[1].is_Integer:
  660. factor, matrices = (1, [expr.args[0]]*expr.args[1])
  661. else:
  662. return expr
  663. else:
  664. factor, matrices = expr.as_coeff_matrices()
  665. i = 0
  666. while (i+1 < len(matrices)):
  667. A, B = matrices[i:i+2]
  668. if isinstance(A, BlockMatrix) and isinstance(B, BlockMatrix):
  669. matrices[i] = A._blockmul(B)
  670. matrices.pop(i+1)
  671. elif isinstance(A, BlockMatrix):
  672. matrices[i] = A._blockmul(BlockMatrix([[B]]))
  673. matrices.pop(i+1)
  674. elif isinstance(B, BlockMatrix):
  675. matrices[i] = BlockMatrix([[A]])._blockmul(B)
  676. matrices.pop(i+1)
  677. else:
  678. i+=1
  679. return MatMul(factor, *matrices).doit()
  680. def bc_transpose(expr):
  681. collapse = block_collapse(expr.arg)
  682. return collapse._eval_transpose()
  683. def bc_inverse(expr):
  684. if isinstance(expr.arg, BlockDiagMatrix):
  685. return expr.inverse()
  686. expr2 = blockinverse_1x1(expr)
  687. if expr != expr2:
  688. return expr2
  689. return blockinverse_2x2(Inverse(reblock_2x2(expr.arg)))
  690. def blockinverse_1x1(expr):
  691. if isinstance(expr.arg, BlockMatrix) and expr.arg.blockshape == (1, 1):
  692. mat = Matrix([[expr.arg.blocks[0].inverse()]])
  693. return BlockMatrix(mat)
  694. return expr
  695. def blockinverse_2x2(expr):
  696. if isinstance(expr.arg, BlockMatrix) and expr.arg.blockshape == (2, 2):
  697. # See: Inverses of 2x2 Block Matrices, Tzon-Tzer Lu and Sheng-Hua Shiou
  698. [[A, B],
  699. [C, D]] = expr.arg.blocks.tolist()
  700. formula = _choose_2x2_inversion_formula(A, B, C, D)
  701. if formula != None:
  702. MI = expr.arg.schur(formula).I
  703. if formula == 'A':
  704. AI = A.I
  705. return BlockMatrix([[AI + AI * B * MI * C * AI, -AI * B * MI], [-MI * C * AI, MI]])
  706. if formula == 'B':
  707. BI = B.I
  708. return BlockMatrix([[-MI * D * BI, MI], [BI + BI * A * MI * D * BI, -BI * A * MI]])
  709. if formula == 'C':
  710. CI = C.I
  711. return BlockMatrix([[-CI * D * MI, CI + CI * D * MI * A * CI], [MI, -MI * A * CI]])
  712. if formula == 'D':
  713. DI = D.I
  714. return BlockMatrix([[MI, -MI * B * DI], [-DI * C * MI, DI + DI * C * MI * B * DI]])
  715. return expr
  716. def _choose_2x2_inversion_formula(A, B, C, D):
  717. """
  718. Assuming [[A, B], [C, D]] would form a valid square block matrix, find
  719. which of the classical 2x2 block matrix inversion formulas would be
  720. best suited.
  721. Returns 'A', 'B', 'C', 'D' to represent the algorithm involving inversion
  722. of the given argument or None if the matrix cannot be inverted using
  723. any of those formulas.
  724. """
  725. # Try to find a known invertible matrix. Note that the Schur complement
  726. # is currently not being considered for this
  727. A_inv = ask(Q.invertible(A))
  728. if A_inv == True:
  729. return 'A'
  730. B_inv = ask(Q.invertible(B))
  731. if B_inv == True:
  732. return 'B'
  733. C_inv = ask(Q.invertible(C))
  734. if C_inv == True:
  735. return 'C'
  736. D_inv = ask(Q.invertible(D))
  737. if D_inv == True:
  738. return 'D'
  739. # Otherwise try to find a matrix that isn't known to be non-invertible
  740. if A_inv != False:
  741. return 'A'
  742. if B_inv != False:
  743. return 'B'
  744. if C_inv != False:
  745. return 'C'
  746. if D_inv != False:
  747. return 'D'
  748. return None
  749. def deblock(B):
  750. """ Flatten a BlockMatrix of BlockMatrices """
  751. if not isinstance(B, BlockMatrix) or not B.blocks.has(BlockMatrix):
  752. return B
  753. wrap = lambda x: x if isinstance(x, BlockMatrix) else BlockMatrix([[x]])
  754. bb = B.blocks.applyfunc(wrap) # everything is a block
  755. try:
  756. MM = Matrix(0, sum(bb[0, i].blocks.shape[1] for i in range(bb.shape[1])), [])
  757. for row in range(0, bb.shape[0]):
  758. M = Matrix(bb[row, 0].blocks)
  759. for col in range(1, bb.shape[1]):
  760. M = M.row_join(bb[row, col].blocks)
  761. MM = MM.col_join(M)
  762. return BlockMatrix(MM)
  763. except ShapeError:
  764. return B
  765. def reblock_2x2(expr):
  766. """
  767. Reblock a BlockMatrix so that it has 2x2 blocks of block matrices. If
  768. possible in such a way that the matrix continues to be invertible using the
  769. classical 2x2 block inversion formulas.
  770. """
  771. if not isinstance(expr, BlockMatrix) or not all(d > 2 for d in expr.blockshape):
  772. return expr
  773. BM = BlockMatrix # for brevity's sake
  774. rowblocks, colblocks = expr.blockshape
  775. blocks = expr.blocks
  776. for i in range(1, rowblocks):
  777. for j in range(1, colblocks):
  778. # try to split rows at i and cols at j
  779. A = bc_unpack(BM(blocks[:i, :j]))
  780. B = bc_unpack(BM(blocks[:i, j:]))
  781. C = bc_unpack(BM(blocks[i:, :j]))
  782. D = bc_unpack(BM(blocks[i:, j:]))
  783. formula = _choose_2x2_inversion_formula(A, B, C, D)
  784. if formula is not None:
  785. return BlockMatrix([[A, B], [C, D]])
  786. # else: nothing worked, just split upper left corner
  787. return BM([[blocks[0, 0], BM(blocks[0, 1:])],
  788. [BM(blocks[1:, 0]), BM(blocks[1:, 1:])]])
  789. def bounds(sizes):
  790. """ Convert sequence of numbers into pairs of low-high pairs
  791. >>> from sympy.matrices.expressions.blockmatrix import bounds
  792. >>> bounds((1, 10, 50))
  793. [(0, 1), (1, 11), (11, 61)]
  794. """
  795. low = 0
  796. rv = []
  797. for size in sizes:
  798. rv.append((low, low + size))
  799. low += size
  800. return rv
  801. def blockcut(expr, rowsizes, colsizes):
  802. """ Cut a matrix expression into Blocks
  803. >>> from sympy import ImmutableMatrix, blockcut
  804. >>> M = ImmutableMatrix(4, 4, range(16))
  805. >>> B = blockcut(M, (1, 3), (1, 3))
  806. >>> type(B).__name__
  807. 'BlockMatrix'
  808. >>> ImmutableMatrix(B.blocks[0, 1])
  809. Matrix([[1, 2, 3]])
  810. """
  811. rowbounds = bounds(rowsizes)
  812. colbounds = bounds(colsizes)
  813. return BlockMatrix([[MatrixSlice(expr, rowbound, colbound)
  814. for colbound in colbounds]
  815. for rowbound in rowbounds])