123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393 |
- from sympy.core import S, Pow
- from sympy.core.function import (Derivative, AppliedUndef, diff)
- from sympy.core.relational import Equality, Eq
- from sympy.core.symbol import Dummy
- from sympy.core.sympify import sympify
- from sympy.logic.boolalg import BooleanAtom
- from sympy.functions import exp
- from sympy.series import Order
- from sympy.simplify.simplify import simplify, posify, besselsimp
- from sympy.simplify.trigsimp import trigsimp
- from sympy.simplify.sqrtdenest import sqrtdenest
- from sympy.solvers import solve
- from sympy.solvers.deutils import _preprocess, ode_order
- from sympy.utilities.iterables import iterable, is_sequence
- def sub_func_doit(eq, func, new):
- r"""
- When replacing the func with something else, we usually want the
- derivative evaluated, so this function helps in making that happen.
- Examples
- ========
- >>> from sympy import Derivative, symbols, Function
- >>> from sympy.solvers.ode.subscheck import sub_func_doit
- >>> x, z = symbols('x, z')
- >>> y = Function('y')
- >>> sub_func_doit(3*Derivative(y(x), x) - 1, y(x), x)
- 2
- >>> sub_func_doit(x*Derivative(y(x), x) - y(x)**2 + y(x), y(x),
- ... 1/(x*(z + 1/x)))
- x*(-1/(x**2*(z + 1/x)) + 1/(x**3*(z + 1/x)**2)) + 1/(x*(z + 1/x))
- ...- 1/(x**2*(z + 1/x)**2)
- """
- reps= {func: new}
- for d in eq.atoms(Derivative):
- if d.expr == func:
- reps[d] = new.diff(*d.variable_count)
- else:
- reps[d] = d.xreplace({func: new}).doit(deep=False)
- return eq.xreplace(reps)
- def checkodesol(ode, sol, func=None, order='auto', solve_for_func=True):
- r"""
- Substitutes ``sol`` into ``ode`` and checks that the result is ``0``.
- This works when ``func`` is one function, like `f(x)` or a list of
- functions like `[f(x), g(x)]` when `ode` is a system of ODEs. ``sol`` can
- be a single solution or a list of solutions. Each solution may be an
- :py:class:`~sympy.core.relational.Equality` that the solution satisfies,
- e.g. ``Eq(f(x), C1), Eq(f(x) + C1, 0)``; or simply an
- :py:class:`~sympy.core.expr.Expr`, e.g. ``f(x) - C1``. In most cases it
- will not be necessary to explicitly identify the function, but if the
- function cannot be inferred from the original equation it can be supplied
- through the ``func`` argument.
- If a sequence of solutions is passed, the same sort of container will be
- used to return the result for each solution.
- It tries the following methods, in order, until it finds zero equivalence:
- 1. Substitute the solution for `f` in the original equation. This only
- works if ``ode`` is solved for `f`. It will attempt to solve it first
- unless ``solve_for_func == False``.
- 2. Take `n` derivatives of the solution, where `n` is the order of
- ``ode``, and check to see if that is equal to the solution. This only
- works on exact ODEs.
- 3. Take the 1st, 2nd, ..., `n`\th derivatives of the solution, each time
- solving for the derivative of `f` of that order (this will always be
- possible because `f` is a linear operator). Then back substitute each
- derivative into ``ode`` in reverse order.
- This function returns a tuple. The first item in the tuple is ``True`` if
- the substitution results in ``0``, and ``False`` otherwise. The second
- item in the tuple is what the substitution results in. It should always
- be ``0`` if the first item is ``True``. Sometimes this function will
- return ``False`` even when an expression is identically equal to ``0``.
- This happens when :py:meth:`~sympy.simplify.simplify.simplify` does not
- reduce the expression to ``0``. If an expression returned by this
- function vanishes identically, then ``sol`` really is a solution to
- the ``ode``.
- If this function seems to hang, it is probably because of a hard
- simplification.
- To use this function to test, test the first item of the tuple.
- Examples
- ========
- >>> from sympy import (Eq, Function, checkodesol, symbols,
- ... Derivative, exp)
- >>> x, C1, C2 = symbols('x,C1,C2')
- >>> f, g = symbols('f g', cls=Function)
- >>> checkodesol(f(x).diff(x), Eq(f(x), C1))
- (True, 0)
- >>> assert checkodesol(f(x).diff(x), C1)[0]
- >>> assert not checkodesol(f(x).diff(x), x)[0]
- >>> checkodesol(f(x).diff(x, 2), x**2)
- (False, 2)
- >>> eqs = [Eq(Derivative(f(x), x), f(x)), Eq(Derivative(g(x), x), g(x))]
- >>> sol = [Eq(f(x), C1*exp(x)), Eq(g(x), C2*exp(x))]
- >>> checkodesol(eqs, sol)
- (True, [0, 0])
- """
- if iterable(ode):
- return checksysodesol(ode, sol, func=func)
- if not isinstance(ode, Equality):
- ode = Eq(ode, 0)
- if func is None:
- try:
- _, func = _preprocess(ode.lhs)
- except ValueError:
- funcs = [s.atoms(AppliedUndef) for s in (
- sol if is_sequence(sol, set) else [sol])]
- funcs = set().union(*funcs)
- if len(funcs) != 1:
- raise ValueError(
- 'must pass func arg to checkodesol for this case.')
- func = funcs.pop()
- if not isinstance(func, AppliedUndef) or len(func.args) != 1:
- raise ValueError(
- "func must be a function of one variable, not %s" % func)
- if is_sequence(sol, set):
- return type(sol)([checkodesol(ode, i, order=order, solve_for_func=solve_for_func) for i in sol])
- if not isinstance(sol, Equality):
- sol = Eq(func, sol)
- elif sol.rhs == func:
- sol = sol.reversed
- if order == 'auto':
- order = ode_order(ode, func)
- solved = sol.lhs == func and not sol.rhs.has(func)
- if solve_for_func and not solved:
- rhs = solve(sol, func)
- if rhs:
- eqs = [Eq(func, t) for t in rhs]
- if len(rhs) == 1:
- eqs = eqs[0]
- return checkodesol(ode, eqs, order=order,
- solve_for_func=False)
- x = func.args[0]
- # Handle series solutions here
- if sol.has(Order):
- assert sol.lhs == func
- Oterm = sol.rhs.getO()
- solrhs = sol.rhs.removeO()
- Oexpr = Oterm.expr
- assert isinstance(Oexpr, Pow)
- sorder = Oexpr.exp
- assert Oterm == Order(x**sorder)
- odesubs = (ode.lhs-ode.rhs).subs(func, solrhs).doit().expand()
- neworder = Order(x**(sorder - order))
- odesubs = odesubs + neworder
- assert odesubs.getO() == neworder
- residual = odesubs.removeO()
- return (residual == 0, residual)
- s = True
- testnum = 0
- while s:
- if testnum == 0:
- # First pass, try substituting a solved solution directly into the
- # ODE. This has the highest chance of succeeding.
- ode_diff = ode.lhs - ode.rhs
- if sol.lhs == func:
- s = sub_func_doit(ode_diff, func, sol.rhs)
- s = besselsimp(s)
- else:
- testnum += 1
- continue
- ss = simplify(s.rewrite(exp))
- if ss:
- # with the new numer_denom in power.py, if we do a simple
- # expansion then testnum == 0 verifies all solutions.
- s = ss.expand(force=True)
- else:
- s = 0
- testnum += 1
- elif testnum == 1:
- # Second pass. If we cannot substitute f, try seeing if the nth
- # derivative is equal, this will only work for odes that are exact,
- # by definition.
- s = simplify(
- trigsimp(diff(sol.lhs, x, order) - diff(sol.rhs, x, order)) -
- trigsimp(ode.lhs) + trigsimp(ode.rhs))
- # s2 = simplify(
- # diff(sol.lhs, x, order) - diff(sol.rhs, x, order) - \
- # ode.lhs + ode.rhs)
- testnum += 1
- elif testnum == 2:
- # Third pass. Try solving for df/dx and substituting that into the
- # ODE. Thanks to Chris Smith for suggesting this method. Many of
- # the comments below are his, too.
- # The method:
- # - Take each of 1..n derivatives of the solution.
- # - Solve each nth derivative for d^(n)f/dx^(n)
- # (the differential of that order)
- # - Back substitute into the ODE in decreasing order
- # (i.e., n, n-1, ...)
- # - Check the result for zero equivalence
- if sol.lhs == func and not sol.rhs.has(func):
- diffsols = {0: sol.rhs}
- elif sol.rhs == func and not sol.lhs.has(func):
- diffsols = {0: sol.lhs}
- else:
- diffsols = {}
- sol = sol.lhs - sol.rhs
- for i in range(1, order + 1):
- # Differentiation is a linear operator, so there should always
- # be 1 solution. Nonetheless, we test just to make sure.
- # We only need to solve once. After that, we automatically
- # have the solution to the differential in the order we want.
- if i == 1:
- ds = sol.diff(x)
- try:
- sdf = solve(ds, func.diff(x, i))
- if not sdf:
- raise NotImplementedError
- except NotImplementedError:
- testnum += 1
- break
- else:
- diffsols[i] = sdf[0]
- else:
- # This is what the solution says df/dx should be.
- diffsols[i] = diffsols[i - 1].diff(x)
- # Make sure the above didn't fail.
- if testnum > 2:
- continue
- else:
- # Substitute it into ODE to check for self consistency.
- lhs, rhs = ode.lhs, ode.rhs
- for i in range(order, -1, -1):
- if i == 0 and 0 not in diffsols:
- # We can only substitute f(x) if the solution was
- # solved for f(x).
- break
- lhs = sub_func_doit(lhs, func.diff(x, i), diffsols[i])
- rhs = sub_func_doit(rhs, func.diff(x, i), diffsols[i])
- ode_or_bool = Eq(lhs, rhs)
- ode_or_bool = simplify(ode_or_bool)
- if isinstance(ode_or_bool, (bool, BooleanAtom)):
- if ode_or_bool:
- lhs = rhs = S.Zero
- else:
- lhs = ode_or_bool.lhs
- rhs = ode_or_bool.rhs
- # No sense in overworking simplify -- just prove that the
- # numerator goes to zero
- num = trigsimp((lhs - rhs).as_numer_denom()[0])
- # since solutions are obtained using force=True we test
- # using the same level of assumptions
- ## replace function with dummy so assumptions will work
- _func = Dummy('func')
- num = num.subs(func, _func)
- ## posify the expression
- num, reps = posify(num)
- s = simplify(num).xreplace(reps).xreplace({_func: func})
- testnum += 1
- else:
- break
- if not s:
- return (True, s)
- elif s is True: # The code above never was able to change s
- raise NotImplementedError("Unable to test if " + str(sol) +
- " is a solution to " + str(ode) + ".")
- else:
- return (False, s)
- def checksysodesol(eqs, sols, func=None):
- r"""
- Substitutes corresponding ``sols`` for each functions into each ``eqs`` and
- checks that the result of substitutions for each equation is ``0``. The
- equations and solutions passed can be any iterable.
- This only works when each ``sols`` have one function only, like `x(t)` or `y(t)`.
- For each function, ``sols`` can have a single solution or a list of solutions.
- In most cases it will not be necessary to explicitly identify the function,
- but if the function cannot be inferred from the original equation it
- can be supplied through the ``func`` argument.
- When a sequence of equations is passed, the same sequence is used to return
- the result for each equation with each function substituted with corresponding
- solutions.
- It tries the following method to find zero equivalence for each equation:
- Substitute the solutions for functions, like `x(t)` and `y(t)` into the
- original equations containing those functions.
- This function returns a tuple. The first item in the tuple is ``True`` if
- the substitution results for each equation is ``0``, and ``False`` otherwise.
- The second item in the tuple is what the substitution results in. Each element
- of the ``list`` should always be ``0`` corresponding to each equation if the
- first item is ``True``. Note that sometimes this function may return ``False``,
- but with an expression that is identically equal to ``0``, instead of returning
- ``True``. This is because :py:meth:`~sympy.simplify.simplify.simplify` cannot
- reduce the expression to ``0``. If an expression returned by each function
- vanishes identically, then ``sols`` really is a solution to ``eqs``.
- If this function seems to hang, it is probably because of a difficult simplification.
- Examples
- ========
- >>> from sympy import Eq, diff, symbols, sin, cos, exp, sqrt, S, Function
- >>> from sympy.solvers.ode.subscheck import checksysodesol
- >>> C1, C2 = symbols('C1:3')
- >>> t = symbols('t')
- >>> x, y = symbols('x, y', cls=Function)
- >>> eq = (Eq(diff(x(t),t), x(t) + y(t) + 17), Eq(diff(y(t),t), -2*x(t) + y(t) + 12))
- >>> sol = [Eq(x(t), (C1*sin(sqrt(2)*t) + C2*cos(sqrt(2)*t))*exp(t) - S(5)/3),
- ... Eq(y(t), (sqrt(2)*C1*cos(sqrt(2)*t) - sqrt(2)*C2*sin(sqrt(2)*t))*exp(t) - S(46)/3)]
- >>> checksysodesol(eq, sol)
- (True, [0, 0])
- >>> eq = (Eq(diff(x(t),t),x(t)*y(t)**4), Eq(diff(y(t),t),y(t)**3))
- >>> sol = [Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), -sqrt(2)*sqrt(-1/(C2 + t))/2),
- ... Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), sqrt(2)*sqrt(-1/(C2 + t))/2)]
- >>> checksysodesol(eq, sol)
- (True, [0, 0])
- """
- def _sympify(eq):
- return list(map(sympify, eq if iterable(eq) else [eq]))
- eqs = _sympify(eqs)
- for i in range(len(eqs)):
- if isinstance(eqs[i], Equality):
- eqs[i] = eqs[i].lhs - eqs[i].rhs
- if func is None:
- funcs = []
- for eq in eqs:
- derivs = eq.atoms(Derivative)
- func = set().union(*[d.atoms(AppliedUndef) for d in derivs])
- for func_ in func:
- funcs.append(func_)
- funcs = list(set(funcs))
- if not all(isinstance(func, AppliedUndef) and len(func.args) == 1 for func in funcs)\
- and len({func.args for func in funcs})!=1:
- raise ValueError("func must be a function of one variable, not %s" % func)
- for sol in sols:
- if len(sol.atoms(AppliedUndef)) != 1:
- raise ValueError("solutions should have one function only")
- if len(funcs) != len({sol.lhs for sol in sols}):
- raise ValueError("number of solutions provided does not match the number of equations")
- dictsol = dict()
- for sol in sols:
- func = list(sol.atoms(AppliedUndef))[0]
- if sol.rhs == func:
- sol = sol.reversed
- solved = sol.lhs == func and not sol.rhs.has(func)
- if not solved:
- rhs = solve(sol, func)
- if not rhs:
- raise NotImplementedError
- else:
- rhs = sol.rhs
- dictsol[func] = rhs
- checkeq = []
- for eq in eqs:
- for func in funcs:
- eq = sub_func_doit(eq, func, dictsol[func])
- ss = simplify(eq)
- if ss != 0:
- eq = ss.expand(force=True)
- if eq != 0:
- eq = sqrtdenest(eq).simplify()
- else:
- eq = 0
- checkeq.append(eq)
- if len(set(checkeq)) == 1 and list(set(checkeq))[0] == 0:
- return (True, checkeq)
- else:
- return (False, checkeq)
|