matrices.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716
  1. """
  2. This module contains query handlers responsible for Matrices queries:
  3. Square, Symmetric, Invertible etc.
  4. """
  5. from sympy.logic.boolalg import conjuncts
  6. from sympy.assumptions import Q, ask
  7. from sympy.assumptions.handlers import test_closed_group
  8. from sympy.matrices import MatrixBase
  9. from sympy.matrices.expressions import (BlockMatrix, BlockDiagMatrix, Determinant,
  10. DiagMatrix, DiagonalMatrix, HadamardProduct, Identity, Inverse, MatAdd, MatMul,
  11. MatPow, MatrixExpr, MatrixSlice, MatrixSymbol, OneMatrix, Trace, Transpose,
  12. ZeroMatrix)
  13. from sympy.matrices.expressions.factorizations import Factorization
  14. from sympy.matrices.expressions.fourier import DFT
  15. from sympy.core.logic import fuzzy_and
  16. from sympy.utilities.iterables import sift
  17. from sympy.core import Basic
  18. from ..predicates.matrices import (SquarePredicate, SymmetricPredicate,
  19. InvertiblePredicate, OrthogonalPredicate, UnitaryPredicate,
  20. FullRankPredicate, PositiveDefinitePredicate, UpperTriangularPredicate,
  21. LowerTriangularPredicate, DiagonalPredicate, IntegerElementsPredicate,
  22. RealElementsPredicate, ComplexElementsPredicate)
  23. def _Factorization(predicate, expr, assumptions):
  24. if predicate in expr.predicates:
  25. return True
  26. # SquarePredicate
  27. @SquarePredicate.register(MatrixExpr)
  28. def _(expr, assumptions):
  29. return expr.shape[0] == expr.shape[1]
  30. # SymmetricPredicate
  31. @SymmetricPredicate.register(MatMul)
  32. def _(expr, assumptions):
  33. factor, mmul = expr.as_coeff_mmul()
  34. if all(ask(Q.symmetric(arg), assumptions) for arg in mmul.args):
  35. return True
  36. # TODO: implement sathandlers system for the matrices.
  37. # Now it duplicates the general fact: Implies(Q.diagonal, Q.symmetric).
  38. if ask(Q.diagonal(expr), assumptions):
  39. return True
  40. if len(mmul.args) >= 2 and mmul.args[0] == mmul.args[-1].T:
  41. if len(mmul.args) == 2:
  42. return True
  43. return ask(Q.symmetric(MatMul(*mmul.args[1:-1])), assumptions)
  44. @SymmetricPredicate.register(MatPow)
  45. def _(expr, assumptions):
  46. # only for integer powers
  47. base, exp = expr.args
  48. int_exp = ask(Q.integer(exp), assumptions)
  49. if not int_exp:
  50. return None
  51. non_negative = ask(~Q.negative(exp), assumptions)
  52. if (non_negative or non_negative == False
  53. and ask(Q.invertible(base), assumptions)):
  54. return ask(Q.symmetric(base), assumptions)
  55. return None
  56. @SymmetricPredicate.register(MatAdd)
  57. def _(expr, assumptions):
  58. return all(ask(Q.symmetric(arg), assumptions) for arg in expr.args)
  59. @SymmetricPredicate.register(MatrixSymbol)
  60. def _(expr, assumptions):
  61. if not expr.is_square:
  62. return False
  63. # TODO: implement sathandlers system for the matrices.
  64. # Now it duplicates the general fact: Implies(Q.diagonal, Q.symmetric).
  65. if ask(Q.diagonal(expr), assumptions):
  66. return True
  67. if Q.symmetric(expr) in conjuncts(assumptions):
  68. return True
  69. @SymmetricPredicate.register_many(OneMatrix, ZeroMatrix)
  70. def _(expr, assumptions):
  71. return ask(Q.square(expr), assumptions)
  72. @SymmetricPredicate.register_many(Inverse, Transpose)
  73. def _(expr, assumptions):
  74. return ask(Q.symmetric(expr.arg), assumptions)
  75. @SymmetricPredicate.register(MatrixSlice)
  76. def _(expr, assumptions):
  77. # TODO: implement sathandlers system for the matrices.
  78. # Now it duplicates the general fact: Implies(Q.diagonal, Q.symmetric).
  79. if ask(Q.diagonal(expr), assumptions):
  80. return True
  81. if not expr.on_diag:
  82. return None
  83. else:
  84. return ask(Q.symmetric(expr.parent), assumptions)
  85. @SymmetricPredicate.register(Identity)
  86. def _(expr, assumptions):
  87. return True
  88. # InvertiblePredicate
  89. @InvertiblePredicate.register(MatMul)
  90. def _(expr, assumptions):
  91. factor, mmul = expr.as_coeff_mmul()
  92. if all(ask(Q.invertible(arg), assumptions) for arg in mmul.args):
  93. return True
  94. if any(ask(Q.invertible(arg), assumptions) is False
  95. for arg in mmul.args):
  96. return False
  97. @InvertiblePredicate.register(MatPow)
  98. def _(expr, assumptions):
  99. # only for integer powers
  100. base, exp = expr.args
  101. int_exp = ask(Q.integer(exp), assumptions)
  102. if not int_exp:
  103. return None
  104. if exp.is_negative == False:
  105. return ask(Q.invertible(base), assumptions)
  106. return None
  107. @InvertiblePredicate.register(MatAdd)
  108. def _(expr, assumptions):
  109. return None
  110. @InvertiblePredicate.register(MatrixSymbol)
  111. def _(expr, assumptions):
  112. if not expr.is_square:
  113. return False
  114. if Q.invertible(expr) in conjuncts(assumptions):
  115. return True
  116. @InvertiblePredicate.register_many(Identity, Inverse)
  117. def _(expr, assumptions):
  118. return True
  119. @InvertiblePredicate.register(ZeroMatrix)
  120. def _(expr, assumptions):
  121. return False
  122. @InvertiblePredicate.register(OneMatrix)
  123. def _(expr, assumptions):
  124. return expr.shape[0] == 1 and expr.shape[1] == 1
  125. @InvertiblePredicate.register(Transpose)
  126. def _(expr, assumptions):
  127. return ask(Q.invertible(expr.arg), assumptions)
  128. @InvertiblePredicate.register(MatrixSlice)
  129. def _(expr, assumptions):
  130. if not expr.on_diag:
  131. return None
  132. else:
  133. return ask(Q.invertible(expr.parent), assumptions)
  134. @InvertiblePredicate.register(MatrixBase)
  135. def _(expr, assumptions):
  136. if not expr.is_square:
  137. return False
  138. return expr.rank() == expr.rows
  139. @InvertiblePredicate.register(MatrixExpr)
  140. def _(expr, assumptions):
  141. if not expr.is_square:
  142. return False
  143. return None
  144. @InvertiblePredicate.register(BlockMatrix)
  145. def _(expr, assumptions):
  146. from sympy.matrices.expressions.blockmatrix import reblock_2x2
  147. if not expr.is_square:
  148. return False
  149. if expr.blockshape == (1, 1):
  150. return ask(Q.invertible(expr.blocks[0, 0]), assumptions)
  151. expr = reblock_2x2(expr)
  152. if expr.blockshape == (2, 2):
  153. [[A, B], [C, D]] = expr.blocks.tolist()
  154. if ask(Q.invertible(A), assumptions) == True:
  155. invertible = ask(Q.invertible(D - C * A.I * B), assumptions)
  156. if invertible is not None:
  157. return invertible
  158. if ask(Q.invertible(B), assumptions) == True:
  159. invertible = ask(Q.invertible(C - D * B.I * A), assumptions)
  160. if invertible is not None:
  161. return invertible
  162. if ask(Q.invertible(C), assumptions) == True:
  163. invertible = ask(Q.invertible(B - A * C.I * D), assumptions)
  164. if invertible is not None:
  165. return invertible
  166. if ask(Q.invertible(D), assumptions) == True:
  167. invertible = ask(Q.invertible(A - B * D.I * C), assumptions)
  168. if invertible is not None:
  169. return invertible
  170. return None
  171. @InvertiblePredicate.register(BlockDiagMatrix)
  172. def _(expr, assumptions):
  173. if expr.rowblocksizes != expr.colblocksizes:
  174. return None
  175. return fuzzy_and([ask(Q.invertible(a), assumptions) for a in expr.diag])
  176. # OrthogonalPredicate
  177. @OrthogonalPredicate.register(MatMul)
  178. def _(expr, assumptions):
  179. factor, mmul = expr.as_coeff_mmul()
  180. if (all(ask(Q.orthogonal(arg), assumptions) for arg in mmul.args) and
  181. factor == 1):
  182. return True
  183. if any(ask(Q.invertible(arg), assumptions) is False
  184. for arg in mmul.args):
  185. return False
  186. @OrthogonalPredicate.register(MatPow)
  187. def _(expr, assumptions):
  188. # only for integer powers
  189. base, exp = expr.args
  190. int_exp = ask(Q.integer(exp), assumptions)
  191. if int_exp:
  192. return ask(Q.orthogonal(base), assumptions)
  193. return None
  194. @OrthogonalPredicate.register(MatAdd)
  195. def _(expr, assumptions):
  196. if (len(expr.args) == 1 and
  197. ask(Q.orthogonal(expr.args[0]), assumptions)):
  198. return True
  199. @OrthogonalPredicate.register(MatrixSymbol)
  200. def _(expr, assumptions):
  201. if (not expr.is_square or
  202. ask(Q.invertible(expr), assumptions) is False):
  203. return False
  204. if Q.orthogonal(expr) in conjuncts(assumptions):
  205. return True
  206. @OrthogonalPredicate.register(Identity)
  207. def _(expr, assumptions):
  208. return True
  209. @OrthogonalPredicate.register(ZeroMatrix)
  210. def _(expr, assumptions):
  211. return False
  212. @OrthogonalPredicate.register_many(Inverse, Transpose)
  213. def _(expr, assumptions):
  214. return ask(Q.orthogonal(expr.arg), assumptions)
  215. @OrthogonalPredicate.register(MatrixSlice)
  216. def _(expr, assumptions):
  217. if not expr.on_diag:
  218. return None
  219. else:
  220. return ask(Q.orthogonal(expr.parent), assumptions)
  221. @OrthogonalPredicate.register(Factorization)
  222. def _(expr, assumptions):
  223. return _Factorization(Q.orthogonal, expr, assumptions)
  224. # UnitaryPredicate
  225. @UnitaryPredicate.register(MatMul)
  226. def _(expr, assumptions):
  227. factor, mmul = expr.as_coeff_mmul()
  228. if (all(ask(Q.unitary(arg), assumptions) for arg in mmul.args) and
  229. abs(factor) == 1):
  230. return True
  231. if any(ask(Q.invertible(arg), assumptions) is False
  232. for arg in mmul.args):
  233. return False
  234. @UnitaryPredicate.register(MatPow)
  235. def _(expr, assumptions):
  236. # only for integer powers
  237. base, exp = expr.args
  238. int_exp = ask(Q.integer(exp), assumptions)
  239. if int_exp:
  240. return ask(Q.unitary(base), assumptions)
  241. return None
  242. @UnitaryPredicate.register(MatrixSymbol)
  243. def _(expr, assumptions):
  244. if (not expr.is_square or
  245. ask(Q.invertible(expr), assumptions) is False):
  246. return False
  247. if Q.unitary(expr) in conjuncts(assumptions):
  248. return True
  249. @UnitaryPredicate.register_many(Inverse, Transpose)
  250. def _(expr, assumptions):
  251. return ask(Q.unitary(expr.arg), assumptions)
  252. @UnitaryPredicate.register(MatrixSlice)
  253. def _(expr, assumptions):
  254. if not expr.on_diag:
  255. return None
  256. else:
  257. return ask(Q.unitary(expr.parent), assumptions)
  258. @UnitaryPredicate.register_many(DFT, Identity)
  259. def _(expr, assumptions):
  260. return True
  261. @UnitaryPredicate.register(ZeroMatrix)
  262. def _(expr, assumptions):
  263. return False
  264. @UnitaryPredicate.register(Factorization)
  265. def _(expr, assumptions):
  266. return _Factorization(Q.unitary, expr, assumptions)
  267. # FullRankPredicate
  268. @FullRankPredicate.register(MatMul)
  269. def _(expr, assumptions):
  270. if all(ask(Q.fullrank(arg), assumptions) for arg in expr.args):
  271. return True
  272. @FullRankPredicate.register(MatPow)
  273. def _(expr, assumptions):
  274. # only for integer powers
  275. base, exp = expr.args
  276. int_exp = ask(Q.integer(exp), assumptions)
  277. if int_exp and ask(~Q.negative(exp), assumptions):
  278. return ask(Q.fullrank(base), assumptions)
  279. return None
  280. @FullRankPredicate.register(Identity)
  281. def _(expr, assumptions):
  282. return True
  283. @FullRankPredicate.register(ZeroMatrix)
  284. def _(expr, assumptions):
  285. return False
  286. @FullRankPredicate.register(OneMatrix)
  287. def _(expr, assumptions):
  288. return expr.shape[0] == 1 and expr.shape[1] == 1
  289. @FullRankPredicate.register_many(Inverse, Transpose)
  290. def _(expr, assumptions):
  291. return ask(Q.fullrank(expr.arg), assumptions)
  292. @FullRankPredicate.register(MatrixSlice)
  293. def _(expr, assumptions):
  294. if ask(Q.orthogonal(expr.parent), assumptions):
  295. return True
  296. # PositiveDefinitePredicate
  297. @PositiveDefinitePredicate.register(MatMul)
  298. def _(expr, assumptions):
  299. factor, mmul = expr.as_coeff_mmul()
  300. if (all(ask(Q.positive_definite(arg), assumptions)
  301. for arg in mmul.args) and factor > 0):
  302. return True
  303. if (len(mmul.args) >= 2
  304. and mmul.args[0] == mmul.args[-1].T
  305. and ask(Q.fullrank(mmul.args[0]), assumptions)):
  306. return ask(Q.positive_definite(
  307. MatMul(*mmul.args[1:-1])), assumptions)
  308. @PositiveDefinitePredicate.register(MatPow)
  309. def _(expr, assumptions):
  310. # a power of a positive definite matrix is positive definite
  311. if ask(Q.positive_definite(expr.args[0]), assumptions):
  312. return True
  313. @PositiveDefinitePredicate.register(MatAdd)
  314. def _(expr, assumptions):
  315. if all(ask(Q.positive_definite(arg), assumptions)
  316. for arg in expr.args):
  317. return True
  318. @PositiveDefinitePredicate.register(MatrixSymbol)
  319. def _(expr, assumptions):
  320. if not expr.is_square:
  321. return False
  322. if Q.positive_definite(expr) in conjuncts(assumptions):
  323. return True
  324. @PositiveDefinitePredicate.register(Identity)
  325. def _(expr, assumptions):
  326. return True
  327. @PositiveDefinitePredicate.register(ZeroMatrix)
  328. def _(expr, assumptions):
  329. return False
  330. @PositiveDefinitePredicate.register(OneMatrix)
  331. def _(expr, assumptions):
  332. return expr.shape[0] == 1 and expr.shape[1] == 1
  333. @PositiveDefinitePredicate.register_many(Inverse, Transpose)
  334. def _(expr, assumptions):
  335. return ask(Q.positive_definite(expr.arg), assumptions)
  336. @PositiveDefinitePredicate.register(MatrixSlice)
  337. def _(expr, assumptions):
  338. if not expr.on_diag:
  339. return None
  340. else:
  341. return ask(Q.positive_definite(expr.parent), assumptions)
  342. # UpperTriangularPredicate
  343. @UpperTriangularPredicate.register(MatMul)
  344. def _(expr, assumptions):
  345. factor, matrices = expr.as_coeff_matrices()
  346. if all(ask(Q.upper_triangular(m), assumptions) for m in matrices):
  347. return True
  348. @UpperTriangularPredicate.register(MatAdd)
  349. def _(expr, assumptions):
  350. if all(ask(Q.upper_triangular(arg), assumptions) for arg in expr.args):
  351. return True
  352. @UpperTriangularPredicate.register(MatPow)
  353. def _(expr, assumptions):
  354. # only for integer powers
  355. base, exp = expr.args
  356. int_exp = ask(Q.integer(exp), assumptions)
  357. if not int_exp:
  358. return None
  359. non_negative = ask(~Q.negative(exp), assumptions)
  360. if (non_negative or non_negative == False
  361. and ask(Q.invertible(base), assumptions)):
  362. return ask(Q.upper_triangular(base), assumptions)
  363. return None
  364. @UpperTriangularPredicate.register(MatrixSymbol)
  365. def _(expr, assumptions):
  366. if Q.upper_triangular(expr) in conjuncts(assumptions):
  367. return True
  368. @UpperTriangularPredicate.register_many(Identity, ZeroMatrix)
  369. def _(expr, assumptions):
  370. return True
  371. @UpperTriangularPredicate.register(OneMatrix)
  372. def _(expr, assumptions):
  373. return expr.shape[0] == 1 and expr.shape[1] == 1
  374. @UpperTriangularPredicate.register(Transpose)
  375. def _(expr, assumptions):
  376. return ask(Q.lower_triangular(expr.arg), assumptions)
  377. @UpperTriangularPredicate.register(Inverse)
  378. def _(expr, assumptions):
  379. return ask(Q.upper_triangular(expr.arg), assumptions)
  380. @UpperTriangularPredicate.register(MatrixSlice)
  381. def _(expr, assumptions):
  382. if not expr.on_diag:
  383. return None
  384. else:
  385. return ask(Q.upper_triangular(expr.parent), assumptions)
  386. @UpperTriangularPredicate.register(Factorization)
  387. def _(expr, assumptions):
  388. return _Factorization(Q.upper_triangular, expr, assumptions)
  389. # LowerTriangularPredicate
  390. @LowerTriangularPredicate.register(MatMul)
  391. def _(expr, assumptions):
  392. factor, matrices = expr.as_coeff_matrices()
  393. if all(ask(Q.lower_triangular(m), assumptions) for m in matrices):
  394. return True
  395. @LowerTriangularPredicate.register(MatAdd)
  396. def _(expr, assumptions):
  397. if all(ask(Q.lower_triangular(arg), assumptions) for arg in expr.args):
  398. return True
  399. @LowerTriangularPredicate.register(MatPow)
  400. def _(expr, assumptions):
  401. # only for integer powers
  402. base, exp = expr.args
  403. int_exp = ask(Q.integer(exp), assumptions)
  404. if not int_exp:
  405. return None
  406. non_negative = ask(~Q.negative(exp), assumptions)
  407. if (non_negative or non_negative == False
  408. and ask(Q.invertible(base), assumptions)):
  409. return ask(Q.lower_triangular(base), assumptions)
  410. return None
  411. @LowerTriangularPredicate.register(MatrixSymbol)
  412. def _(expr, assumptions):
  413. if Q.lower_triangular(expr) in conjuncts(assumptions):
  414. return True
  415. @LowerTriangularPredicate.register_many(Identity, ZeroMatrix)
  416. def _(expr, assumptions):
  417. return True
  418. @LowerTriangularPredicate.register(OneMatrix)
  419. def _(expr, assumptions):
  420. return expr.shape[0] == 1 and expr.shape[1] == 1
  421. @LowerTriangularPredicate.register(Transpose)
  422. def _(expr, assumptions):
  423. return ask(Q.upper_triangular(expr.arg), assumptions)
  424. @LowerTriangularPredicate.register(Inverse)
  425. def _(expr, assumptions):
  426. return ask(Q.lower_triangular(expr.arg), assumptions)
  427. @LowerTriangularPredicate.register(MatrixSlice)
  428. def _(expr, assumptions):
  429. if not expr.on_diag:
  430. return None
  431. else:
  432. return ask(Q.lower_triangular(expr.parent), assumptions)
  433. @LowerTriangularPredicate.register(Factorization)
  434. def _(expr, assumptions):
  435. return _Factorization(Q.lower_triangular, expr, assumptions)
  436. # DiagonalPredicate
  437. def _is_empty_or_1x1(expr):
  438. return expr.shape in ((0, 0), (1, 1))
  439. @DiagonalPredicate.register(MatMul)
  440. def _(expr, assumptions):
  441. if _is_empty_or_1x1(expr):
  442. return True
  443. factor, matrices = expr.as_coeff_matrices()
  444. if all(ask(Q.diagonal(m), assumptions) for m in matrices):
  445. return True
  446. @DiagonalPredicate.register(MatPow)
  447. def _(expr, assumptions):
  448. # only for integer powers
  449. base, exp = expr.args
  450. int_exp = ask(Q.integer(exp), assumptions)
  451. if not int_exp:
  452. return None
  453. non_negative = ask(~Q.negative(exp), assumptions)
  454. if (non_negative or non_negative == False
  455. and ask(Q.invertible(base), assumptions)):
  456. return ask(Q.diagonal(base), assumptions)
  457. return None
  458. @DiagonalPredicate.register(MatAdd)
  459. def _(expr, assumptions):
  460. if all(ask(Q.diagonal(arg), assumptions) for arg in expr.args):
  461. return True
  462. @DiagonalPredicate.register(MatrixSymbol)
  463. def _(expr, assumptions):
  464. if _is_empty_or_1x1(expr):
  465. return True
  466. if Q.diagonal(expr) in conjuncts(assumptions):
  467. return True
  468. @DiagonalPredicate.register(OneMatrix)
  469. def _(expr, assumptions):
  470. return expr.shape[0] == 1 and expr.shape[1] == 1
  471. @DiagonalPredicate.register_many(Inverse, Transpose)
  472. def _(expr, assumptions):
  473. return ask(Q.diagonal(expr.arg), assumptions)
  474. @DiagonalPredicate.register(MatrixSlice)
  475. def _(expr, assumptions):
  476. if _is_empty_or_1x1(expr):
  477. return True
  478. if not expr.on_diag:
  479. return None
  480. else:
  481. return ask(Q.diagonal(expr.parent), assumptions)
  482. @DiagonalPredicate.register_many(DiagonalMatrix, DiagMatrix, Identity, ZeroMatrix)
  483. def _(expr, assumptions):
  484. return True
  485. @DiagonalPredicate.register(Factorization)
  486. def _(expr, assumptions):
  487. return _Factorization(Q.diagonal, expr, assumptions)
  488. # IntegerElementsPredicate
  489. def BM_elements(predicate, expr, assumptions):
  490. """ Block Matrix elements. """
  491. return all(ask(predicate(b), assumptions) for b in expr.blocks)
  492. def MS_elements(predicate, expr, assumptions):
  493. """ Matrix Slice elements. """
  494. return ask(predicate(expr.parent), assumptions)
  495. def MatMul_elements(matrix_predicate, scalar_predicate, expr, assumptions):
  496. d = sift(expr.args, lambda x: isinstance(x, MatrixExpr))
  497. factors, matrices = d[False], d[True]
  498. return fuzzy_and([
  499. test_closed_group(Basic(*factors), assumptions, scalar_predicate),
  500. test_closed_group(Basic(*matrices), assumptions, matrix_predicate)])
  501. @IntegerElementsPredicate.register_many(Determinant, HadamardProduct, MatAdd,
  502. Trace, Transpose)
  503. def _(expr, assumptions):
  504. return test_closed_group(expr, assumptions, Q.integer_elements)
  505. @IntegerElementsPredicate.register(MatPow)
  506. def _(expr, assumptions):
  507. # only for integer powers
  508. base, exp = expr.args
  509. int_exp = ask(Q.integer(exp), assumptions)
  510. if not int_exp:
  511. return None
  512. if exp.is_negative == False:
  513. return ask(Q.integer_elements(base), assumptions)
  514. return None
  515. @IntegerElementsPredicate.register_many(Identity, OneMatrix, ZeroMatrix)
  516. def _(expr, assumptions):
  517. return True
  518. @IntegerElementsPredicate.register(MatMul)
  519. def _(expr, assumptions):
  520. return MatMul_elements(Q.integer_elements, Q.integer, expr, assumptions)
  521. @IntegerElementsPredicate.register(MatrixSlice)
  522. def _(expr, assumptions):
  523. return MS_elements(Q.integer_elements, expr, assumptions)
  524. @IntegerElementsPredicate.register(BlockMatrix)
  525. def _(expr, assumptions):
  526. return BM_elements(Q.integer_elements, expr, assumptions)
  527. # RealElementsPredicate
  528. @RealElementsPredicate.register_many(Determinant, Factorization, HadamardProduct,
  529. MatAdd, Trace, Transpose)
  530. def _(expr, assumptions):
  531. return test_closed_group(expr, assumptions, Q.real_elements)
  532. @RealElementsPredicate.register(MatPow)
  533. def _(expr, assumptions):
  534. # only for integer powers
  535. base, exp = expr.args
  536. int_exp = ask(Q.integer(exp), assumptions)
  537. if not int_exp:
  538. return None
  539. non_negative = ask(~Q.negative(exp), assumptions)
  540. if (non_negative or non_negative == False
  541. and ask(Q.invertible(base), assumptions)):
  542. return ask(Q.real_elements(base), assumptions)
  543. return None
  544. @RealElementsPredicate.register(MatMul)
  545. def _(expr, assumptions):
  546. return MatMul_elements(Q.real_elements, Q.real, expr, assumptions)
  547. @RealElementsPredicate.register(MatrixSlice)
  548. def _(expr, assumptions):
  549. return MS_elements(Q.real_elements, expr, assumptions)
  550. @RealElementsPredicate.register(BlockMatrix)
  551. def _(expr, assumptions):
  552. return BM_elements(Q.real_elements, expr, assumptions)
  553. # ComplexElementsPredicate
  554. @ComplexElementsPredicate.register_many(Determinant, Factorization, HadamardProduct,
  555. Inverse, MatAdd, Trace, Transpose)
  556. def _(expr, assumptions):
  557. return test_closed_group(expr, assumptions, Q.complex_elements)
  558. @ComplexElementsPredicate.register(MatPow)
  559. def _(expr, assumptions):
  560. # only for integer powers
  561. base, exp = expr.args
  562. int_exp = ask(Q.integer(exp), assumptions)
  563. if not int_exp:
  564. return None
  565. non_negative = ask(~Q.negative(exp), assumptions)
  566. if (non_negative or non_negative == False
  567. and ask(Q.invertible(base), assumptions)):
  568. return ask(Q.complex_elements(base), assumptions)
  569. return None
  570. @ComplexElementsPredicate.register(MatMul)
  571. def _(expr, assumptions):
  572. return MatMul_elements(Q.complex_elements, Q.complex, expr, assumptions)
  573. @ComplexElementsPredicate.register(MatrixSlice)
  574. def _(expr, assumptions):
  575. return MS_elements(Q.complex_elements, expr, assumptions)
  576. @ComplexElementsPredicate.register(BlockMatrix)
  577. def _(expr, assumptions):
  578. return BM_elements(Q.complex_elements, expr, assumptions)
  579. @ComplexElementsPredicate.register(DFT)
  580. def _(expr, assumptions):
  581. return True