special.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. from sympy.assumptions.ask import ask, Q
  2. from sympy.core.relational import Eq
  3. from sympy.core.singleton import S
  4. from sympy.core.sympify import _sympify
  5. from sympy.functions.special.tensor_functions import KroneckerDelta
  6. from sympy.matrices.common import NonInvertibleMatrixError
  7. from .matexpr import MatrixExpr
  8. class ZeroMatrix(MatrixExpr):
  9. """The Matrix Zero 0 - additive identity
  10. Examples
  11. ========
  12. >>> from sympy import MatrixSymbol, ZeroMatrix
  13. >>> A = MatrixSymbol('A', 3, 5)
  14. >>> Z = ZeroMatrix(3, 5)
  15. >>> A + Z
  16. A
  17. >>> Z*A.T
  18. 0
  19. """
  20. is_ZeroMatrix = True
  21. def __new__(cls, m, n):
  22. m, n = _sympify(m), _sympify(n)
  23. cls._check_dim(m)
  24. cls._check_dim(n)
  25. return super().__new__(cls, m, n)
  26. @property
  27. def shape(self):
  28. return (self.args[0], self.args[1])
  29. def _eval_power(self, exp):
  30. # exp = -1, 0, 1 are already handled at this stage
  31. if (exp < 0) == True:
  32. raise NonInvertibleMatrixError("Matrix det == 0; not invertible")
  33. return self
  34. def _eval_transpose(self):
  35. return ZeroMatrix(self.cols, self.rows)
  36. def _eval_trace(self):
  37. return S.Zero
  38. def _eval_determinant(self):
  39. return S.Zero
  40. def _eval_inverse(self):
  41. raise NonInvertibleMatrixError("Matrix det == 0; not invertible.")
  42. def conjugate(self):
  43. return self
  44. def _entry(self, i, j, **kwargs):
  45. return S.Zero
  46. class GenericZeroMatrix(ZeroMatrix):
  47. """
  48. A zero matrix without a specified shape
  49. This exists primarily so MatAdd() with no arguments can return something
  50. meaningful.
  51. """
  52. def __new__(cls):
  53. # super(ZeroMatrix, cls) instead of super(GenericZeroMatrix, cls)
  54. # because ZeroMatrix.__new__ doesn't have the same signature
  55. return super(ZeroMatrix, cls).__new__(cls)
  56. @property
  57. def rows(self):
  58. raise TypeError("GenericZeroMatrix does not have a specified shape")
  59. @property
  60. def cols(self):
  61. raise TypeError("GenericZeroMatrix does not have a specified shape")
  62. @property
  63. def shape(self):
  64. raise TypeError("GenericZeroMatrix does not have a specified shape")
  65. # Avoid Matrix.__eq__ which might call .shape
  66. def __eq__(self, other):
  67. return isinstance(other, GenericZeroMatrix)
  68. def __ne__(self, other):
  69. return not (self == other)
  70. def __hash__(self):
  71. return super().__hash__()
  72. class Identity(MatrixExpr):
  73. """The Matrix Identity I - multiplicative identity
  74. Examples
  75. ========
  76. >>> from sympy import Identity, MatrixSymbol
  77. >>> A = MatrixSymbol('A', 3, 5)
  78. >>> I = Identity(3)
  79. >>> I*A
  80. A
  81. """
  82. is_Identity = True
  83. def __new__(cls, n):
  84. n = _sympify(n)
  85. cls._check_dim(n)
  86. return super().__new__(cls, n)
  87. @property
  88. def rows(self):
  89. return self.args[0]
  90. @property
  91. def cols(self):
  92. return self.args[0]
  93. @property
  94. def shape(self):
  95. return (self.args[0], self.args[0])
  96. @property
  97. def is_square(self):
  98. return True
  99. def _eval_transpose(self):
  100. return self
  101. def _eval_trace(self):
  102. return self.rows
  103. def _eval_inverse(self):
  104. return self
  105. def conjugate(self):
  106. return self
  107. def _entry(self, i, j, **kwargs):
  108. eq = Eq(i, j)
  109. if eq is S.true:
  110. return S.One
  111. elif eq is S.false:
  112. return S.Zero
  113. return KroneckerDelta(i, j, (0, self.cols-1))
  114. def _eval_determinant(self):
  115. return S.One
  116. def _eval_power(self, exp):
  117. return self
  118. class GenericIdentity(Identity):
  119. """
  120. An identity matrix without a specified shape
  121. This exists primarily so MatMul() with no arguments can return something
  122. meaningful.
  123. """
  124. def __new__(cls):
  125. # super(Identity, cls) instead of super(GenericIdentity, cls) because
  126. # Identity.__new__ doesn't have the same signature
  127. return super(Identity, cls).__new__(cls)
  128. @property
  129. def rows(self):
  130. raise TypeError("GenericIdentity does not have a specified shape")
  131. @property
  132. def cols(self):
  133. raise TypeError("GenericIdentity does not have a specified shape")
  134. @property
  135. def shape(self):
  136. raise TypeError("GenericIdentity does not have a specified shape")
  137. # Avoid Matrix.__eq__ which might call .shape
  138. def __eq__(self, other):
  139. return isinstance(other, GenericIdentity)
  140. def __ne__(self, other):
  141. return not (self == other)
  142. def __hash__(self):
  143. return super().__hash__()
  144. class OneMatrix(MatrixExpr):
  145. """
  146. Matrix whose all entries are ones.
  147. """
  148. def __new__(cls, m, n, evaluate=False):
  149. m, n = _sympify(m), _sympify(n)
  150. cls._check_dim(m)
  151. cls._check_dim(n)
  152. if evaluate:
  153. condition = Eq(m, 1) & Eq(n, 1)
  154. if condition == True:
  155. return Identity(1)
  156. obj = super().__new__(cls, m, n)
  157. return obj
  158. @property
  159. def shape(self):
  160. return self._args
  161. @property
  162. def is_Identity(self):
  163. return self._is_1x1() == True
  164. def as_explicit(self):
  165. from sympy.matrices.immutable import ImmutableDenseMatrix
  166. return ImmutableDenseMatrix.ones(*self.shape)
  167. def doit(self, **hints):
  168. args = self.args
  169. if hints.get('deep', True):
  170. args = [a.doit(**hints) for a in args]
  171. return self.func(*args, evaluate=True)
  172. def _eval_power(self, exp):
  173. # exp = -1, 0, 1 are already handled at this stage
  174. if self._is_1x1() == True:
  175. return Identity(1)
  176. if (exp < 0) == True:
  177. raise NonInvertibleMatrixError("Matrix det == 0; not invertible")
  178. if ask(Q.integer(exp)):
  179. return self.shape[0] ** (exp - 1) * OneMatrix(*self.shape)
  180. return super()._eval_power(exp)
  181. def _eval_transpose(self):
  182. return OneMatrix(self.cols, self.rows)
  183. def _eval_trace(self):
  184. return S.One*self.rows
  185. def _is_1x1(self):
  186. """Returns true if the matrix is known to be 1x1"""
  187. shape = self.shape
  188. return Eq(shape[0], 1) & Eq(shape[1], 1)
  189. def _eval_determinant(self):
  190. condition = self._is_1x1()
  191. if condition == True:
  192. return S.One
  193. elif condition == False:
  194. return S.Zero
  195. else:
  196. from sympy.matrices.expressions.determinant import Determinant
  197. return Determinant(self)
  198. def _eval_inverse(self):
  199. condition = self._is_1x1()
  200. if condition == True:
  201. return Identity(1)
  202. elif condition == False:
  203. raise NonInvertibleMatrixError("Matrix det == 0; not invertible.")
  204. else:
  205. from .inverse import Inverse
  206. return Inverse(self)
  207. def conjugate(self):
  208. return self
  209. def _entry(self, i, j, **kwargs):
  210. return S.One