transpose.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from sympy.core.basic import Basic
  2. from sympy.functions import adjoint, conjugate
  3. from sympy.matrices.expressions.matexpr import MatrixExpr
  4. class Transpose(MatrixExpr):
  5. """
  6. The transpose of a matrix expression.
  7. This is a symbolic object that simply stores its argument without
  8. evaluating it. To actually compute the transpose, use the ``transpose()``
  9. function, or the ``.T`` attribute of matrices.
  10. Examples
  11. ========
  12. >>> from sympy import MatrixSymbol, Transpose, transpose
  13. >>> A = MatrixSymbol('A', 3, 5)
  14. >>> B = MatrixSymbol('B', 5, 3)
  15. >>> Transpose(A)
  16. A.T
  17. >>> A.T == transpose(A) == Transpose(A)
  18. True
  19. >>> Transpose(A*B)
  20. (A*B).T
  21. >>> transpose(A*B)
  22. B.T*A.T
  23. """
  24. is_Transpose = True
  25. def doit(self, **hints):
  26. arg = self.arg
  27. if hints.get('deep', True) and isinstance(arg, Basic):
  28. arg = arg.doit(**hints)
  29. _eval_transpose = getattr(arg, '_eval_transpose', None)
  30. if _eval_transpose is not None:
  31. result = _eval_transpose()
  32. return result if result is not None else Transpose(arg)
  33. else:
  34. return Transpose(arg)
  35. @property
  36. def arg(self):
  37. return self.args[0]
  38. @property
  39. def shape(self):
  40. return self.arg.shape[::-1]
  41. def _entry(self, i, j, expand=False, **kwargs):
  42. return self.arg._entry(j, i, expand=expand, **kwargs)
  43. def _eval_adjoint(self):
  44. return conjugate(self.arg)
  45. def _eval_conjugate(self):
  46. return adjoint(self.arg)
  47. def _eval_transpose(self):
  48. return self.arg
  49. def _eval_trace(self):
  50. from .trace import Trace
  51. return Trace(self.arg) # Trace(X.T) => Trace(X)
  52. def _eval_determinant(self):
  53. from sympy.matrices.expressions.determinant import det
  54. return det(self.arg)
  55. def _eval_derivative(self, x):
  56. # x is a scalar:
  57. return self.arg._eval_derivative(x)
  58. def _eval_derivative_matrix_lines(self, x):
  59. lines = self.args[0]._eval_derivative_matrix_lines(x)
  60. return [i.transpose() for i in lines]
  61. def transpose(expr):
  62. """Matrix transpose"""
  63. return Transpose(expr).doit(deep=False)
  64. from sympy.assumptions.ask import ask, Q
  65. from sympy.assumptions.refine import handlers_dict
  66. def refine_Transpose(expr, assumptions):
  67. """
  68. >>> from sympy import MatrixSymbol, Q, assuming, refine
  69. >>> X = MatrixSymbol('X', 2, 2)
  70. >>> X.T
  71. X.T
  72. >>> with assuming(Q.symmetric(X)):
  73. ... print(refine(X.T))
  74. X
  75. """
  76. if ask(Q.symmetric(expr), assumptions):
  77. return expr.arg
  78. return expr
  79. handlers_dict['Transpose'] = refine_Transpose