123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891 |
- r"""
- This module contains :py:meth:`~sympy.solvers.ode.riccati.solve_riccati`,
- a function which gives all rational particular solutions to first order
- Riccati ODEs. A general first order Riccati ODE is given by -
- .. math:: y' = b_0(x) + b_1(x)w + b_2(x)w^2
- where `b_0, b_1` and `b_2` can be arbitrary rational functions of `x`
- with `b_2 \ne 0`. When `b_2 = 0`, the equation is not a Riccati ODE
- anymore and becomes a Linear ODE. Similarly, when `b_0 = 0`, the equation
- is a Bernoulli ODE. The algorithm presented below can find rational
- solution(s) to all ODEs with `b_2 \ne 0` that have a rational solution,
- or prove that no rational solution exists for the equation.
- Background
- ==========
- A Riccati equation can be transformed to its normal form
- .. math:: y' + y^2 = a(x)
- using the transformation
- .. math:: y = -b_2(x) - \frac{b'_2(x)}{2 b_2(x)} - \frac{b_1(x)}{2}
- where `a(x)` is given by
- .. math:: a(x) = \frac{1}{4}\left(\frac{b_2'}{b_2} + b_1\right)^2 - \frac{1}{2}\left(\frac{b_2'}{b_2} + b_1\right)' - b_0 b_2
- Thus, we can develop an algorithm to solve for the Riccati equation
- in its normal form, which would in turn give us the solution for
- the original Riccati equation.
- Algorithm
- =========
- The algorithm implemented here is presented in the Ph.D thesis
- "Rational and Algebraic Solutions of First-Order Algebraic ODEs"
- by N. Thieu Vo. The entire thesis can be found here -
- https://www3.risc.jku.at/publications/download/risc_5387/PhDThesisThieu.pdf
- We have only implemented the Rational Riccati solver (Algorithm 11,
- Pg 78-82 in Thesis). Before we proceed towards the implementation
- of the algorithm, a few definitions to understand are -
- 1. Valuation of a Rational Function at `\infty`:
- The valuation of a rational function `p(x)` at `\infty` is equal
- to the difference between the degree of the denominator and the
- numerator of `p(x)`.
- NOTE: A general definition of valuation of a rational function
- at any value of `x` can be found in Pg 63 of the thesis, but
- is not of any interest for this algorithm.
- 2. Zeros and Poles of a Rational Function:
- Let `a(x) = \frac{S(x)}{T(x)}, T \ne 0` be a rational function
- of `x`. Then -
- a. The Zeros of `a(x)` are the roots of `S(x)`.
- b. The Poles of `a(x)` are the roots of `T(x)`. However, `\infty`
- can also be a pole of a(x). We say that `a(x)` has a pole at
- `\infty` if `a(\frac{1}{x})` has a pole at 0.
- Every pole is associated with an order that is equal to the multiplicity
- of its appearence as a root of `T(x)`. A pole is called a simple pole if
- it has an order 1. Similarly, a pole is called a multiple pole if it has
- an order `\ge` 2.
- Necessary Conditions
- ====================
- For a Riccati equation in its normal form,
- .. math:: y' + y^2 = a(x)
- we can define
- a. A pole is called a movable pole if it is a pole of `y(x)` and is not
- a pole of `a(x)`.
- b. Similarly, a pole is called a non-movable pole if it is a pole of both
- `y(x)` and `a(x)`.
- Then, the algorithm states that a rational solution exists only if -
- a. Every pole of `a(x)` must be either a simple pole or a multiple pole
- of even order.
- b. The valuation of `a(x)` at `\infty` must be even or be `\ge` 2.
- This algorithm finds all possible rational solutions for the Riccati ODE.
- If no rational solutions are found, it means that no rational solutions
- exist.
- The algorithm works for Riccati ODEs where the coefficients are rational
- functions in the independent variable `x` with rational number coefficients
- i.e. in `Q(x)`. The coefficients in the rational function cannot be floats,
- irrational numbers, symbols or any other kind of expression. The reasons
- for this are -
- 1. When using symbols, different symbols could take the same value and this
- would affect the multiplicity of poles if symbols are present here.
- 2. An integer degree bound is required to calculate a polynomial solution
- to an auxiliary differential equation, which in turn gives the particular
- solution for the original ODE. If symbols/floats/irrational numbers are
- present, we cannot determine if the expression for the degree bound is an
- integer or not.
- Solution
- ========
- With these definitions, we can state a general form for the solution of
- the equation. `y(x)` must have the form -
- .. math:: y(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=1}^{m} \frac{1}{x - \chi_i} + \sum_{i=0}^{N} d_i x^i
- where `x_1, x_2, \dots, x_n` are non-movable poles of `a(x)`,
- `\chi_1, \chi_2, \dots, \chi_m` are movable poles of `a(x)`, and the values
- of `N, n, r_1, r_2, \dots, r_n` can be determined from `a(x)`. The
- coefficient vectors `(d_0, d_1, \dots, d_N)` and `(c_{i1}, c_{i2}, \dots, c_{i r_i})`
- can be determined from `a(x)`. We will have 2 choices each of these vectors
- and part of the procedure is figuring out which of the 2 should be used
- to get the solution correctly.
- Implementation
- ==============
- In this implementatin, we use ``Poly`` to represent a rational function
- rather than using ``Expr`` since ``Poly`` is much faster. Since we cannot
- represent rational functions directly using ``Poly``, we instead represent
- a rational function with 2 ``Poly`` objects - one for its numerator and
- the other for its denominator.
- The code is written to match the steps given in the thesis (Pg 82)
- Step 0 : Match the equation -
- Find `b_0, b_1` and `b_2`. If `b_2 = 0` or no such functions exist, raise
- an error
- Step 1 : Transform the equation to its normal form as explained in the
- theory section.
- Step 2 : Initialize an empty set of solutions, ``sol``.
- Step 3 : If `a(x) = 0`, append `\frac{1}/{(x - C1)}` to ``sol``.
- Step 4 : If `a(x)` is a rational non-zero number, append `\pm \sqrt{a}`
- to ``sol``.
- Step 5 : Find the poles and their multiplicities of `a(x)`. Let
- the number of poles be `n`. Also find the valuation of `a(x)` at
- `\infty` using ``val_at_inf``.
- NOTE: Although the algorithm considers `\infty` as a pole, it is
- not mentioned if it a part of the set of finite poles. `\infty`
- is NOT a part of the set of finite poles. If a pole exists at
- `\infty`, we use its multiplicty to find the laurent series of
- `a(x)` about `\infty`.
- Step 6 : Find `n` c-vectors (one for each pole) and 1 d-vector using
- ``construct_c`` and ``construct_d``. Now, determine all the ``2**(n + 1)``
- combinations of choosing between 2 choices for each of the `n` c-vectors
- and 1 d-vector.
- NOTE: The equation for `d_{-1}` in Case 4 (Pg 80) has a printinig
- mistake. The term `- d_N` must be replaced with `-N d_N`. The same
- has been explained in the code as well.
- For each of these above combinations, do
- Step 8 : Compute `m` in ``compute_m_ybar``. `m` is the degree bound of
- the polynomial solution we must find for the auxiliary equation.
- Step 9 : In ``compute_m_ybar``, compute ybar as well where ``ybar`` is
- one part of y(x) -
- .. math:: \overline{y}(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=0}^{N} d_i x^i
- Step 10 : If `m` is a non-negative integer -
- Step 11: Find a polynomial solution of degree `m` for the auxiliary equation.
- There are 2 cases possible -
- a. `m` is a non-negative integer: We can solve for the coefficients
- in `p(x)` using Undetermined Coefficients.
- b. `m` is not a non-negative integer: In this case, we cannot find
- a polynomial solution to the auxiliary equation, and hence, we ignore
- this value of `m`.
- Step 12 : For each `p(x)` that exists, append `ybar + \frac{p'(x)}{p(x)}`
- to ``sol``.
- Step 13 : For each solution in ``sol``, apply an inverse transformation,
- so that the solutions of the original equation are found using the
- solutions of the equation in its normal form.
- """
- from itertools import product
- from sympy.core import S
- from sympy.core.add import Add
- from sympy.core.numbers import oo, Float
- from sympy.core.function import count_ops
- from sympy.core.relational import Eq
- from sympy.core.symbol import symbols, Symbol, Dummy
- from sympy.functions import sqrt, exp
- from sympy.functions.elementary.complexes import sign
- from sympy.integrals.integrals import Integral
- from sympy.polys.domains import ZZ
- from sympy.polys.polytools import Poly
- from sympy.polys.polyroots import roots
- from sympy.solvers.solveset import linsolve
- def riccati_normal(w, x, b1, b2):
- """
- Given a solution `w(x)` to the equation
- .. math:: w'(x) = b_0(x) + b_1(x)*w(x) + b_2(x)*w(x)^2
- and rational function coefficients `b_1(x)` and
- `b_2(x)`, this function transforms the solution to
- give a solution `y(x)` for its corresponding normal
- Riccati ODE
- .. math:: y'(x) + y(x)^2 = a(x)
- using the transformation
- .. math:: y(x) = -b_2(x)*w(x) - b'_2(x)/(2*b_2(x)) - b_1(x)/2
- """
- return -b2*w - b2.diff(x)/(2*b2) - b1/2
- def riccati_inverse_normal(y, x, b1, b2, bp=None):
- """
- Inverse transforming the solution to the normal
- Riccati ODE to get the solution to the Riccati ODE.
- """
- # bp is the expression which is independent of the solution
- # and hence, it need not be computed again
- if bp is None:
- bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2)
- # w(x) = -y(x)/b2(x) - b2'(x)/(2*b2(x)^2) - b1(x)/(2*b2(x))
- return -y/b2 + bp
- def riccati_reduced(eq, f, x):
- """
- Convert a Riccati ODE into its corresponding
- normal Riccati ODE.
- """
- match, funcs = match_riccati(eq, f, x)
- # If equation is not a Riccati ODE, exit
- if not match:
- return False
- # Using the rational functions, find the expression for a(x)
- b0, b1, b2 = funcs
- a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \
- b2.diff(x, 2)/(2*b2)
- # Normal form of Riccati ODE is f'(x) + f(x)^2 = a(x)
- return f(x).diff(x) + f(x)**2 - a
- def linsolve_dict(eq, syms):
- """
- Get the output of linsolve as a dict
- """
- # Convert tuple type return value of linsolve
- # to a dictionary for ease of use
- sol = linsolve(eq, syms)
- if not sol:
- return {}
- return {k:v for k, v in zip(syms, list(sol)[0])}
- def match_riccati(eq, f, x):
- """
- A function that matches and returns the coefficients
- if an equation is a Riccati ODE
- Parameters
- ==========
- eq: Equation to be matched
- f: Dependent variable
- x: Independent variable
- Returns
- =======
- match: True if equation is a Riccati ODE, False otherwise
- funcs: [b0, b1, b2] if match is True, [] otherwise. Here,
- b0, b1 and b2 are rational functions which match the equation.
- """
- # Group terms based on f(x)
- if isinstance(eq, Eq):
- eq = eq.lhs - eq.rhs
- eq = eq.expand().collect(f(x))
- cf = eq.coeff(f(x).diff(x))
- # There must be an f(x).diff(x) term.
- # eq must be an Add object since we are using the expanded
- # equation and it must have atleast 2 terms (b2 != 0)
- if cf != 0 and isinstance(eq, Add):
- # Divide all coefficients by the coefficient of f(x).diff(x)
- # and add the terms again to get the same equation
- eq = Add(*((x/cf).cancel() for x in eq.args)).collect(f(x))
- # Match the equation with the pattern
- b1 = -eq.coeff(f(x))
- b2 = -eq.coeff(f(x)**2)
- b0 = (f(x).diff(x) - b1*f(x) - b2*f(x)**2 - eq).expand()
- funcs = [b0, b1, b2]
- # Check if coefficients are not symbols and floats
- if any(len(x.atoms(Symbol)) > 1 or len(x.atoms(Float)) for x in funcs):
- return False, []
- # If b_0(x) contains f(x), it is not a Riccati ODE
- if len(b0.atoms(f)) or not all((b2 != 0, b0.is_rational_function(x),
- b1.is_rational_function(x), b2.is_rational_function(x))):
- return False, []
- return True, funcs
- return False, []
- def val_at_inf(num, den, x):
- # Valuation of a rational function at oo = deg(denom) - deg(numer)
- return den.degree(x) - num.degree(x)
- def check_necessary_conds(val_inf, muls):
- """
- The necessary conditions for a rational solution
- to exist are as follows -
- i) Every pole of a(x) must be either a simple pole
- or a multiple pole of even order.
- ii) The valuation of a(x) at infinity must be even
- or be greater than or equal to 2.
- Here, a simple pole is a pole with multiplicity 1
- and a multiple pole is a pole with multiplicity
- greater than 1.
- """
- return (val_inf >= 2 or (val_inf <= 0 and val_inf%2 == 0)) and \
- all(mul == 1 or (mul%2 == 0 and mul >= 2) for mul in muls)
- def inverse_transform_poly(num, den, x):
- """
- A function to make the substitution
- x -> 1/x in a rational function that
- is represented using Poly objects for
- numerator and denominator.
- """
- # Declare for reuse
- one = Poly(1, x)
- xpoly = Poly(x, x)
- # Check if degree of numerator is same as denominator
- pwr = val_at_inf(num, den, x)
- if pwr >= 0:
- # Denominator has greater degree. Substituting x with
- # 1/x would make the extra power go to the numerator
- if num.expr != 0:
- num = num.transform(one, xpoly) * x**pwr
- den = den.transform(one, xpoly)
- else:
- # Numerator has greater degree. Substituting x with
- # 1/x would make the extra power go to the denominator
- num = num.transform(one, xpoly)
- den = den.transform(one, xpoly) * x**(-pwr)
- return num.cancel(den, include=True)
- def limit_at_inf(num, den, x):
- """
- Find the limit of a rational function
- at oo
- """
- # pwr = degree(num) - degree(den)
- pwr = -val_at_inf(num, den, x)
- # Numerator has a greater degree than denominator
- # Limit at infinity would depend on the sign of the
- # leading coefficients of numerator and denominator
- if pwr > 0:
- return oo*sign(num.LC()/den.LC())
- # Degree of numerator is equal to that of denominator
- # Limit at infinity is just the ratio of leading coeffs
- elif pwr == 0:
- return num.LC()/den.LC()
- # Degree of numerator is less than that of denominator
- # Limit at infinity is just 0
- else:
- return 0
- def construct_c_case_1(num, den, x, pole):
- # Find the coefficient of 1/(x - pole)**2 in the
- # Laurent series expansion of a(x) about pole.
- num1, den1 = (num*Poly((x - pole)**2, x, extension=True)).cancel(den, include=True)
- r = (num1.subs(x, pole))/(den1.subs(x, pole))
- # If multiplicity is 2, the coefficient to be added
- # in the c-vector is c = (1 +- sqrt(1 + 4*r))/2
- if r != -S(1)/4:
- return [[(1 + sqrt(1 + 4*r))/2], [(1 - sqrt(1 + 4*r))/2]]
- return [[S.Half]]
- def construct_c_case_2(num, den, x, pole, mul):
- # Generate the coefficients using the recurrence
- # relation mentioned in (5.14) in the thesis (Pg 80)
- # r_i = mul/2
- ri = mul//2
- # Find the Laurent series coefficients about the pole
- ser = rational_laurent_series(num, den, x, pole, mul, 6)
- # Start with an empty memo to store the coefficients
- # This is for the plus case
- cplus = [0 for i in range(ri)]
- # Base Case
- cplus[ri-1] = sqrt(ser[2*ri])
- # Iterate backwards to find all coefficients
- s = ri - 1
- sm = 0
- for s in range(ri-1, 0, -1):
- sm = 0
- for j in range(s+1, ri):
- sm += cplus[j-1]*cplus[ri+s-j-1]
- if s!= 1:
- cplus[s-1] = (ser[ri+s] - sm)/(2*cplus[ri-1])
- # Memo for the minus case
- cminus = [-x for x in cplus]
- # Find the 0th coefficient in the recurrence
- cplus[0] = (ser[ri+s] - sm - ri*cplus[ri-1])/(2*cplus[ri-1])
- cminus[0] = (ser[ri+s] - sm - ri*cminus[ri-1])/(2*cminus[ri-1])
- # Add both the plus and minus cases' coefficients
- if cplus != cminus:
- return [cplus, cminus]
- return cplus
- def construct_c_case_3():
- # If multiplicity is 1, the coefficient to be added
- # in the c-vector is 1 (no choice)
- return [[1]]
- def construct_c(num, den, x, poles, muls):
- """
- Helper function to calculate the coefficients
- in the c-vector for each pole.
- """
- c = []
- for pole, mul in zip(poles, muls):
- c.append([])
- # Case 3
- if mul == 1:
- # Add the coefficients from Case 3
- c[-1].extend(construct_c_case_3())
- # Case 1
- elif mul == 2:
- # Add the coefficients from Case 1
- c[-1].extend(construct_c_case_1(num, den, x, pole))
- # Case 2
- else:
- # Add the coefficients from Case 2
- c[-1].extend(construct_c_case_2(num, den, x, pole, mul))
- return c
- def construct_d_case_4(ser, N):
- # Initialize an empty vector
- dplus = [0 for i in range(N+2)]
- # d_N = sqrt(a_{2*N})
- dplus[N] = sqrt(ser[2*N])
- # Use the recurrence relations to find
- # the value of d_s
- for s in range(N-1, -2, -1):
- sm = 0
- for j in range(s+1, N):
- sm += dplus[j]*dplus[N+s-j]
- if s != -1:
- dplus[s] = (ser[N+s] - sm)/(2*dplus[N])
- # Coefficients for the case of d_N = -sqrt(a_{2*N})
- dminus = [-x for x in dplus]
- # The third equation in Eq 5.15 of the thesis is WRONG!
- # d_N must be replaced with N*d_N in that equation.
- dplus[-1] = (ser[N+s] - N*dplus[N] - sm)/(2*dplus[N])
- dminus[-1] = (ser[N+s] - N*dminus[N] - sm)/(2*dminus[N])
- if dplus != dminus:
- return [dplus, dminus]
- return dplus
- def construct_d_case_5(ser):
- # List to store coefficients for plus case
- dplus = [0, 0]
- # d_0 = sqrt(a_0)
- dplus[0] = sqrt(ser[0])
- # d_(-1) = a_(-1)/(2*d_0)
- dplus[-1] = ser[-1]/(2*dplus[0])
- # Coefficients for the minus case are just the negative
- # of the coefficients for the positive case.
- dminus = [-x for x in dplus]
- if dplus != dminus:
- return [dplus, dminus]
- return dplus
- def construct_d_case_6(num, den, x):
- # s_oo = lim x->0 1/x**2 * a(1/x) which is equivalent to
- # s_oo = lim x->oo x**2 * a(x)
- s_inf = limit_at_inf(Poly(x**2, x)*num, den, x)
- # d_(-1) = (1 +- sqrt(1 + 4*s_oo))/2
- if s_inf != -S(1)/4:
- return [[(1 + sqrt(1 + 4*s_inf))/2], [(1 - sqrt(1 + 4*s_inf))/2]]
- return [[S.Half]]
- def construct_d(num, den, x, val_inf):
- """
- Helper function to calculate the coefficients
- in the d-vector based on the valuation of the
- function at oo.
- """
- N = -val_inf//2
- # Multiplicity of oo as a pole
- mul = -val_inf if val_inf < 0 else 0
- ser = rational_laurent_series(num, den, x, oo, mul, 1)
- # Case 4
- if val_inf < 0:
- d = construct_d_case_4(ser, N)
- # Case 5
- elif val_inf == 0:
- d = construct_d_case_5(ser)
- # Case 6
- else:
- d = construct_d_case_6(num, den, x)
- return d
- def rational_laurent_series(num, den, x, r, m, n):
- r"""
- The function computes the Laurent series coefficients
- of a rational function.
- Parameters
- ==========
- num: A Poly object that is the numerator of `f(x)`.
- den: A Poly object that is the denominator of `f(x)`.
- x: The variable of expansion of the series.
- r: The point of expansion of the series.
- m: Multiplicity of r if r is a pole of `f(x)`. Should
- be zero otherwise.
- n: Order of the term upto which the series is expanded.
- Returns
- =======
- series: A dictionary that has power of the term as key
- and coefficient of that term as value.
- Below is a basic outline of how the Laurent series of a
- rational function `f(x)` about `x_0` is being calculated -
- 1. Substitute `x + x_0` in place of `x`. If `x_0`
- is a pole of `f(x)`, multiply the expression by `x^m`
- where `m` is the multiplicity of `x_0`. Denote the
- the resulting expression as g(x). We do this substitution
- so that we can now find the Laurent series of g(x) about
- `x = 0`.
- 2. We can then assume that the Laurent series of `g(x)`
- takes the following form -
- .. math:: g(x) = \frac{num(x)}{den(x)} = \sum_{m = 0}^{\infty} a_m x^m
- where `a_m` denotes the Laurent series coefficients.
- 3. Multiply the denominator to the RHS of the equation
- and form a recurrence relation for the coefficients `a_m`.
- """
- one = Poly(1, x, extension=True)
- if r == oo:
- # Series at x = oo is equal to first transforming
- # the function from x -> 1/x and finding the
- # series at x = 0
- num, den = inverse_transform_poly(num, den, x)
- r = S(0)
- if r:
- # For an expansion about a non-zero point, a
- # transformation from x -> x + r must be made
- num = num.transform(Poly(x + r, x, extension=True), one)
- den = den.transform(Poly(x + r, x, extension=True), one)
- # Remove the pole from the denominator if the series
- # expansion is about one of the poles
- num, den = (num*x**m).cancel(den, include=True)
- # Equate coefficients for the first terms (base case)
- maxdegree = 1 + max(num.degree(), den.degree())
- syms = symbols(f'a:{maxdegree}', cls=Dummy)
- diff = num - den * Poly(syms[::-1], x)
- coeff_diffs = diff.all_coeffs()[::-1][:maxdegree]
- (coeffs, ) = linsolve(coeff_diffs, syms)
- # Use the recursion relation for the rest
- recursion = den.all_coeffs()[::-1]
- div, rec_rhs = recursion[0], recursion[1:]
- series = list(coeffs)
- while len(series) < n:
- next_coeff = Add(*(c*series[-1-n] for n, c in enumerate(rec_rhs))) / div
- series.append(-next_coeff)
- series = {m - i: val for i, val in enumerate(series)}
- return series
- def compute_m_ybar(x, poles, choice, N):
- """
- Helper function to calculate -
- 1. m - The degree bound for the polynomial
- solution that must be found for the auxiliary
- differential equation.
- 2. ybar - Part of the solution which can be
- computed using the poles, c and d vectors.
- """
- ybar = 0
- m = Poly(choice[-1][-1], x, extension=True)
- # Calculate the first (nested) summation for ybar
- # as given in Step 9 of the Thesis (Pg 82)
- for i in range(len(poles)):
- for j in range(len(choice[i])):
- ybar += choice[i][j]/(x - poles[i])**(j+1)
- m -= Poly(choice[i][0], x, extension=True)
- # Calculate the second summation for ybar
- for i in range(N+1):
- ybar += choice[-1][i]*x**i
- return (m.expr, ybar)
- def solve_aux_eq(numa, dena, numy, deny, x, m):
- """
- Helper function to find a polynomial solution
- of degree m for the auxiliary differential
- equation.
- """
- # Assume that the solution is of the type
- # p(x) = C_0 + C_1*x + ... + C_{m-1}*x**(m-1) + x**m
- psyms = symbols(f'C0:{m}', cls=Dummy)
- K = ZZ[psyms]
- psol = Poly(K.gens, x, domain=K) + Poly(x**m, x, domain=K)
- # Eq (5.16) in Thesis - Pg 81
- auxeq = (dena*(numy.diff(x)*deny - numy*deny.diff(x) + numy**2) - numa*deny**2)*psol
- if m >= 1:
- px = psol.diff(x)
- auxeq += px*(2*numy*deny*dena)
- if m >= 2:
- auxeq += px.diff(x)*(deny**2*dena)
- if m != 0:
- # m is a non-zero integer. Find the constant terms using undetermined coefficients
- return psol, linsolve_dict(auxeq.all_coeffs(), psyms), True
- else:
- # m == 0 . Check if 1 (x**0) is a solution to the auxiliary equation
- return S.One, auxeq, auxeq == 0
- def remove_redundant_sols(sol1, sol2, x):
- """
- Helper function to remove redundant
- solutions to the differential equation.
- """
- # If y1 and y2 are redundant solutions, there is
- # some value of the arbitrary constant for which
- # they will be equal
- syms1 = sol1.atoms(Symbol, Dummy)
- syms2 = sol2.atoms(Symbol, Dummy)
- num1, den1 = [Poly(e, x, extension=True) for e in sol1.together().as_numer_denom()]
- num2, den2 = [Poly(e, x, extension=True) for e in sol2.together().as_numer_denom()]
- # Cross multiply
- e = num1*den2 - den1*num2
- # Check if there are any constants
- syms = list(e.atoms(Symbol, Dummy))
- if len(syms):
- # Find values of constants for which solutions are equal
- redn = linsolve(e.all_coeffs(), syms)
- if len(redn):
- # Return the general solution over a particular solution
- if len(syms1) > len(syms2):
- return sol2
- # If both have constants, return the lesser complex solution
- elif len(syms1) == len(syms2):
- return sol1 if count_ops(syms1) >= count_ops(syms2) else sol2
- else:
- return sol1
- def get_gen_sol_from_part_sol(part_sols, a, x):
- """"
- Helper function which computes the general
- solution for a Riccati ODE from its particular
- solutions.
- There are 3 cases to find the general solution
- from the particular solutions for a Riccati ODE
- depending on the number of particular solution(s)
- we have - 1, 2 or 3.
- For more information, see Section 6 of
- "Methods of Solution of the Riccati Differential Equation"
- by D. R. Haaheim and F. M. Stein
- """
- # If no particular solutions are found, a general
- # solution cannot be found
- if len(part_sols) == 0:
- return []
- # In case of a single particular solution, the general
- # solution can be found by using the substitution
- # y = y1 + 1/z and solving a Bernoulli ODE to find z.
- elif len(part_sols) == 1:
- y1 = part_sols[0]
- i = exp(Integral(2*y1, x))
- z = i * Integral(a/i, x)
- z = z.doit()
- if a == 0 or z == 0:
- return y1
- return y1 + 1/z
- # In case of 2 particular solutions, the general solution
- # can be found by solving a separable equation. This is
- # the most common case, i.e. most Riccati ODEs have 2
- # rational particular solutions.
- elif len(part_sols) == 2:
- y1, y2 = part_sols
- # One of them already has a constant
- if len(y1.atoms(Dummy)) + len(y2.atoms(Dummy)) > 0:
- u = exp(Integral(y2 - y1, x)).doit()
- # Introduce a constant
- else:
- C1 = Dummy('C1')
- u = C1*exp(Integral(y2 - y1, x)).doit()
- if u == 1:
- return y2
- return (y2*u - y1)/(u - 1)
- # In case of 3 particular solutions, a closed form
- # of the general solution can be obtained directly
- else:
- y1, y2, y3 = part_sols[:3]
- C1 = Dummy('C1')
- return (C1 + 1)*y2*(y1 - y3)/(C1*y1 + y2 - (C1 + 1)*y3)
- def solve_riccati(fx, x, b0, b1, b2, gensol=False):
- """
- The main function that gives particular/general
- solutions to Riccati ODEs that have atleast 1
- rational particular solution.
- """
- # Step 1 : Convert to Normal Form
- a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \
- b2.diff(x, 2)/(2*b2)
- a_t = a.together()
- num, den = [Poly(e, x, extension=True) for e in a_t.as_numer_denom()]
- num, den = num.cancel(den, include=True)
- # Step 2
- presol = []
- # Step 3 : a(x) is 0
- if num == 0:
- presol.append(1/(x + Dummy('C1')))
- # Step 4 : a(x) is a non-zero constant
- elif x not in num.free_symbols.union(den.free_symbols):
- presol.extend([sqrt(a), -sqrt(a)])
- # Step 5 : Find poles and valuation at infinity
- poles = roots(den, x)
- poles, muls = list(poles.keys()), list(poles.values())
- val_inf = val_at_inf(num, den, x)
- if len(poles):
- # Check necessary conditions (outlined in the module docstring)
- if not check_necessary_conds(val_inf, muls):
- raise ValueError("Rational Solution doesn't exist")
- # Step 6
- # Construct c-vectors for each singular point
- c = construct_c(num, den, x, poles, muls)
- # Construct d vectors for each singular point
- d = construct_d(num, den, x, val_inf)
- # Step 7 : Iterate over all possible combinations and return solutions
- # For each possible combination, generate an array of 0's and 1's
- # where 0 means pick 1st choice and 1 means pick the second choice.
- # NOTE: We could exit from the loop if we find 3 particular solutions,
- # but it is not implemented here as -
- # a. Finding 3 particular solutions is very rare. Most of the time,
- # only 2 particular solutions are found.
- # b. In case we exit after finding 3 particular solutions, it might
- # happen that 1 or 2 of them are redundant solutions. So, instead of
- # spending some more time in computing the particular solutions,
- # we will end up computing the general solution from a single
- # particular solution which is usually slower than computing the
- # general solution from 2 or 3 particular solutions.
- c.append(d)
- choices = product(*c)
- for choice in choices:
- m, ybar = compute_m_ybar(x, poles, choice, -val_inf//2)
- numy, deny = [Poly(e, x, extension=True) for e in ybar.together().as_numer_denom()]
- # Step 10 : Check if a valid solution exists. If yes, also check
- # if m is a non-negative integer
- if m.is_nonnegative == True and m.is_integer == True:
- # Step 11 : Find polynomial solutions of degree m for the auxiliary equation
- psol, coeffs, exists = solve_aux_eq(num, den, numy, deny, x, m)
- # Step 12 : If valid polynomial solution exists, append solution.
- if exists:
- # m == 0 case
- if psol == 1 and coeffs == 0:
- # p(x) = 1, so p'(x)/p(x) term need not be added
- presol.append(ybar)
- # m is a positive integer and there are valid coefficients
- elif len(coeffs):
- # Substitute the valid coefficients to get p(x)
- psol = psol.xreplace(coeffs)
- # y(x) = ybar(x) + p'(x)/p(x)
- presol.append(ybar + psol.diff(x)/psol)
- # Remove redundant solutions from the list of existing solutions
- remove = set()
- for i in range(len(presol)):
- for j in range(i+1, len(presol)):
- rem = remove_redundant_sols(presol[i], presol[j], x)
- if rem is not None:
- remove.add(rem)
- sols = [x for x in presol if x not in remove]
- # Step 15 : Inverse transform the solutions of the equation in normal form
- bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2)
- # If general solution is required, compute it from the particular solutions
- if gensol:
- sols = [get_gen_sol_from_part_sol(sols, a, x)]
- # Inverse transform the particular solutions
- presol = [Eq(fx, riccati_inverse_normal(y, x, b1, b2, bp).cancel(extension=True)) for y in sols]
- return presol
|