matrix_nodes.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. """
  2. Additional AST nodes for operations on matrices. The nodes in this module
  3. are meant to represent optimization of matrix expressions within codegen's
  4. target languages that cannot be represented by SymPy expressions.
  5. As an example, we can use :meth:`sympy.codegen.rewriting.optimize` and the
  6. ``matin_opt`` optimization provided in :mod:`sympy.codegen.rewriting` to
  7. transform matrix multiplication under certain assumptions:
  8. >>> from sympy import symbols, MatrixSymbol
  9. >>> n = symbols('n', integer=True)
  10. >>> A = MatrixSymbol('A', n, n)
  11. >>> x = MatrixSymbol('x', n, 1)
  12. >>> expr = A**(-1) * x
  13. >>> from sympy import assuming, Q
  14. >>> from sympy.codegen.rewriting import matinv_opt, optimize
  15. >>> with assuming(Q.fullrank(A)):
  16. ... optimize(expr, [matinv_opt])
  17. MatrixSolve(A, vector=x)
  18. """
  19. from .ast import Token
  20. from sympy.matrices import MatrixExpr
  21. from sympy.core.sympify import sympify
  22. class MatrixSolve(Token, MatrixExpr):
  23. """Represents an operation to solve a linear matrix equation.
  24. Parameters
  25. ==========
  26. matrix : MatrixSymbol
  27. Matrix representing the coefficients of variables in the linear
  28. equation. This matrix must be square and full-rank (i.e. all columns must
  29. be linearly independent) for the solving operation to be valid.
  30. vector : MatrixSymbol
  31. One-column matrix representing the solutions to the equations
  32. represented in ``matrix``.
  33. Examples
  34. ========
  35. >>> from sympy import symbols, MatrixSymbol
  36. >>> from sympy.codegen.matrix_nodes import MatrixSolve
  37. >>> n = symbols('n', integer=True)
  38. >>> A = MatrixSymbol('A', n, n)
  39. >>> x = MatrixSymbol('x', n, 1)
  40. >>> from sympy.printing.numpy import NumPyPrinter
  41. >>> NumPyPrinter().doprint(MatrixSolve(A, x))
  42. 'numpy.linalg.solve(A, x)'
  43. >>> from sympy import octave_code
  44. >>> octave_code(MatrixSolve(A, x))
  45. 'A \\\\ x'
  46. """
  47. __slots__ = ('matrix', 'vector')
  48. _construct_matrix = staticmethod(sympify)
  49. @property
  50. def shape(self):
  51. return self.vector.shape