diagonal.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. from sympy.core.sympify import _sympify
  2. from sympy.matrices.expressions import MatrixExpr
  3. from sympy.core import S, Eq, Ge
  4. from sympy.core.mul import Mul
  5. from sympy.functions.special.tensor_functions import KroneckerDelta
  6. class DiagonalMatrix(MatrixExpr):
  7. """DiagonalMatrix(M) will create a matrix expression that
  8. behaves as though all off-diagonal elements,
  9. `M[i, j]` where `i != j`, are zero.
  10. Examples
  11. ========
  12. >>> from sympy import MatrixSymbol, DiagonalMatrix, Symbol
  13. >>> n = Symbol('n', integer=True)
  14. >>> m = Symbol('m', integer=True)
  15. >>> D = DiagonalMatrix(MatrixSymbol('x', 2, 3))
  16. >>> D[1, 2]
  17. 0
  18. >>> D[1, 1]
  19. x[1, 1]
  20. The length of the diagonal -- the lesser of the two dimensions of `M` --
  21. is accessed through the `diagonal_length` property:
  22. >>> D.diagonal_length
  23. 2
  24. >>> DiagonalMatrix(MatrixSymbol('x', n + 1, n)).diagonal_length
  25. n
  26. When one of the dimensions is symbolic the other will be treated as
  27. though it is smaller:
  28. >>> tall = DiagonalMatrix(MatrixSymbol('x', n, 3))
  29. >>> tall.diagonal_length
  30. 3
  31. >>> tall[10, 1]
  32. 0
  33. When the size of the diagonal is not known, a value of None will
  34. be returned:
  35. >>> DiagonalMatrix(MatrixSymbol('x', n, m)).diagonal_length is None
  36. True
  37. """
  38. arg = property(lambda self: self.args[0])
  39. shape = property(lambda self: self.arg.shape) # type:ignore
  40. @property
  41. def diagonal_length(self):
  42. r, c = self.shape
  43. if r.is_Integer and c.is_Integer:
  44. m = min(r, c)
  45. elif r.is_Integer and not c.is_Integer:
  46. m = r
  47. elif c.is_Integer and not r.is_Integer:
  48. m = c
  49. elif r == c:
  50. m = r
  51. else:
  52. try:
  53. m = min(r, c)
  54. except TypeError:
  55. m = None
  56. return m
  57. def _entry(self, i, j, **kwargs):
  58. if self.diagonal_length is not None:
  59. if Ge(i, self.diagonal_length) is S.true:
  60. return S.Zero
  61. elif Ge(j, self.diagonal_length) is S.true:
  62. return S.Zero
  63. eq = Eq(i, j)
  64. if eq is S.true:
  65. return self.arg[i, i]
  66. elif eq is S.false:
  67. return S.Zero
  68. return self.arg[i, j]*KroneckerDelta(i, j)
  69. class DiagonalOf(MatrixExpr):
  70. """DiagonalOf(M) will create a matrix expression that
  71. is equivalent to the diagonal of `M`, represented as
  72. a single column matrix.
  73. Examples
  74. ========
  75. >>> from sympy import MatrixSymbol, DiagonalOf, Symbol
  76. >>> n = Symbol('n', integer=True)
  77. >>> m = Symbol('m', integer=True)
  78. >>> x = MatrixSymbol('x', 2, 3)
  79. >>> diag = DiagonalOf(x)
  80. >>> diag.shape
  81. (2, 1)
  82. The diagonal can be addressed like a matrix or vector and will
  83. return the corresponding element of the original matrix:
  84. >>> diag[1, 0] == diag[1] == x[1, 1]
  85. True
  86. The length of the diagonal -- the lesser of the two dimensions of `M` --
  87. is accessed through the `diagonal_length` property:
  88. >>> diag.diagonal_length
  89. 2
  90. >>> DiagonalOf(MatrixSymbol('x', n + 1, n)).diagonal_length
  91. n
  92. When only one of the dimensions is symbolic the other will be
  93. treated as though it is smaller:
  94. >>> dtall = DiagonalOf(MatrixSymbol('x', n, 3))
  95. >>> dtall.diagonal_length
  96. 3
  97. When the size of the diagonal is not known, a value of None will
  98. be returned:
  99. >>> DiagonalOf(MatrixSymbol('x', n, m)).diagonal_length is None
  100. True
  101. """
  102. arg = property(lambda self: self.args[0])
  103. @property
  104. def shape(self):
  105. r, c = self.arg.shape
  106. if r.is_Integer and c.is_Integer:
  107. m = min(r, c)
  108. elif r.is_Integer and not c.is_Integer:
  109. m = r
  110. elif c.is_Integer and not r.is_Integer:
  111. m = c
  112. elif r == c:
  113. m = r
  114. else:
  115. try:
  116. m = min(r, c)
  117. except TypeError:
  118. m = None
  119. return m, S.One
  120. @property
  121. def diagonal_length(self):
  122. return self.shape[0]
  123. def _entry(self, i, j, **kwargs):
  124. return self.arg._entry(i, i, **kwargs)
  125. class DiagMatrix(MatrixExpr):
  126. """
  127. Turn a vector into a diagonal matrix.
  128. """
  129. def __new__(cls, vector):
  130. vector = _sympify(vector)
  131. obj = MatrixExpr.__new__(cls, vector)
  132. shape = vector.shape
  133. dim = shape[1] if shape[0] == 1 else shape[0]
  134. if vector.shape[0] != 1:
  135. obj._iscolumn = True
  136. else:
  137. obj._iscolumn = False
  138. obj._shape = (dim, dim)
  139. obj._vector = vector
  140. return obj
  141. @property
  142. def shape(self):
  143. return self._shape
  144. def _entry(self, i, j, **kwargs):
  145. if self._iscolumn:
  146. result = self._vector._entry(i, 0, **kwargs)
  147. else:
  148. result = self._vector._entry(0, j, **kwargs)
  149. if i != j:
  150. result *= KroneckerDelta(i, j)
  151. return result
  152. def _eval_transpose(self):
  153. return self
  154. def as_explicit(self):
  155. from sympy.matrices.dense import diag
  156. return diag(*list(self._vector.as_explicit()))
  157. def doit(self, **hints):
  158. from sympy.assumptions import ask, Q
  159. from sympy.matrices.expressions.matmul import MatMul
  160. from sympy.matrices.expressions.transpose import Transpose
  161. from sympy.matrices.dense import eye
  162. from sympy.matrices.matrices import MatrixBase
  163. vector = self._vector
  164. # This accounts for shape (1, 1) and identity matrices, among others:
  165. if ask(Q.diagonal(vector)):
  166. return vector
  167. if isinstance(vector, MatrixBase):
  168. ret = eye(max(vector.shape))
  169. for i in range(ret.shape[0]):
  170. ret[i, i] = vector[i]
  171. return type(vector)(ret)
  172. if vector.is_MatMul:
  173. matrices = [arg for arg in vector.args if arg.is_Matrix]
  174. scalars = [arg for arg in vector.args if arg not in matrices]
  175. if scalars:
  176. return Mul.fromiter(scalars)*DiagMatrix(MatMul.fromiter(matrices).doit()).doit()
  177. if isinstance(vector, Transpose):
  178. vector = vector.arg
  179. return DiagMatrix(vector)
  180. def diagonalize_vector(vector):
  181. return DiagMatrix(vector).doit()