123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499 |
- r"""
- This File contains helper functions for nth_linear_constant_coeff_undetermined_coefficients,
- nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients,
- nth_linear_constant_coeff_variation_of_parameters,
- and nth_linear_euler_eq_nonhomogeneous_variation_of_parameters.
- All the functions in this file are used by more than one solvers so, instead of creating
- instances in other classes for using them it is better to keep it here as separate helpers.
- """
- from collections import defaultdict
- from sympy.core import Add, S
- from sympy.core.function import diff, expand, _mexpand, expand_mul
- from sympy.core.relational import Eq
- from sympy.core.sorting import default_sort_key
- from sympy.core.symbol import Dummy, Wild
- from sympy.functions import exp, cos, cosh, im, log, re, sin, sinh, \
- atan2, conjugate
- from sympy.integrals import Integral
- from sympy.polys import (Poly, RootOf, rootof, roots)
- from sympy.simplify import collect, simplify, separatevars, powsimp, trigsimp # type: ignore
- from sympy.utilities import numbered_symbols
- from sympy.solvers.solvers import solve
- from sympy.matrices import wronskian
- from .subscheck import sub_func_doit
- from sympy.solvers.ode.ode import get_numbered_constants
- def _test_term(coeff, func, order):
- r"""
- Linear Euler ODEs have the form K*x**order*diff(y(x), x, order) = F(x),
- where K is independent of x and y(x), order>= 0.
- So we need to check that for each term, coeff == K*x**order from
- some K. We have a few cases, since coeff may have several
- different types.
- """
- x = func.args[0]
- f = func.func
- if order < 0:
- raise ValueError("order should be greater than 0")
- if coeff == 0:
- return True
- if order == 0:
- if x in coeff.free_symbols:
- return False
- return True
- if coeff.is_Mul:
- if coeff.has(f(x)):
- return False
- return x**order in coeff.args
- elif coeff.is_Pow:
- return coeff.as_base_exp() == (x, order)
- elif order == 1:
- return x == coeff
- return False
- def _get_euler_characteristic_eq_sols(eq, func, match_obj):
- r"""
- Returns the solution of homogeneous part of the linear euler ODE and
- the list of roots of characteristic equation.
- The parameter ``match_obj`` is a dict of order:coeff terms, where order is the order
- of the derivative on each term, and coeff is the coefficient of that derivative.
- """
- x = func.args[0]
- f = func.func
- # First, set up characteristic equation.
- chareq, symbol = S.Zero, Dummy('x')
- for i in match_obj:
- if i >= 0:
- chareq += (match_obj[i]*diff(x**symbol, x, i)*x**-symbol).expand()
- chareq = Poly(chareq, symbol)
- chareqroots = [rootof(chareq, k) for k in range(chareq.degree())]
- collectterms = []
- # A generator of constants
- constants = list(get_numbered_constants(eq, num=chareq.degree()*2))
- constants.reverse()
- # Create a dict root: multiplicity or charroots
- charroots = defaultdict(int)
- for root in chareqroots:
- charroots[root] += 1
- gsol = S.Zero
- ln = log
- for root, multiplicity in charroots.items():
- for i in range(multiplicity):
- if isinstance(root, RootOf):
- gsol += (x**root) * constants.pop()
- if multiplicity != 1:
- raise ValueError("Value should be 1")
- collectterms = [(0, root, 0)] + collectterms
- elif root.is_real:
- gsol += ln(x)**i*(x**root) * constants.pop()
- collectterms = [(i, root, 0)] + collectterms
- else:
- reroot = re(root)
- imroot = im(root)
- gsol += ln(x)**i * (x**reroot) * (
- constants.pop() * sin(abs(imroot)*ln(x))
- + constants.pop() * cos(imroot*ln(x)))
- collectterms = [(i, reroot, imroot)] + collectterms
- gsol = Eq(f(x), gsol)
- gensols = []
- # Keep track of when to use sin or cos for nonzero imroot
- for i, reroot, imroot in collectterms:
- if imroot == 0:
- gensols.append(ln(x)**i*x**reroot)
- else:
- sin_form = ln(x)**i*x**reroot*sin(abs(imroot)*ln(x))
- if sin_form in gensols:
- cos_form = ln(x)**i*x**reroot*cos(imroot*ln(x))
- gensols.append(cos_form)
- else:
- gensols.append(sin_form)
- return gsol, gensols
- def _solve_variation_of_parameters(eq, func, roots, homogen_sol, order, match_obj, simplify_flag=True):
- r"""
- Helper function for the method of variation of parameters and nonhomogeneous euler eq.
- See the
- :py:meth:`~sympy.solvers.ode.single.NthLinearConstantCoeffVariationOfParameters`
- docstring for more information on this method.
- The parameter are ``match_obj`` should be a dictionary that has the following
- keys:
- ``list``
- A list of solutions to the homogeneous equation.
- ``sol``
- The general solution.
- """
- f = func.func
- x = func.args[0]
- r = match_obj
- psol = 0
- wr = wronskian(roots, x)
- if simplify_flag:
- wr = simplify(wr) # We need much better simplification for
- # some ODEs. See issue 4662, for example.
- # To reduce commonly occurring sin(x)**2 + cos(x)**2 to 1
- wr = trigsimp(wr, deep=True, recursive=True)
- if not wr:
- # The wronskian will be 0 iff the solutions are not linearly
- # independent.
- raise NotImplementedError("Cannot find " + str(order) +
- " solutions to the homogeneous equation necessary to apply " +
- "variation of parameters to " + str(eq) + " (Wronskian == 0)")
- if len(roots) != order:
- raise NotImplementedError("Cannot find " + str(order) +
- " solutions to the homogeneous equation necessary to apply " +
- "variation of parameters to " +
- str(eq) + " (number of terms != order)")
- negoneterm = S.NegativeOne**(order)
- for i in roots:
- psol += negoneterm*Integral(wronskian([sol for sol in roots if sol != i], x)*r[-1]/wr, x)*i/r[order]
- negoneterm *= -1
- if simplify_flag:
- psol = simplify(psol)
- psol = trigsimp(psol, deep=True)
- return Eq(f(x), homogen_sol.rhs + psol)
- def _get_const_characteristic_eq_sols(r, func, order):
- r"""
- Returns the roots of characteristic equation of constant coefficient
- linear ODE and list of collectterms which is later on used by simplification
- to use collect on solution.
- The parameter `r` is a dict of order:coeff terms, where order is the order of the
- derivative on each term, and coeff is the coefficient of that derivative.
- """
- x = func.args[0]
- # First, set up characteristic equation.
- chareq, symbol = S.Zero, Dummy('x')
- for i in r.keys():
- if isinstance(i, str) or i < 0:
- pass
- else:
- chareq += r[i]*symbol**i
- chareq = Poly(chareq, symbol)
- # Can't just call roots because it doesn't return rootof for unsolveable
- # polynomials.
- chareqroots = roots(chareq, multiple=True)
- if len(chareqroots) != order:
- chareqroots = [rootof(chareq, k) for k in range(chareq.degree())]
- chareq_is_complex = not all(i.is_real for i in chareq.all_coeffs())
- # Create a dict root: multiplicity or charroots
- charroots = defaultdict(int)
- for root in chareqroots:
- charroots[root] += 1
- # We need to keep track of terms so we can run collect() at the end.
- # This is necessary for constantsimp to work properly.
- collectterms = []
- gensols = []
- conjugate_roots = [] # used to prevent double-use of conjugate roots
- # Loop over roots in theorder provided by roots/rootof...
- for root in chareqroots:
- # but don't repoeat multiple roots.
- if root not in charroots:
- continue
- multiplicity = charroots.pop(root)
- for i in range(multiplicity):
- if chareq_is_complex:
- gensols.append(x**i*exp(root*x))
- collectterms = [(i, root, 0)] + collectterms
- continue
- reroot = re(root)
- imroot = im(root)
- if imroot.has(atan2) and reroot.has(atan2):
- # Remove this condition when re and im stop returning
- # circular atan2 usages.
- gensols.append(x**i*exp(root*x))
- collectterms = [(i, root, 0)] + collectterms
- else:
- if root in conjugate_roots:
- collectterms = [(i, reroot, imroot)] + collectterms
- continue
- if imroot == 0:
- gensols.append(x**i*exp(reroot*x))
- collectterms = [(i, reroot, 0)] + collectterms
- continue
- conjugate_roots.append(conjugate(root))
- gensols.append(x**i*exp(reroot*x) * sin(abs(imroot) * x))
- gensols.append(x**i*exp(reroot*x) * cos( imroot * x))
- # This ordering is important
- collectterms = [(i, reroot, imroot)] + collectterms
- return gensols, collectterms
- # Ideally these kind of simplification functions shouldn't be part of solvers.
- # odesimp should be improved to handle these kind of specific simplifications.
- def _get_simplified_sol(sol, func, collectterms):
- r"""
- Helper function which collects the solution on
- collectterms. Ideally this should be handled by odesimp.It is used
- only when the simplify is set to True in dsolve.
- The parameter ``collectterms`` is a list of tuple (i, reroot, imroot) where `i` is
- the multiplicity of the root, reroot is real part and imroot being the imaginary part.
- """
- f = func.func
- x = func.args[0]
- collectterms.sort(key=default_sort_key)
- collectterms.reverse()
- assert len(sol) == 1 and sol[0].lhs == f(x)
- sol = sol[0].rhs
- sol = expand_mul(sol)
- for i, reroot, imroot in collectterms:
- sol = collect(sol, x**i*exp(reroot*x)*sin(abs(imroot)*x))
- sol = collect(sol, x**i*exp(reroot*x)*cos(imroot*x))
- for i, reroot, imroot in collectterms:
- sol = collect(sol, x**i*exp(reroot*x))
- sol = powsimp(sol)
- return Eq(f(x), sol)
- def _undetermined_coefficients_match(expr, x, func=None, eq_homogeneous=S.Zero):
- r"""
- Returns a trial function match if undetermined coefficients can be applied
- to ``expr``, and ``None`` otherwise.
- A trial expression can be found for an expression for use with the method
- of undetermined coefficients if the expression is an
- additive/multiplicative combination of constants, polynomials in `x` (the
- independent variable of expr), `\sin(a x + b)`, `\cos(a x + b)`, and
- `e^{a x}` terms (in other words, it has a finite number of linearly
- independent derivatives).
- Note that you may still need to multiply each term returned here by
- sufficient `x` to make it linearly independent with the solutions to the
- homogeneous equation.
- This is intended for internal use by ``undetermined_coefficients`` hints.
- SymPy currently has no way to convert `\sin^n(x) \cos^m(y)` into a sum of
- only `\sin(a x)` and `\cos(b x)` terms, so these are not implemented. So,
- for example, you will need to manually convert `\sin^2(x)` into `[1 +
- \cos(2 x)]/2` to properly apply the method of undetermined coefficients on
- it.
- Examples
- ========
- >>> from sympy import log, exp
- >>> from sympy.solvers.ode.nonhomogeneous import _undetermined_coefficients_match
- >>> from sympy.abc import x
- >>> _undetermined_coefficients_match(9*x*exp(x) + exp(-x), x)
- {'test': True, 'trialset': {x*exp(x), exp(-x), exp(x)}}
- >>> _undetermined_coefficients_match(log(x), x)
- {'test': False}
- """
- a = Wild('a', exclude=[x])
- b = Wild('b', exclude=[x])
- expr = powsimp(expr, combine='exp') # exp(x)*exp(2*x + 1) => exp(3*x + 1)
- retdict = {}
- def _test_term(expr, x):
- r"""
- Test if ``expr`` fits the proper form for undetermined coefficients.
- """
- if not expr.has(x):
- return True
- elif expr.is_Add:
- return all(_test_term(i, x) for i in expr.args)
- elif expr.is_Mul:
- if expr.has(sin, cos):
- foundtrig = False
- # Make sure that there is only one trig function in the args.
- # See the docstring.
- for i in expr.args:
- if i.has(sin, cos):
- if foundtrig:
- return False
- else:
- foundtrig = True
- return all(_test_term(i, x) for i in expr.args)
- elif expr.is_Function:
- if expr.func in (sin, cos, exp, sinh, cosh):
- if expr.args[0].match(a*x + b):
- return True
- else:
- return False
- else:
- return False
- elif expr.is_Pow and expr.base.is_Symbol and expr.exp.is_Integer and \
- expr.exp >= 0:
- return True
- elif expr.is_Pow and expr.base.is_number:
- if expr.exp.match(a*x + b):
- return True
- else:
- return False
- elif expr.is_Symbol or expr.is_number:
- return True
- else:
- return False
- def _get_trial_set(expr, x, exprs=set()):
- r"""
- Returns a set of trial terms for undetermined coefficients.
- The idea behind undetermined coefficients is that the terms expression
- repeat themselves after a finite number of derivatives, except for the
- coefficients (they are linearly dependent). So if we collect these,
- we should have the terms of our trial function.
- """
- def _remove_coefficient(expr, x):
- r"""
- Returns the expression without a coefficient.
- Similar to expr.as_independent(x)[1], except it only works
- multiplicatively.
- """
- term = S.One
- if expr.is_Mul:
- for i in expr.args:
- if i.has(x):
- term *= i
- elif expr.has(x):
- term = expr
- return term
- expr = expand_mul(expr)
- if expr.is_Add:
- for term in expr.args:
- if _remove_coefficient(term, x) in exprs:
- pass
- else:
- exprs.add(_remove_coefficient(term, x))
- exprs = exprs.union(_get_trial_set(term, x, exprs))
- else:
- term = _remove_coefficient(expr, x)
- tmpset = exprs.union({term})
- oldset = set()
- while tmpset != oldset:
- # If you get stuck in this loop, then _test_term is probably
- # broken
- oldset = tmpset.copy()
- expr = expr.diff(x)
- term = _remove_coefficient(expr, x)
- if term.is_Add:
- tmpset = tmpset.union(_get_trial_set(term, x, tmpset))
- else:
- tmpset.add(term)
- exprs = tmpset
- return exprs
- def is_homogeneous_solution(term):
- r""" This function checks whether the given trialset contains any root
- of homogenous equation"""
- return expand(sub_func_doit(eq_homogeneous, func, term)).is_zero
- retdict['test'] = _test_term(expr, x)
- if retdict['test']:
- # Try to generate a list of trial solutions that will have the
- # undetermined coefficients. Note that if any of these are not linearly
- # independent with any of the solutions to the homogeneous equation,
- # then they will need to be multiplied by sufficient x to make them so.
- # This function DOES NOT do that (it doesn't even look at the
- # homogeneous equation).
- temp_set = set()
- for i in Add.make_args(expr):
- act = _get_trial_set(i, x)
- if eq_homogeneous is not S.Zero:
- while any(is_homogeneous_solution(ts) for ts in act):
- act = {x*ts for ts in act}
- temp_set = temp_set.union(act)
- retdict['trialset'] = temp_set
- return retdict
- def _solve_undetermined_coefficients(eq, func, order, match, trialset):
- r"""
- Helper function for the method of undetermined coefficients.
- See the
- :py:meth:`~sympy.solvers.ode.single.NthLinearConstantCoeffUndeterminedCoefficients`
- docstring for more information on this method.
- The parameter ``trialset`` is the set of trial functions as returned by
- ``_undetermined_coefficients_match()['trialset']``.
- The parameter ``match`` should be a dictionary that has the following
- keys:
- ``list``
- A list of solutions to the homogeneous equation.
- ``sol``
- The general solution.
- """
- r = match
- coeffs = numbered_symbols('a', cls=Dummy)
- coefflist = []
- gensols = r['list']
- gsol = r['sol']
- f = func.func
- x = func.args[0]
- if len(gensols) != order:
- raise NotImplementedError("Cannot find " + str(order) +
- " solutions to the homogeneous equation necessary to apply" +
- " undetermined coefficients to " + str(eq) +
- " (number of terms != order)")
- trialfunc = 0
- for i in trialset:
- c = next(coeffs)
- coefflist.append(c)
- trialfunc += c*i
- eqs = sub_func_doit(eq, f(x), trialfunc)
- coeffsdict = dict(list(zip(trialset, [0]*(len(trialset) + 1))))
- eqs = _mexpand(eqs)
- for i in Add.make_args(eqs):
- s = separatevars(i, dict=True, symbols=[x])
- if coeffsdict.get(s[x]):
- coeffsdict[s[x]] += s['coeff']
- else:
- coeffsdict[s[x]] = s['coeff']
- coeffvals = solve(list(coeffsdict.values()), coefflist)
- if not coeffvals:
- raise NotImplementedError(
- "Could not solve `%s` using the "
- "method of undetermined coefficients "
- "(unable to solve for coefficients)." % eq)
- psol = trialfunc.subs(coeffvals)
- return Eq(f(x), gsol.rhs + psol)
|