123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360 |
- from typing import Any, Dict as tDict
- from sympy.simplify import simplify as simp, trigsimp as tsimp
- from sympy.core.decorators import call_highest_priority, _sympifyit
- from sympy.core.assumptions import StdFactKB
- from sympy.core.function import diff as df
- from sympy.integrals.integrals import Integral
- from sympy.polys.polytools import factor as fctr
- from sympy.core import S, Add, Mul
- from sympy.core.expr import Expr
- class BasisDependent(Expr):
- """
- Super class containing functionality common to vectors and
- dyadics.
- Named so because the representation of these quantities in
- sympy.vector is dependent on the basis they are expressed in.
- """
- @call_highest_priority('__radd__')
- def __add__(self, other):
- return self._add_func(self, other)
- @call_highest_priority('__add__')
- def __radd__(self, other):
- return self._add_func(other, self)
- @call_highest_priority('__rsub__')
- def __sub__(self, other):
- return self._add_func(self, -other)
- @call_highest_priority('__sub__')
- def __rsub__(self, other):
- return self._add_func(other, -self)
- @_sympifyit('other', NotImplemented)
- @call_highest_priority('__rmul__')
- def __mul__(self, other):
- return self._mul_func(self, other)
- @_sympifyit('other', NotImplemented)
- @call_highest_priority('__mul__')
- def __rmul__(self, other):
- return self._mul_func(other, self)
- def __neg__(self):
- return self._mul_func(S.NegativeOne, self)
- @_sympifyit('other', NotImplemented)
- @call_highest_priority('__rtruediv__')
- def __truediv__(self, other):
- return self._div_helper(other)
- @call_highest_priority('__truediv__')
- def __rtruediv__(self, other):
- return TypeError("Invalid divisor for division")
- def evalf(self, n=15, subs=None, maxn=100, chop=False, strict=False, quad=None, verbose=False):
- """
- Implements the SymPy evalf routine for this quantity.
- evalf's documentation
- =====================
- """
- options = {'subs':subs, 'maxn':maxn, 'chop':chop, 'strict':strict,
- 'quad':quad, 'verbose':verbose}
- vec = self.zero
- for k, v in self.components.items():
- vec += v.evalf(n, **options) * k
- return vec
- evalf.__doc__ += Expr.evalf.__doc__ # type: ignore
- n = evalf
- def simplify(self, **kwargs):
- """
- Implements the SymPy simplify routine for this quantity.
- simplify's documentation
- ========================
- """
- simp_components = [simp(v, **kwargs) * k for
- k, v in self.components.items()]
- return self._add_func(*simp_components)
- simplify.__doc__ += simp.__doc__ # type: ignore
- def trigsimp(self, **opts):
- """
- Implements the SymPy trigsimp routine, for this quantity.
- trigsimp's documentation
- ========================
- """
- trig_components = [tsimp(v, **opts) * k for
- k, v in self.components.items()]
- return self._add_func(*trig_components)
- trigsimp.__doc__ += tsimp.__doc__ # type: ignore
- def _eval_simplify(self, **kwargs):
- return self.simplify(**kwargs)
- def _eval_trigsimp(self, **opts):
- return self.trigsimp(**opts)
- def _eval_derivative(self, wrt):
- return self.diff(wrt)
- def _eval_Integral(self, *symbols, **assumptions):
- integral_components = [Integral(v, *symbols, **assumptions) * k
- for k, v in self.components.items()]
- return self._add_func(*integral_components)
- def as_numer_denom(self):
- """
- Returns the expression as a tuple wrt the following
- transformation -
- expression -> a/b -> a, b
- """
- return self, S.One
- def factor(self, *args, **kwargs):
- """
- Implements the SymPy factor routine, on the scalar parts
- of a basis-dependent expression.
- factor's documentation
- ========================
- """
- fctr_components = [fctr(v, *args, **kwargs) * k for
- k, v in self.components.items()]
- return self._add_func(*fctr_components)
- factor.__doc__ += fctr.__doc__ # type: ignore
- def as_coeff_Mul(self, rational=False):
- """Efficiently extract the coefficient of a product. """
- return (S.One, self)
- def as_coeff_add(self, *deps):
- """Efficiently extract the coefficient of a summation. """
- l = [x * self.components[x] for x in self.components]
- return 0, tuple(l)
- def diff(self, *args, **kwargs):
- """
- Implements the SymPy diff routine, for vectors.
- diff's documentation
- ========================
- """
- for x in args:
- if isinstance(x, BasisDependent):
- raise TypeError("Invalid arg for differentiation")
- diff_components = [df(v, *args, **kwargs) * k for
- k, v in self.components.items()]
- return self._add_func(*diff_components)
- diff.__doc__ += df.__doc__ # type: ignore
- def doit(self, **hints):
- """Calls .doit() on each term in the Dyadic"""
- doit_components = [self.components[x].doit(**hints) * x
- for x in self.components]
- return self._add_func(*doit_components)
- class BasisDependentAdd(BasisDependent, Add):
- """
- Denotes sum of basis dependent quantities such that they cannot
- be expressed as base or Mul instances.
- """
- def __new__(cls, *args, **options):
- components = {}
- # Check each arg and simultaneously learn the components
- for i, arg in enumerate(args):
- if not isinstance(arg, cls._expr_type):
- if isinstance(arg, Mul):
- arg = cls._mul_func(*(arg.args))
- elif isinstance(arg, Add):
- arg = cls._add_func(*(arg.args))
- else:
- raise TypeError(str(arg) +
- " cannot be interpreted correctly")
- # If argument is zero, ignore
- if arg == cls.zero:
- continue
- # Else, update components accordingly
- if hasattr(arg, "components"):
- for x in arg.components:
- components[x] = components.get(x, 0) + arg.components[x]
- temp = list(components.keys())
- for x in temp:
- if components[x] == 0:
- del components[x]
- # Handle case of zero vector
- if len(components) == 0:
- return cls.zero
- # Build object
- newargs = [x * components[x] for x in components]
- obj = super().__new__(cls, *newargs, **options)
- if isinstance(obj, Mul):
- return cls._mul_func(*obj.args)
- assumptions = {'commutative': True}
- obj._assumptions = StdFactKB(assumptions)
- obj._components = components
- obj._sys = (list(components.keys()))[0]._sys
- return obj
- class BasisDependentMul(BasisDependent, Mul):
- """
- Denotes product of base- basis dependent quantity with a scalar.
- """
- def __new__(cls, *args, **options):
- from sympy.vector import Cross, Dot, Curl, Gradient
- count = 0
- measure_number = S.One
- zeroflag = False
- extra_args = []
- # Determine the component and check arguments
- # Also keep a count to ensure two vectors aren't
- # being multiplied
- for arg in args:
- if isinstance(arg, cls._zero_func):
- count += 1
- zeroflag = True
- elif arg == S.Zero:
- zeroflag = True
- elif isinstance(arg, (cls._base_func, cls._mul_func)):
- count += 1
- expr = arg._base_instance
- measure_number *= arg._measure_number
- elif isinstance(arg, cls._add_func):
- count += 1
- expr = arg
- elif isinstance(arg, (Cross, Dot, Curl, Gradient)):
- extra_args.append(arg)
- else:
- measure_number *= arg
- # Make sure incompatible types weren't multiplied
- if count > 1:
- raise ValueError("Invalid multiplication")
- elif count == 0:
- return Mul(*args, **options)
- # Handle zero vector case
- if zeroflag:
- return cls.zero
- # If one of the args was a VectorAdd, return an
- # appropriate VectorAdd instance
- if isinstance(expr, cls._add_func):
- newargs = [cls._mul_func(measure_number, x) for
- x in expr.args]
- return cls._add_func(*newargs)
- obj = super().__new__(cls, measure_number,
- expr._base_instance,
- *extra_args,
- **options)
- if isinstance(obj, Add):
- return cls._add_func(*obj.args)
- obj._base_instance = expr._base_instance
- obj._measure_number = measure_number
- assumptions = {'commutative': True}
- obj._assumptions = StdFactKB(assumptions)
- obj._components = {expr._base_instance: measure_number}
- obj._sys = expr._base_instance._sys
- return obj
- def _sympystr(self, printer):
- measure_str = printer._print(self._measure_number)
- if ('(' in measure_str or '-' in measure_str or
- '+' in measure_str):
- measure_str = '(' + measure_str + ')'
- return measure_str + '*' + printer._print(self._base_instance)
- class BasisDependentZero(BasisDependent):
- """
- Class to denote a zero basis dependent instance.
- """
- # XXX: Can't type the keys as BaseVector because of cyclic import
- # problems.
- components = {} # type: tDict[Any, Expr]
- def __new__(cls):
- obj = super().__new__(cls)
- # Pre-compute a specific hash value for the zero vector
- # Use the same one always
- obj._hash = tuple([S.Zero, cls]).__hash__()
- return obj
- def __hash__(self):
- return self._hash
- @call_highest_priority('__req__')
- def __eq__(self, other):
- return isinstance(other, self._zero_func)
- __req__ = __eq__
- @call_highest_priority('__radd__')
- def __add__(self, other):
- if isinstance(other, self._expr_type):
- return other
- else:
- raise TypeError("Invalid argument types for addition")
- @call_highest_priority('__add__')
- def __radd__(self, other):
- if isinstance(other, self._expr_type):
- return other
- else:
- raise TypeError("Invalid argument types for addition")
- @call_highest_priority('__rsub__')
- def __sub__(self, other):
- if isinstance(other, self._expr_type):
- return -other
- else:
- raise TypeError("Invalid argument types for subtraction")
- @call_highest_priority('__sub__')
- def __rsub__(self, other):
- if isinstance(other, self._expr_type):
- return other
- else:
- raise TypeError("Invalid argument types for subtraction")
- def __neg__(self):
- return self
- def normalize(self):
- """
- Returns the normalized version of this vector.
- """
- return self
- def _sympystr(self, printer):
- return '0'
|