123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- from collections.abc import Iterable
- from functools import singledispatch
- from sympy.core.expr import Expr
- from sympy.core.mul import Mul
- from sympy.core.singleton import S
- from sympy.core.sympify import sympify
- from sympy.core.parameters import global_parameters
- class TensorProduct(Expr):
- """
- Generic class for tensor products.
- """
- is_number = False
- def __new__(cls, *args, **kwargs):
- from sympy.tensor.array import NDimArray, tensorproduct, Array
- from sympy.matrices.expressions.matexpr import MatrixExpr
- from sympy.matrices.matrices import MatrixBase
- from sympy.strategies import flatten
- args = [sympify(arg) for arg in args]
- evaluate = kwargs.get("evaluate", global_parameters.evaluate)
- if not evaluate:
- obj = Expr.__new__(cls, *args)
- return obj
- arrays = []
- other = []
- scalar = S.One
- for arg in args:
- if isinstance(arg, (Iterable, MatrixBase, NDimArray)):
- arrays.append(Array(arg))
- elif isinstance(arg, (MatrixExpr,)):
- other.append(arg)
- else:
- scalar *= arg
- coeff = scalar*tensorproduct(*arrays)
- if len(other) == 0:
- return coeff
- if coeff != 1:
- newargs = [coeff] + other
- else:
- newargs = other
- obj = Expr.__new__(cls, *newargs, **kwargs)
- return flatten(obj)
- def rank(self):
- return len(self.shape)
- def _get_args_shapes(self):
- from sympy.tensor.array import Array
- return [i.shape if hasattr(i, "shape") else Array(i).shape for i in self.args]
- @property
- def shape(self):
- shape_list = self._get_args_shapes()
- return sum(shape_list, ())
- def __getitem__(self, index):
- index = iter(index)
- return Mul.fromiter(
- arg.__getitem__(tuple(next(index) for i in shp))
- for arg, shp in zip(self.args, self._get_args_shapes())
- )
- @singledispatch
- def shape(expr):
- """
- Return the shape of the *expr* as a tuple. *expr* should represent
- suitable object such as matrix or array.
- Parameters
- ==========
- expr : SymPy object having ``MatrixKind`` or ``ArrayKind``.
- Raises
- ======
- NoShapeError : Raised when object with wrong kind is passed.
- Examples
- ========
- This function returns the shape of any object representing matrix or array.
- >>> from sympy import shape, Array, ImmutableDenseMatrix, Integral
- >>> from sympy.abc import x
- >>> A = Array([1, 2])
- >>> shape(A)
- (2,)
- >>> shape(Integral(A, x))
- (2,)
- >>> M = ImmutableDenseMatrix([1, 2])
- >>> shape(M)
- (2, 1)
- >>> shape(Integral(M, x))
- (2, 1)
- You can support new type by dispatching.
- >>> from sympy import Expr
- >>> class NewExpr(Expr):
- ... pass
- >>> @shape.register(NewExpr)
- ... def _(expr):
- ... return shape(expr.args[0])
- >>> shape(NewExpr(M))
- (2, 1)
- If unsuitable expression is passed, ``NoShapeError()`` will be raised.
- >>> shape(Integral(x, x))
- Traceback (most recent call last):
- ...
- sympy.tensor.functions.NoShapeError: shape() called on non-array object: Integral(x, x)
- Notes
- =====
- Array-like classes (such as ``Matrix`` or ``NDimArray``) has ``shape``
- property which returns its shape, but it cannot be used for non-array
- classes containing array. This function returns the shape of any
- registered object representing array.
- """
- if hasattr(expr, "shape"):
- return expr.shape
- raise NoShapeError(
- "%s does not have shape, or its type is not registered to shape()." % expr)
- class NoShapeError(Exception):
- """
- Raised when ``shape()`` is called on non-array object.
- This error can be imported from ``sympy.tensor.functions``.
- Examples
- ========
- >>> from sympy import shape
- >>> from sympy.abc import x
- >>> shape(x)
- Traceback (most recent call last):
- ...
- sympy.tensor.functions.NoShapeError: shape() called on non-array object: x
- """
- pass
|