deutils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. """Utility functions for classifying and solving
  2. ordinary and partial differential equations.
  3. Contains
  4. ========
  5. _preprocess
  6. ode_order
  7. _desolve
  8. """
  9. from sympy.core import Pow
  10. from sympy.core.function import Derivative, AppliedUndef
  11. from sympy.core.relational import Equality
  12. from sympy.core.symbol import Wild
  13. def _preprocess(expr, func=None, hint='_Integral'):
  14. """Prepare expr for solving by making sure that differentiation
  15. is done so that only func remains in unevaluated derivatives and
  16. (if hint doesn't end with _Integral) that doit is applied to all
  17. other derivatives. If hint is None, don't do any differentiation.
  18. (Currently this may cause some simple differential equations to
  19. fail.)
  20. In case func is None, an attempt will be made to autodetect the
  21. function to be solved for.
  22. >>> from sympy.solvers.deutils import _preprocess
  23. >>> from sympy import Derivative, Function
  24. >>> from sympy.abc import x, y, z
  25. >>> f, g = map(Function, 'fg')
  26. If f(x)**p == 0 and p>0 then we can solve for f(x)=0
  27. >>> _preprocess((f(x).diff(x)-4)**5, f(x))
  28. (Derivative(f(x), x) - 4, f(x))
  29. Apply doit to derivatives that contain more than the function
  30. of interest:
  31. >>> _preprocess(Derivative(f(x) + x, x))
  32. (Derivative(f(x), x) + 1, f(x))
  33. Do others if the differentiation variable(s) intersect with those
  34. of the function of interest or contain the function of interest:
  35. >>> _preprocess(Derivative(g(x), y, z), f(y))
  36. (0, f(y))
  37. >>> _preprocess(Derivative(f(y), z), f(y))
  38. (0, f(y))
  39. Do others if the hint doesn't end in '_Integral' (the default
  40. assumes that it does):
  41. >>> _preprocess(Derivative(g(x), y), f(x))
  42. (Derivative(g(x), y), f(x))
  43. >>> _preprocess(Derivative(f(x), y), f(x), hint='')
  44. (0, f(x))
  45. Don't do any derivatives if hint is None:
  46. >>> eq = Derivative(f(x) + 1, x) + Derivative(f(x), y)
  47. >>> _preprocess(eq, f(x), hint=None)
  48. (Derivative(f(x) + 1, x) + Derivative(f(x), y), f(x))
  49. If it's not clear what the function of interest is, it must be given:
  50. >>> eq = Derivative(f(x) + g(x), x)
  51. >>> _preprocess(eq, g(x))
  52. (Derivative(f(x), x) + Derivative(g(x), x), g(x))
  53. >>> try: _preprocess(eq)
  54. ... except ValueError: print("A ValueError was raised.")
  55. A ValueError was raised.
  56. """
  57. if isinstance(expr, Pow):
  58. # if f(x)**p=0 then f(x)=0 (p>0)
  59. if (expr.exp).is_positive:
  60. expr = expr.base
  61. derivs = expr.atoms(Derivative)
  62. if not func:
  63. funcs = set().union(*[d.atoms(AppliedUndef) for d in derivs])
  64. if len(funcs) != 1:
  65. raise ValueError('The function cannot be '
  66. 'automatically detected for %s.' % expr)
  67. func = funcs.pop()
  68. fvars = set(func.args)
  69. if hint is None:
  70. return expr, func
  71. reps = [(d, d.doit()) for d in derivs if not hint.endswith('_Integral') or
  72. d.has(func) or set(d.variables) & fvars]
  73. eq = expr.subs(reps)
  74. return eq, func
  75. def ode_order(expr, func):
  76. """
  77. Returns the order of a given differential
  78. equation with respect to func.
  79. This function is implemented recursively.
  80. Examples
  81. ========
  82. >>> from sympy import Function
  83. >>> from sympy.solvers.deutils import ode_order
  84. >>> from sympy.abc import x
  85. >>> f, g = map(Function, ['f', 'g'])
  86. >>> ode_order(f(x).diff(x, 2) + f(x).diff(x)**2 +
  87. ... f(x).diff(x), f(x))
  88. 2
  89. >>> ode_order(f(x).diff(x, 2) + g(x).diff(x, 3), f(x))
  90. 2
  91. >>> ode_order(f(x).diff(x, 2) + g(x).diff(x, 3), g(x))
  92. 3
  93. """
  94. a = Wild('a', exclude=[func])
  95. if expr.match(a):
  96. return 0
  97. if isinstance(expr, Derivative):
  98. if expr.args[0] == func:
  99. return len(expr.variables)
  100. else:
  101. order = 0
  102. for arg in expr.args[0].args:
  103. order = max(order, ode_order(arg, func) + len(expr.variables))
  104. return order
  105. else:
  106. order = 0
  107. for arg in expr.args:
  108. order = max(order, ode_order(arg, func))
  109. return order
  110. def _desolve(eq, func=None, hint="default", ics=None, simplify=True, *, prep=True, **kwargs):
  111. """This is a helper function to dsolve and pdsolve in the ode
  112. and pde modules.
  113. If the hint provided to the function is "default", then a dict with
  114. the following keys are returned
  115. 'func' - It provides the function for which the differential equation
  116. has to be solved. This is useful when the expression has
  117. more than one function in it.
  118. 'default' - The default key as returned by classifier functions in ode
  119. and pde.py
  120. 'hint' - The hint given by the user for which the differential equation
  121. is to be solved. If the hint given by the user is 'default',
  122. then the value of 'hint' and 'default' is the same.
  123. 'order' - The order of the function as returned by ode_order
  124. 'match' - It returns the match as given by the classifier functions, for
  125. the default hint.
  126. If the hint provided to the function is not "default" and is not in
  127. ('all', 'all_Integral', 'best'), then a dict with the above mentioned keys
  128. is returned along with the keys which are returned when dict in
  129. classify_ode or classify_pde is set True
  130. If the hint given is in ('all', 'all_Integral', 'best'), then this function
  131. returns a nested dict, with the keys, being the set of classified hints
  132. returned by classifier functions, and the values being the dict of form
  133. as mentioned above.
  134. Key 'eq' is a common key to all the above mentioned hints which returns an
  135. expression if eq given by user is an Equality.
  136. See Also
  137. ========
  138. classify_ode(ode.py)
  139. classify_pde(pde.py)
  140. """
  141. if isinstance(eq, Equality):
  142. eq = eq.lhs - eq.rhs
  143. # preprocess the equation and find func if not given
  144. if prep or func is None:
  145. eq, func = _preprocess(eq, func)
  146. prep = False
  147. # type is an argument passed by the solve functions in ode and pde.py
  148. # that identifies whether the function caller is an ordinary
  149. # or partial differential equation. Accordingly corresponding
  150. # changes are made in the function.
  151. type = kwargs.get('type', None)
  152. xi = kwargs.get('xi')
  153. eta = kwargs.get('eta')
  154. x0 = kwargs.get('x0', 0)
  155. terms = kwargs.get('n')
  156. if type == 'ode':
  157. from sympy.solvers.ode import classify_ode, allhints
  158. classifier = classify_ode
  159. string = 'ODE '
  160. dummy = ''
  161. elif type == 'pde':
  162. from sympy.solvers.pde import classify_pde, allhints
  163. classifier = classify_pde
  164. string = 'PDE '
  165. dummy = 'p'
  166. # Magic that should only be used internally. Prevents classify_ode from
  167. # being called more than it needs to be by passing its results through
  168. # recursive calls.
  169. if kwargs.get('classify', True):
  170. hints = classifier(eq, func, dict=True, ics=ics, xi=xi, eta=eta,
  171. n=terms, x0=x0, hint=hint, prep=prep)
  172. else:
  173. # Here is what all this means:
  174. #
  175. # hint: The hint method given to _desolve() by the user.
  176. # hints: The dictionary of hints that match the DE, along with other
  177. # information (including the internal pass-through magic).
  178. # default: The default hint to return, the first hint from allhints
  179. # that matches the hint; obtained from classify_ode().
  180. # match: Dictionary containing the match dictionary for each hint
  181. # (the parts of the DE for solving). When going through the
  182. # hints in "all", this holds the match string for the current
  183. # hint.
  184. # order: The order of the DE, as determined by ode_order().
  185. hints = kwargs.get('hint',
  186. {'default': hint,
  187. hint: kwargs['match'],
  188. 'order': kwargs['order']})
  189. if not hints['default']:
  190. # classify_ode will set hints['default'] to None if no hints match.
  191. if hint not in allhints and hint != 'default':
  192. raise ValueError("Hint not recognized: " + hint)
  193. elif hint not in hints['ordered_hints'] and hint != 'default':
  194. raise ValueError(string + str(eq) + " does not match hint " + hint)
  195. # If dsolve can't solve the purely algebraic equation then dsolve will raise
  196. # ValueError
  197. elif hints['order'] == 0:
  198. raise ValueError(
  199. str(eq) + " is not a solvable differential equation in " + str(func))
  200. else:
  201. raise NotImplementedError(dummy + "solve" + ": Cannot solve " + str(eq))
  202. if hint == 'default':
  203. return _desolve(eq, func, ics=ics, hint=hints['default'], simplify=simplify,
  204. prep=prep, x0=x0, classify=False, order=hints['order'],
  205. match=hints[hints['default']], xi=xi, eta=eta, n=terms, type=type)
  206. elif hint in ('all', 'all_Integral', 'best'):
  207. retdict = {}
  208. gethints = set(hints) - {'order', 'default', 'ordered_hints'}
  209. if hint == 'all_Integral':
  210. for i in hints:
  211. if i.endswith('_Integral'):
  212. gethints.remove(i[:-len('_Integral')])
  213. # special cases
  214. for k in ["1st_homogeneous_coeff_best", "1st_power_series",
  215. "lie_group", "2nd_power_series_ordinary", "2nd_power_series_regular"]:
  216. if k in gethints:
  217. gethints.remove(k)
  218. for i in gethints:
  219. sol = _desolve(eq, func, ics=ics, hint=i, x0=x0, simplify=simplify, prep=prep,
  220. classify=False, n=terms, order=hints['order'], match=hints[i], type=type)
  221. retdict[i] = sol
  222. retdict['all'] = True
  223. retdict['eq'] = eq
  224. return retdict
  225. elif hint not in allhints: # and hint not in ('default', 'ordered_hints'):
  226. raise ValueError("Hint not recognized: " + hint)
  227. elif hint not in hints:
  228. raise ValueError(string + str(eq) + " does not match hint " + hint)
  229. else:
  230. # Key added to identify the hint needed to solve the equation
  231. hints['hint'] = hint
  232. hints.update({'func': func, 'eq': eq})
  233. return hints