basisdependent.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. from typing import Any, Dict as tDict
  2. from sympy.simplify import simplify as simp, trigsimp as tsimp
  3. from sympy.core.decorators import call_highest_priority, _sympifyit
  4. from sympy.core.assumptions import StdFactKB
  5. from sympy.core.function import diff as df
  6. from sympy.integrals.integrals import Integral
  7. from sympy.polys.polytools import factor as fctr
  8. from sympy.core import S, Add, Mul
  9. from sympy.core.expr import Expr
  10. class BasisDependent(Expr):
  11. """
  12. Super class containing functionality common to vectors and
  13. dyadics.
  14. Named so because the representation of these quantities in
  15. sympy.vector is dependent on the basis they are expressed in.
  16. """
  17. @call_highest_priority('__radd__')
  18. def __add__(self, other):
  19. return self._add_func(self, other)
  20. @call_highest_priority('__add__')
  21. def __radd__(self, other):
  22. return self._add_func(other, self)
  23. @call_highest_priority('__rsub__')
  24. def __sub__(self, other):
  25. return self._add_func(self, -other)
  26. @call_highest_priority('__sub__')
  27. def __rsub__(self, other):
  28. return self._add_func(other, -self)
  29. @_sympifyit('other', NotImplemented)
  30. @call_highest_priority('__rmul__')
  31. def __mul__(self, other):
  32. return self._mul_func(self, other)
  33. @_sympifyit('other', NotImplemented)
  34. @call_highest_priority('__mul__')
  35. def __rmul__(self, other):
  36. return self._mul_func(other, self)
  37. def __neg__(self):
  38. return self._mul_func(S.NegativeOne, self)
  39. @_sympifyit('other', NotImplemented)
  40. @call_highest_priority('__rtruediv__')
  41. def __truediv__(self, other):
  42. return self._div_helper(other)
  43. @call_highest_priority('__truediv__')
  44. def __rtruediv__(self, other):
  45. return TypeError("Invalid divisor for division")
  46. def evalf(self, n=15, subs=None, maxn=100, chop=False, strict=False, quad=None, verbose=False):
  47. """
  48. Implements the SymPy evalf routine for this quantity.
  49. evalf's documentation
  50. =====================
  51. """
  52. options = {'subs':subs, 'maxn':maxn, 'chop':chop, 'strict':strict,
  53. 'quad':quad, 'verbose':verbose}
  54. vec = self.zero
  55. for k, v in self.components.items():
  56. vec += v.evalf(n, **options) * k
  57. return vec
  58. evalf.__doc__ += Expr.evalf.__doc__ # type: ignore
  59. n = evalf
  60. def simplify(self, **kwargs):
  61. """
  62. Implements the SymPy simplify routine for this quantity.
  63. simplify's documentation
  64. ========================
  65. """
  66. simp_components = [simp(v, **kwargs) * k for
  67. k, v in self.components.items()]
  68. return self._add_func(*simp_components)
  69. simplify.__doc__ += simp.__doc__ # type: ignore
  70. def trigsimp(self, **opts):
  71. """
  72. Implements the SymPy trigsimp routine, for this quantity.
  73. trigsimp's documentation
  74. ========================
  75. """
  76. trig_components = [tsimp(v, **opts) * k for
  77. k, v in self.components.items()]
  78. return self._add_func(*trig_components)
  79. trigsimp.__doc__ += tsimp.__doc__ # type: ignore
  80. def _eval_simplify(self, **kwargs):
  81. return self.simplify(**kwargs)
  82. def _eval_trigsimp(self, **opts):
  83. return self.trigsimp(**opts)
  84. def _eval_derivative(self, wrt):
  85. return self.diff(wrt)
  86. def _eval_Integral(self, *symbols, **assumptions):
  87. integral_components = [Integral(v, *symbols, **assumptions) * k
  88. for k, v in self.components.items()]
  89. return self._add_func(*integral_components)
  90. def as_numer_denom(self):
  91. """
  92. Returns the expression as a tuple wrt the following
  93. transformation -
  94. expression -> a/b -> a, b
  95. """
  96. return self, S.One
  97. def factor(self, *args, **kwargs):
  98. """
  99. Implements the SymPy factor routine, on the scalar parts
  100. of a basis-dependent expression.
  101. factor's documentation
  102. ========================
  103. """
  104. fctr_components = [fctr(v, *args, **kwargs) * k for
  105. k, v in self.components.items()]
  106. return self._add_func(*fctr_components)
  107. factor.__doc__ += fctr.__doc__ # type: ignore
  108. def as_coeff_Mul(self, rational=False):
  109. """Efficiently extract the coefficient of a product. """
  110. return (S.One, self)
  111. def as_coeff_add(self, *deps):
  112. """Efficiently extract the coefficient of a summation. """
  113. l = [x * self.components[x] for x in self.components]
  114. return 0, tuple(l)
  115. def diff(self, *args, **kwargs):
  116. """
  117. Implements the SymPy diff routine, for vectors.
  118. diff's documentation
  119. ========================
  120. """
  121. for x in args:
  122. if isinstance(x, BasisDependent):
  123. raise TypeError("Invalid arg for differentiation")
  124. diff_components = [df(v, *args, **kwargs) * k for
  125. k, v in self.components.items()]
  126. return self._add_func(*diff_components)
  127. diff.__doc__ += df.__doc__ # type: ignore
  128. def doit(self, **hints):
  129. """Calls .doit() on each term in the Dyadic"""
  130. doit_components = [self.components[x].doit(**hints) * x
  131. for x in self.components]
  132. return self._add_func(*doit_components)
  133. class BasisDependentAdd(BasisDependent, Add):
  134. """
  135. Denotes sum of basis dependent quantities such that they cannot
  136. be expressed as base or Mul instances.
  137. """
  138. def __new__(cls, *args, **options):
  139. components = {}
  140. # Check each arg and simultaneously learn the components
  141. for i, arg in enumerate(args):
  142. if not isinstance(arg, cls._expr_type):
  143. if isinstance(arg, Mul):
  144. arg = cls._mul_func(*(arg.args))
  145. elif isinstance(arg, Add):
  146. arg = cls._add_func(*(arg.args))
  147. else:
  148. raise TypeError(str(arg) +
  149. " cannot be interpreted correctly")
  150. # If argument is zero, ignore
  151. if arg == cls.zero:
  152. continue
  153. # Else, update components accordingly
  154. if hasattr(arg, "components"):
  155. for x in arg.components:
  156. components[x] = components.get(x, 0) + arg.components[x]
  157. temp = list(components.keys())
  158. for x in temp:
  159. if components[x] == 0:
  160. del components[x]
  161. # Handle case of zero vector
  162. if len(components) == 0:
  163. return cls.zero
  164. # Build object
  165. newargs = [x * components[x] for x in components]
  166. obj = super().__new__(cls, *newargs, **options)
  167. if isinstance(obj, Mul):
  168. return cls._mul_func(*obj.args)
  169. assumptions = {'commutative': True}
  170. obj._assumptions = StdFactKB(assumptions)
  171. obj._components = components
  172. obj._sys = (list(components.keys()))[0]._sys
  173. return obj
  174. class BasisDependentMul(BasisDependent, Mul):
  175. """
  176. Denotes product of base- basis dependent quantity with a scalar.
  177. """
  178. def __new__(cls, *args, **options):
  179. from sympy.vector import Cross, Dot, Curl, Gradient
  180. count = 0
  181. measure_number = S.One
  182. zeroflag = False
  183. extra_args = []
  184. # Determine the component and check arguments
  185. # Also keep a count to ensure two vectors aren't
  186. # being multiplied
  187. for arg in args:
  188. if isinstance(arg, cls._zero_func):
  189. count += 1
  190. zeroflag = True
  191. elif arg == S.Zero:
  192. zeroflag = True
  193. elif isinstance(arg, (cls._base_func, cls._mul_func)):
  194. count += 1
  195. expr = arg._base_instance
  196. measure_number *= arg._measure_number
  197. elif isinstance(arg, cls._add_func):
  198. count += 1
  199. expr = arg
  200. elif isinstance(arg, (Cross, Dot, Curl, Gradient)):
  201. extra_args.append(arg)
  202. else:
  203. measure_number *= arg
  204. # Make sure incompatible types weren't multiplied
  205. if count > 1:
  206. raise ValueError("Invalid multiplication")
  207. elif count == 0:
  208. return Mul(*args, **options)
  209. # Handle zero vector case
  210. if zeroflag:
  211. return cls.zero
  212. # If one of the args was a VectorAdd, return an
  213. # appropriate VectorAdd instance
  214. if isinstance(expr, cls._add_func):
  215. newargs = [cls._mul_func(measure_number, x) for
  216. x in expr.args]
  217. return cls._add_func(*newargs)
  218. obj = super().__new__(cls, measure_number,
  219. expr._base_instance,
  220. *extra_args,
  221. **options)
  222. if isinstance(obj, Add):
  223. return cls._add_func(*obj.args)
  224. obj._base_instance = expr._base_instance
  225. obj._measure_number = measure_number
  226. assumptions = {'commutative': True}
  227. obj._assumptions = StdFactKB(assumptions)
  228. obj._components = {expr._base_instance: measure_number}
  229. obj._sys = expr._base_instance._sys
  230. return obj
  231. def _sympystr(self, printer):
  232. measure_str = printer._print(self._measure_number)
  233. if ('(' in measure_str or '-' in measure_str or
  234. '+' in measure_str):
  235. measure_str = '(' + measure_str + ')'
  236. return measure_str + '*' + printer._print(self._base_instance)
  237. class BasisDependentZero(BasisDependent):
  238. """
  239. Class to denote a zero basis dependent instance.
  240. """
  241. # XXX: Can't type the keys as BaseVector because of cyclic import
  242. # problems.
  243. components = {} # type: tDict[Any, Expr]
  244. def __new__(cls):
  245. obj = super().__new__(cls)
  246. # Pre-compute a specific hash value for the zero vector
  247. # Use the same one always
  248. obj._hash = tuple([S.Zero, cls]).__hash__()
  249. return obj
  250. def __hash__(self):
  251. return self._hash
  252. @call_highest_priority('__req__')
  253. def __eq__(self, other):
  254. return isinstance(other, self._zero_func)
  255. __req__ = __eq__
  256. @call_highest_priority('__radd__')
  257. def __add__(self, other):
  258. if isinstance(other, self._expr_type):
  259. return other
  260. else:
  261. raise TypeError("Invalid argument types for addition")
  262. @call_highest_priority('__add__')
  263. def __radd__(self, other):
  264. if isinstance(other, self._expr_type):
  265. return other
  266. else:
  267. raise TypeError("Invalid argument types for addition")
  268. @call_highest_priority('__rsub__')
  269. def __sub__(self, other):
  270. if isinstance(other, self._expr_type):
  271. return -other
  272. else:
  273. raise TypeError("Invalid argument types for subtraction")
  274. @call_highest_priority('__sub__')
  275. def __rsub__(self, other):
  276. if isinstance(other, self._expr_type):
  277. return other
  278. else:
  279. raise TypeError("Invalid argument types for subtraction")
  280. def __neg__(self):
  281. return self
  282. def normalize(self):
  283. """
  284. Returns the normalized version of this vector.
  285. """
  286. return self
  287. def _sympystr(self, printer):
  288. return '0'