functions.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. from collections.abc import Iterable
  2. from functools import singledispatch
  3. from sympy.core.expr import Expr
  4. from sympy.core.mul import Mul
  5. from sympy.core.singleton import S
  6. from sympy.core.sympify import sympify
  7. from sympy.core.parameters import global_parameters
  8. class TensorProduct(Expr):
  9. """
  10. Generic class for tensor products.
  11. """
  12. is_number = False
  13. def __new__(cls, *args, **kwargs):
  14. from sympy.tensor.array import NDimArray, tensorproduct, Array
  15. from sympy.matrices.expressions.matexpr import MatrixExpr
  16. from sympy.matrices.matrices import MatrixBase
  17. from sympy.strategies import flatten
  18. args = [sympify(arg) for arg in args]
  19. evaluate = kwargs.get("evaluate", global_parameters.evaluate)
  20. if not evaluate:
  21. obj = Expr.__new__(cls, *args)
  22. return obj
  23. arrays = []
  24. other = []
  25. scalar = S.One
  26. for arg in args:
  27. if isinstance(arg, (Iterable, MatrixBase, NDimArray)):
  28. arrays.append(Array(arg))
  29. elif isinstance(arg, (MatrixExpr,)):
  30. other.append(arg)
  31. else:
  32. scalar *= arg
  33. coeff = scalar*tensorproduct(*arrays)
  34. if len(other) == 0:
  35. return coeff
  36. if coeff != 1:
  37. newargs = [coeff] + other
  38. else:
  39. newargs = other
  40. obj = Expr.__new__(cls, *newargs, **kwargs)
  41. return flatten(obj)
  42. def rank(self):
  43. return len(self.shape)
  44. def _get_args_shapes(self):
  45. from sympy.tensor.array import Array
  46. return [i.shape if hasattr(i, "shape") else Array(i).shape for i in self.args]
  47. @property
  48. def shape(self):
  49. shape_list = self._get_args_shapes()
  50. return sum(shape_list, ())
  51. def __getitem__(self, index):
  52. index = iter(index)
  53. return Mul.fromiter(
  54. arg.__getitem__(tuple(next(index) for i in shp))
  55. for arg, shp in zip(self.args, self._get_args_shapes())
  56. )
  57. @singledispatch
  58. def shape(expr):
  59. """
  60. Return the shape of the *expr* as a tuple. *expr* should represent
  61. suitable object such as matrix or array.
  62. Parameters
  63. ==========
  64. expr : SymPy object having ``MatrixKind`` or ``ArrayKind``.
  65. Raises
  66. ======
  67. NoShapeError : Raised when object with wrong kind is passed.
  68. Examples
  69. ========
  70. This function returns the shape of any object representing matrix or array.
  71. >>> from sympy import shape, Array, ImmutableDenseMatrix, Integral
  72. >>> from sympy.abc import x
  73. >>> A = Array([1, 2])
  74. >>> shape(A)
  75. (2,)
  76. >>> shape(Integral(A, x))
  77. (2,)
  78. >>> M = ImmutableDenseMatrix([1, 2])
  79. >>> shape(M)
  80. (2, 1)
  81. >>> shape(Integral(M, x))
  82. (2, 1)
  83. You can support new type by dispatching.
  84. >>> from sympy import Expr
  85. >>> class NewExpr(Expr):
  86. ... pass
  87. >>> @shape.register(NewExpr)
  88. ... def _(expr):
  89. ... return shape(expr.args[0])
  90. >>> shape(NewExpr(M))
  91. (2, 1)
  92. If unsuitable expression is passed, ``NoShapeError()`` will be raised.
  93. >>> shape(Integral(x, x))
  94. Traceback (most recent call last):
  95. ...
  96. sympy.tensor.functions.NoShapeError: shape() called on non-array object: Integral(x, x)
  97. Notes
  98. =====
  99. Array-like classes (such as ``Matrix`` or ``NDimArray``) has ``shape``
  100. property which returns its shape, but it cannot be used for non-array
  101. classes containing array. This function returns the shape of any
  102. registered object representing array.
  103. """
  104. if hasattr(expr, "shape"):
  105. return expr.shape
  106. raise NoShapeError(
  107. "%s does not have shape, or its type is not registered to shape()." % expr)
  108. class NoShapeError(Exception):
  109. """
  110. Raised when ``shape()`` is called on non-array object.
  111. This error can be imported from ``sympy.tensor.functions``.
  112. Examples
  113. ========
  114. >>> from sympy import shape
  115. >>> from sympy.abc import x
  116. >>> shape(x)
  117. Traceback (most recent call last):
  118. ...
  119. sympy.tensor.functions.NoShapeError: shape() called on non-array object: x
  120. """
  121. pass