subscheck.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. from sympy.core import S, Pow
  2. from sympy.core.function import (Derivative, AppliedUndef, diff)
  3. from sympy.core.relational import Equality, Eq
  4. from sympy.core.symbol import Dummy
  5. from sympy.core.sympify import sympify
  6. from sympy.logic.boolalg import BooleanAtom
  7. from sympy.functions import exp
  8. from sympy.series import Order
  9. from sympy.simplify.simplify import simplify, posify, besselsimp
  10. from sympy.simplify.trigsimp import trigsimp
  11. from sympy.simplify.sqrtdenest import sqrtdenest
  12. from sympy.solvers import solve
  13. from sympy.solvers.deutils import _preprocess, ode_order
  14. from sympy.utilities.iterables import iterable, is_sequence
  15. def sub_func_doit(eq, func, new):
  16. r"""
  17. When replacing the func with something else, we usually want the
  18. derivative evaluated, so this function helps in making that happen.
  19. Examples
  20. ========
  21. >>> from sympy import Derivative, symbols, Function
  22. >>> from sympy.solvers.ode.subscheck import sub_func_doit
  23. >>> x, z = symbols('x, z')
  24. >>> y = Function('y')
  25. >>> sub_func_doit(3*Derivative(y(x), x) - 1, y(x), x)
  26. 2
  27. >>> sub_func_doit(x*Derivative(y(x), x) - y(x)**2 + y(x), y(x),
  28. ... 1/(x*(z + 1/x)))
  29. x*(-1/(x**2*(z + 1/x)) + 1/(x**3*(z + 1/x)**2)) + 1/(x*(z + 1/x))
  30. ...- 1/(x**2*(z + 1/x)**2)
  31. """
  32. reps= {func: new}
  33. for d in eq.atoms(Derivative):
  34. if d.expr == func:
  35. reps[d] = new.diff(*d.variable_count)
  36. else:
  37. reps[d] = d.xreplace({func: new}).doit(deep=False)
  38. return eq.xreplace(reps)
  39. def checkodesol(ode, sol, func=None, order='auto', solve_for_func=True):
  40. r"""
  41. Substitutes ``sol`` into ``ode`` and checks that the result is ``0``.
  42. This works when ``func`` is one function, like `f(x)` or a list of
  43. functions like `[f(x), g(x)]` when `ode` is a system of ODEs. ``sol`` can
  44. be a single solution or a list of solutions. Each solution may be an
  45. :py:class:`~sympy.core.relational.Equality` that the solution satisfies,
  46. e.g. ``Eq(f(x), C1), Eq(f(x) + C1, 0)``; or simply an
  47. :py:class:`~sympy.core.expr.Expr`, e.g. ``f(x) - C1``. In most cases it
  48. will not be necessary to explicitly identify the function, but if the
  49. function cannot be inferred from the original equation it can be supplied
  50. through the ``func`` argument.
  51. If a sequence of solutions is passed, the same sort of container will be
  52. used to return the result for each solution.
  53. It tries the following methods, in order, until it finds zero equivalence:
  54. 1. Substitute the solution for `f` in the original equation. This only
  55. works if ``ode`` is solved for `f`. It will attempt to solve it first
  56. unless ``solve_for_func == False``.
  57. 2. Take `n` derivatives of the solution, where `n` is the order of
  58. ``ode``, and check to see if that is equal to the solution. This only
  59. works on exact ODEs.
  60. 3. Take the 1st, 2nd, ..., `n`\th derivatives of the solution, each time
  61. solving for the derivative of `f` of that order (this will always be
  62. possible because `f` is a linear operator). Then back substitute each
  63. derivative into ``ode`` in reverse order.
  64. This function returns a tuple. The first item in the tuple is ``True`` if
  65. the substitution results in ``0``, and ``False`` otherwise. The second
  66. item in the tuple is what the substitution results in. It should always
  67. be ``0`` if the first item is ``True``. Sometimes this function will
  68. return ``False`` even when an expression is identically equal to ``0``.
  69. This happens when :py:meth:`~sympy.simplify.simplify.simplify` does not
  70. reduce the expression to ``0``. If an expression returned by this
  71. function vanishes identically, then ``sol`` really is a solution to
  72. the ``ode``.
  73. If this function seems to hang, it is probably because of a hard
  74. simplification.
  75. To use this function to test, test the first item of the tuple.
  76. Examples
  77. ========
  78. >>> from sympy import (Eq, Function, checkodesol, symbols,
  79. ... Derivative, exp)
  80. >>> x, C1, C2 = symbols('x,C1,C2')
  81. >>> f, g = symbols('f g', cls=Function)
  82. >>> checkodesol(f(x).diff(x), Eq(f(x), C1))
  83. (True, 0)
  84. >>> assert checkodesol(f(x).diff(x), C1)[0]
  85. >>> assert not checkodesol(f(x).diff(x), x)[0]
  86. >>> checkodesol(f(x).diff(x, 2), x**2)
  87. (False, 2)
  88. >>> eqs = [Eq(Derivative(f(x), x), f(x)), Eq(Derivative(g(x), x), g(x))]
  89. >>> sol = [Eq(f(x), C1*exp(x)), Eq(g(x), C2*exp(x))]
  90. >>> checkodesol(eqs, sol)
  91. (True, [0, 0])
  92. """
  93. if iterable(ode):
  94. return checksysodesol(ode, sol, func=func)
  95. if not isinstance(ode, Equality):
  96. ode = Eq(ode, 0)
  97. if func is None:
  98. try:
  99. _, func = _preprocess(ode.lhs)
  100. except ValueError:
  101. funcs = [s.atoms(AppliedUndef) for s in (
  102. sol if is_sequence(sol, set) else [sol])]
  103. funcs = set().union(*funcs)
  104. if len(funcs) != 1:
  105. raise ValueError(
  106. 'must pass func arg to checkodesol for this case.')
  107. func = funcs.pop()
  108. if not isinstance(func, AppliedUndef) or len(func.args) != 1:
  109. raise ValueError(
  110. "func must be a function of one variable, not %s" % func)
  111. if is_sequence(sol, set):
  112. return type(sol)([checkodesol(ode, i, order=order, solve_for_func=solve_for_func) for i in sol])
  113. if not isinstance(sol, Equality):
  114. sol = Eq(func, sol)
  115. elif sol.rhs == func:
  116. sol = sol.reversed
  117. if order == 'auto':
  118. order = ode_order(ode, func)
  119. solved = sol.lhs == func and not sol.rhs.has(func)
  120. if solve_for_func and not solved:
  121. rhs = solve(sol, func)
  122. if rhs:
  123. eqs = [Eq(func, t) for t in rhs]
  124. if len(rhs) == 1:
  125. eqs = eqs[0]
  126. return checkodesol(ode, eqs, order=order,
  127. solve_for_func=False)
  128. x = func.args[0]
  129. # Handle series solutions here
  130. if sol.has(Order):
  131. assert sol.lhs == func
  132. Oterm = sol.rhs.getO()
  133. solrhs = sol.rhs.removeO()
  134. Oexpr = Oterm.expr
  135. assert isinstance(Oexpr, Pow)
  136. sorder = Oexpr.exp
  137. assert Oterm == Order(x**sorder)
  138. odesubs = (ode.lhs-ode.rhs).subs(func, solrhs).doit().expand()
  139. neworder = Order(x**(sorder - order))
  140. odesubs = odesubs + neworder
  141. assert odesubs.getO() == neworder
  142. residual = odesubs.removeO()
  143. return (residual == 0, residual)
  144. s = True
  145. testnum = 0
  146. while s:
  147. if testnum == 0:
  148. # First pass, try substituting a solved solution directly into the
  149. # ODE. This has the highest chance of succeeding.
  150. ode_diff = ode.lhs - ode.rhs
  151. if sol.lhs == func:
  152. s = sub_func_doit(ode_diff, func, sol.rhs)
  153. s = besselsimp(s)
  154. else:
  155. testnum += 1
  156. continue
  157. ss = simplify(s.rewrite(exp))
  158. if ss:
  159. # with the new numer_denom in power.py, if we do a simple
  160. # expansion then testnum == 0 verifies all solutions.
  161. s = ss.expand(force=True)
  162. else:
  163. s = 0
  164. testnum += 1
  165. elif testnum == 1:
  166. # Second pass. If we cannot substitute f, try seeing if the nth
  167. # derivative is equal, this will only work for odes that are exact,
  168. # by definition.
  169. s = simplify(
  170. trigsimp(diff(sol.lhs, x, order) - diff(sol.rhs, x, order)) -
  171. trigsimp(ode.lhs) + trigsimp(ode.rhs))
  172. # s2 = simplify(
  173. # diff(sol.lhs, x, order) - diff(sol.rhs, x, order) - \
  174. # ode.lhs + ode.rhs)
  175. testnum += 1
  176. elif testnum == 2:
  177. # Third pass. Try solving for df/dx and substituting that into the
  178. # ODE. Thanks to Chris Smith for suggesting this method. Many of
  179. # the comments below are his, too.
  180. # The method:
  181. # - Take each of 1..n derivatives of the solution.
  182. # - Solve each nth derivative for d^(n)f/dx^(n)
  183. # (the differential of that order)
  184. # - Back substitute into the ODE in decreasing order
  185. # (i.e., n, n-1, ...)
  186. # - Check the result for zero equivalence
  187. if sol.lhs == func and not sol.rhs.has(func):
  188. diffsols = {0: sol.rhs}
  189. elif sol.rhs == func and not sol.lhs.has(func):
  190. diffsols = {0: sol.lhs}
  191. else:
  192. diffsols = {}
  193. sol = sol.lhs - sol.rhs
  194. for i in range(1, order + 1):
  195. # Differentiation is a linear operator, so there should always
  196. # be 1 solution. Nonetheless, we test just to make sure.
  197. # We only need to solve once. After that, we automatically
  198. # have the solution to the differential in the order we want.
  199. if i == 1:
  200. ds = sol.diff(x)
  201. try:
  202. sdf = solve(ds, func.diff(x, i))
  203. if not sdf:
  204. raise NotImplementedError
  205. except NotImplementedError:
  206. testnum += 1
  207. break
  208. else:
  209. diffsols[i] = sdf[0]
  210. else:
  211. # This is what the solution says df/dx should be.
  212. diffsols[i] = diffsols[i - 1].diff(x)
  213. # Make sure the above didn't fail.
  214. if testnum > 2:
  215. continue
  216. else:
  217. # Substitute it into ODE to check for self consistency.
  218. lhs, rhs = ode.lhs, ode.rhs
  219. for i in range(order, -1, -1):
  220. if i == 0 and 0 not in diffsols:
  221. # We can only substitute f(x) if the solution was
  222. # solved for f(x).
  223. break
  224. lhs = sub_func_doit(lhs, func.diff(x, i), diffsols[i])
  225. rhs = sub_func_doit(rhs, func.diff(x, i), diffsols[i])
  226. ode_or_bool = Eq(lhs, rhs)
  227. ode_or_bool = simplify(ode_or_bool)
  228. if isinstance(ode_or_bool, (bool, BooleanAtom)):
  229. if ode_or_bool:
  230. lhs = rhs = S.Zero
  231. else:
  232. lhs = ode_or_bool.lhs
  233. rhs = ode_or_bool.rhs
  234. # No sense in overworking simplify -- just prove that the
  235. # numerator goes to zero
  236. num = trigsimp((lhs - rhs).as_numer_denom()[0])
  237. # since solutions are obtained using force=True we test
  238. # using the same level of assumptions
  239. ## replace function with dummy so assumptions will work
  240. _func = Dummy('func')
  241. num = num.subs(func, _func)
  242. ## posify the expression
  243. num, reps = posify(num)
  244. s = simplify(num).xreplace(reps).xreplace({_func: func})
  245. testnum += 1
  246. else:
  247. break
  248. if not s:
  249. return (True, s)
  250. elif s is True: # The code above never was able to change s
  251. raise NotImplementedError("Unable to test if " + str(sol) +
  252. " is a solution to " + str(ode) + ".")
  253. else:
  254. return (False, s)
  255. def checksysodesol(eqs, sols, func=None):
  256. r"""
  257. Substitutes corresponding ``sols`` for each functions into each ``eqs`` and
  258. checks that the result of substitutions for each equation is ``0``. The
  259. equations and solutions passed can be any iterable.
  260. This only works when each ``sols`` have one function only, like `x(t)` or `y(t)`.
  261. For each function, ``sols`` can have a single solution or a list of solutions.
  262. In most cases it will not be necessary to explicitly identify the function,
  263. but if the function cannot be inferred from the original equation it
  264. can be supplied through the ``func`` argument.
  265. When a sequence of equations is passed, the same sequence is used to return
  266. the result for each equation with each function substituted with corresponding
  267. solutions.
  268. It tries the following method to find zero equivalence for each equation:
  269. Substitute the solutions for functions, like `x(t)` and `y(t)` into the
  270. original equations containing those functions.
  271. This function returns a tuple. The first item in the tuple is ``True`` if
  272. the substitution results for each equation is ``0``, and ``False`` otherwise.
  273. The second item in the tuple is what the substitution results in. Each element
  274. of the ``list`` should always be ``0`` corresponding to each equation if the
  275. first item is ``True``. Note that sometimes this function may return ``False``,
  276. but with an expression that is identically equal to ``0``, instead of returning
  277. ``True``. This is because :py:meth:`~sympy.simplify.simplify.simplify` cannot
  278. reduce the expression to ``0``. If an expression returned by each function
  279. vanishes identically, then ``sols`` really is a solution to ``eqs``.
  280. If this function seems to hang, it is probably because of a difficult simplification.
  281. Examples
  282. ========
  283. >>> from sympy import Eq, diff, symbols, sin, cos, exp, sqrt, S, Function
  284. >>> from sympy.solvers.ode.subscheck import checksysodesol
  285. >>> C1, C2 = symbols('C1:3')
  286. >>> t = symbols('t')
  287. >>> x, y = symbols('x, y', cls=Function)
  288. >>> eq = (Eq(diff(x(t),t), x(t) + y(t) + 17), Eq(diff(y(t),t), -2*x(t) + y(t) + 12))
  289. >>> sol = [Eq(x(t), (C1*sin(sqrt(2)*t) + C2*cos(sqrt(2)*t))*exp(t) - S(5)/3),
  290. ... Eq(y(t), (sqrt(2)*C1*cos(sqrt(2)*t) - sqrt(2)*C2*sin(sqrt(2)*t))*exp(t) - S(46)/3)]
  291. >>> checksysodesol(eq, sol)
  292. (True, [0, 0])
  293. >>> eq = (Eq(diff(x(t),t),x(t)*y(t)**4), Eq(diff(y(t),t),y(t)**3))
  294. >>> sol = [Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), -sqrt(2)*sqrt(-1/(C2 + t))/2),
  295. ... Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), sqrt(2)*sqrt(-1/(C2 + t))/2)]
  296. >>> checksysodesol(eq, sol)
  297. (True, [0, 0])
  298. """
  299. def _sympify(eq):
  300. return list(map(sympify, eq if iterable(eq) else [eq]))
  301. eqs = _sympify(eqs)
  302. for i in range(len(eqs)):
  303. if isinstance(eqs[i], Equality):
  304. eqs[i] = eqs[i].lhs - eqs[i].rhs
  305. if func is None:
  306. funcs = []
  307. for eq in eqs:
  308. derivs = eq.atoms(Derivative)
  309. func = set().union(*[d.atoms(AppliedUndef) for d in derivs])
  310. for func_ in func:
  311. funcs.append(func_)
  312. funcs = list(set(funcs))
  313. if not all(isinstance(func, AppliedUndef) and len(func.args) == 1 for func in funcs)\
  314. and len({func.args for func in funcs})!=1:
  315. raise ValueError("func must be a function of one variable, not %s" % func)
  316. for sol in sols:
  317. if len(sol.atoms(AppliedUndef)) != 1:
  318. raise ValueError("solutions should have one function only")
  319. if len(funcs) != len({sol.lhs for sol in sols}):
  320. raise ValueError("number of solutions provided does not match the number of equations")
  321. dictsol = dict()
  322. for sol in sols:
  323. func = list(sol.atoms(AppliedUndef))[0]
  324. if sol.rhs == func:
  325. sol = sol.reversed
  326. solved = sol.lhs == func and not sol.rhs.has(func)
  327. if not solved:
  328. rhs = solve(sol, func)
  329. if not rhs:
  330. raise NotImplementedError
  331. else:
  332. rhs = sol.rhs
  333. dictsol[func] = rhs
  334. checkeq = []
  335. for eq in eqs:
  336. for func in funcs:
  337. eq = sub_func_doit(eq, func, dictsol[func])
  338. ss = simplify(eq)
  339. if ss != 0:
  340. eq = ss.expand(force=True)
  341. if eq != 0:
  342. eq = sqrtdenest(eq).simplify()
  343. else:
  344. eq = 0
  345. checkeq.append(eq)
  346. if len(set(checkeq)) == 1 and list(set(checkeq))[0] == 0:
  347. return (True, checkeq)
  348. else:
  349. return (False, checkeq)