|
- """
- This module contain solvers for all kinds of equations:
- - algebraic or transcendental, use solve()
- - recurrence, use rsolve()
- - differential, use dsolve()
- - nonlinear (numerically), use nsolve()
- (you will need a good starting point)
- """
- from sympy.core import (S, Add, Symbol, Dummy, Expr, Mul)
- from sympy.core.assumptions import check_assumptions
- from sympy.core.exprtools import factor_terms
- from sympy.core.function import (expand_mul, expand_log, Derivative,
- AppliedUndef, UndefinedFunction, nfloat,
- Function, expand_power_exp, _mexpand, expand,
- expand_func)
- from sympy.core.logic import fuzzy_not
- from sympy.core.numbers import ilcm, Float, Rational
- from sympy.core.power import integer_log, Pow
- from sympy.core.relational import Relational, Eq, Ne
- from sympy.core.sorting import ordered, default_sort_key
- from sympy.core.sympify import sympify
- from sympy.core.traversal import preorder_traversal
- from sympy.logic.boolalg import And, Or, BooleanAtom
- from sympy.functions import (log, exp, LambertW, cos, sin, tan, acos, asin, atan,
- Abs, re, im, arg, sqrt, atan2)
- from sympy.functions.combinatorial.factorials import binomial
- from sympy.functions.elementary.trigonometric import (TrigonometricFunction,
- HyperbolicFunction)
- from sympy.functions.elementary.piecewise import piecewise_fold, Piecewise
- from sympy.ntheory.factor_ import divisors
- from sympy.simplify import (simplify, collect, powsimp, posify, # type: ignore
- powdenest, nsimplify, denom, logcombine, sqrtdenest, fraction,
- separatevars)
- from sympy.simplify.sqrtdenest import sqrt_depth
- from sympy.simplify.fu import TR1, TR2i
- from sympy.matrices.common import NonInvertibleMatrixError
- from sympy.matrices import Matrix, zeros
- from sympy.polys import roots, cancel, factor, Poly
- from sympy.polys.polyerrors import GeneratorsNeeded, PolynomialError
- from sympy.polys.solvers import sympy_eqs_to_ring, solve_lin_sys
- from sympy.utilities.lambdify import lambdify
- from sympy.utilities.misc import filldedent, debug
- from sympy.utilities.iterables import (connected_components,
- generate_bell, uniq, iterable, is_sequence, subsets)
- from sympy.utilities.decorator import conserve_mpmath_dps
- from mpmath import findroot
- from sympy.solvers.polysys import solve_poly_system
- from types import GeneratorType
- from collections import defaultdict
- from itertools import product
- import warnings
- def recast_to_symbols(eqs, symbols):
- """
- Return (e, s, d) where e and s are versions of *eqs* and
- *symbols* in which any non-Symbol objects in *symbols* have
- been replaced with generic Dummy symbols and d is a dictionary
- that can be used to restore the original expressions.
- Examples
- ========
- >>> from sympy.solvers.solvers import recast_to_symbols
- >>> from sympy import symbols, Function
- >>> x, y = symbols('x y')
- >>> fx = Function('f')(x)
- >>> eqs, syms = [fx + 1, x, y], [fx, y]
- >>> e, s, d = recast_to_symbols(eqs, syms); (e, s, d)
- ([_X0 + 1, x, y], [_X0, y], {_X0: f(x)})
- The original equations and symbols can be restored using d:
- >>> assert [i.xreplace(d) for i in eqs] == eqs
- >>> assert [d.get(i, i) for i in s] == syms
- """
- if not iterable(eqs) and iterable(symbols):
- raise ValueError('Both eqs and symbols must be iterable')
- new_symbols = list(symbols)
- swap_sym = {}
- for i, s in enumerate(symbols):
- if not isinstance(s, Symbol) and s not in swap_sym:
- swap_sym[s] = Dummy('X%d' % i)
- new_symbols[i] = swap_sym[s]
- new_f = []
- for i in eqs:
- isubs = getattr(i, 'subs', None)
- if isubs is not None:
- new_f.append(isubs(swap_sym))
- else:
- new_f.append(i)
- swap_sym = {v: k for k, v in swap_sym.items()}
- return new_f, new_symbols, swap_sym
- def _ispow(e):
- """Return True if e is a Pow or is exp."""
- return isinstance(e, Expr) and (e.is_Pow or isinstance(e, exp))
- def _simple_dens(f, symbols):
- # when checking if a denominator is zero, we can just check the
- # base of powers with nonzero exponents since if the base is zero
- # the power will be zero, too. To keep it simple and fast, we
- # limit simplification to exponents that are Numbers
- dens = set()
- for d in denoms(f, symbols):
- if d.is_Pow and d.exp.is_Number:
- if d.exp.is_zero:
- continue # foo**0 is never 0
- d = d.base
- dens.add(d)
- return dens
- def denoms(eq, *symbols):
- """
- Return (recursively) set of all denominators that appear in *eq*
- that contain any symbol in *symbols*; if *symbols* are not
- provided then all denominators will be returned.
- Examples
- ========
- >>> from sympy.solvers.solvers import denoms
- >>> from sympy.abc import x, y, z
- >>> denoms(x/y)
- {y}
- >>> denoms(x/(y*z))
- {y, z}
- >>> denoms(3/x + y/z)
- {x, z}
- >>> denoms(x/2 + y/z)
- {2, z}
- If *symbols* are provided then only denominators containing
- those symbols will be returned:
- >>> denoms(1/x + 1/y + 1/z, y, z)
- {y, z}
- """
- pot = preorder_traversal(eq)
- dens = set()
- for p in pot:
- # Here p might be Tuple or Relational
- # Expr subtrees (e.g. lhs and rhs) will be traversed after by pot
- if not isinstance(p, Expr):
- continue
- den = denom(p)
- if den is S.One:
- continue
- for d in Mul.make_args(den):
- dens.add(d)
- if not symbols:
- return dens
- elif len(symbols) == 1:
- if iterable(symbols[0]):
- symbols = symbols[0]
- rv = []
- for d in dens:
- free = d.free_symbols
- if any(s in free for s in symbols):
- rv.append(d)
- return set(rv)
- def checksol(f, symbol, sol=None, **flags):
- """
- Checks whether sol is a solution of equation f == 0.
- Explanation
- ===========
- Input can be either a single symbol and corresponding value
- or a dictionary of symbols and values. When given as a dictionary
- and flag ``simplify=True``, the values in the dictionary will be
- simplified. *f* can be a single equation or an iterable of equations.
- A solution must satisfy all equations in *f* to be considered valid;
- if a solution does not satisfy any equation, False is returned; if one or
- more checks are inconclusive (and none are False) then None is returned.
- Examples
- ========
- >>> from sympy import checksol, symbols
- >>> x, y = symbols('x,y')
- >>> checksol(x**4 - 1, x, 1)
- True
- >>> checksol(x**4 - 1, x, 0)
- False
- >>> checksol(x**2 + y**2 - 5**2, {x: 3, y: 4})
- True
- To check if an expression is zero using ``checksol()``, pass it
- as *f* and send an empty dictionary for *symbol*:
- >>> checksol(x**2 + x - x*(x + 1), {})
- True
- None is returned if ``checksol()`` could not conclude.
- flags:
- 'numerical=True (default)'
- do a fast numerical check if ``f`` has only one symbol.
- 'minimal=True (default is False)'
- a very fast, minimal testing.
- 'warn=True (default is False)'
- show a warning if checksol() could not conclude.
- 'simplify=True (default)'
- simplify solution before substituting into function and
- simplify the function before trying specific simplifications
- 'force=True (default is False)'
- make positive all symbols without assumptions regarding sign.
- """
- from sympy.physics.units import Unit
- minimal = flags.get('minimal', False)
- if sol is not None:
- sol = {symbol: sol}
- elif isinstance(symbol, dict):
- sol = symbol
- else:
- msg = 'Expecting (sym, val) or ({sym: val}, None) but got (%s, %s)'
- raise ValueError(msg % (symbol, sol))
- if iterable(f):
- if not f:
- raise ValueError('no functions to check')
- rv = True
- for fi in f:
- check = checksol(fi, sol, **flags)
- if check:
- continue
- if check is False:
- return False
- rv = None # don't return, wait to see if there's a False
- return rv
- if isinstance(f, Poly):
- f = f.as_expr()
- elif isinstance(f, (Eq, Ne)):
- if f.rhs in (S.true, S.false):
- f = f.reversed
- B, E = f.args
- if isinstance(B, BooleanAtom):
- f = f.subs(sol)
- if not f.is_Boolean:
- return
- else:
- f = f.rewrite(Add, evaluate=False)
- if isinstance(f, BooleanAtom):
- return bool(f)
- elif not f.is_Relational and not f:
- return True
- if sol and not f.free_symbols & set(sol.keys()):
- # if f(y) == 0, x=3 does not set f(y) to zero...nor does it not
- return None
- illegal = {S.NaN,
- S.ComplexInfinity,
- S.Infinity,
- S.NegativeInfinity}
- if any(sympify(v).atoms() & illegal for k, v in sol.items()):
- return False
- was = f
- attempt = -1
- numerical = flags.get('numerical', True)
- while 1:
- attempt += 1
- if attempt == 0:
- val = f.subs(sol)
- if isinstance(val, Mul):
- val = val.as_independent(Unit)[0]
- if val.atoms() & illegal:
- return False
- elif attempt == 1:
- if not val.is_number:
- if not val.is_constant(*list(sol.keys()), simplify=not minimal):
- return False
- # there are free symbols -- simple expansion might work
- _, val = val.as_content_primitive()
- val = _mexpand(val.as_numer_denom()[0], recursive=True)
- elif attempt == 2:
- if minimal:
- return
- if flags.get('simplify', True):
- for k in sol:
- sol[k] = simplify(sol[k])
- # start over without the failed expanded form, possibly
- # with a simplified solution
- val = simplify(f.subs(sol))
- if flags.get('force', True):
- val, reps = posify(val)
- # expansion may work now, so try again and check
- exval = _mexpand(val, recursive=True)
- if exval.is_number:
- # we can decide now
- val = exval
- else:
- # if there are no radicals and no functions then this can't be
- # zero anymore -- can it?
- pot = preorder_traversal(expand_mul(val))
- seen = set()
- saw_pow_func = False
- for p in pot:
- if p in seen:
- continue
- seen.add(p)
- if p.is_Pow and not p.exp.is_Integer:
- saw_pow_func = True
- elif p.is_Function:
- saw_pow_func = True
- elif isinstance(p, UndefinedFunction):
- saw_pow_func = True
- if saw_pow_func:
- break
- if saw_pow_func is False:
- return False
- if flags.get('force', True):
- # don't do a zero check with the positive assumptions in place
- val = val.subs(reps)
- nz = fuzzy_not(val.is_zero)
- if nz is not None:
- # issue 5673: nz may be True even when False
- # so these are just hacks to keep a false positive
- # from being returned
- # HACK 1: LambertW (issue 5673)
- if val.is_number and val.has(LambertW):
- # don't eval this to verify solution since if we got here,
- # numerical must be False
- return None
- # add other HACKs here if necessary, otherwise we assume
- # the nz value is correct
- return not nz
- break
- if val == was:
- continue
- elif val.is_Rational:
- return val == 0
- if numerical and val.is_number:
- return (abs(val.n(18).n(12, chop=True)) < 1e-9) is S.true
- was = val
- if flags.get('warn', False):
- warnings.warn("\n\tWarning: could not verify solution %s." % sol)
- # returns None if it can't conclude
- # TODO: improve solution testing
- def solve(f, *symbols, **flags):
- r"""
- Algebraically solves equations and systems of equations.
- Explanation
- ===========
- Currently supported:
- - polynomial
- - transcendental
- - piecewise combinations of the above
- - systems of linear and polynomial equations
- - systems containing relational expressions
- Examples
- ========
- The output varies according to the input and can be seen by example:
- >>> from sympy import solve, Poly, Eq, Function, exp
- >>> from sympy.abc import x, y, z, a, b
- >>> f = Function('f')
- Boolean or univariate Relational:
- >>> solve(x < 3)
- (-oo < x) & (x < 3)
- To always get a list of solution mappings, use flag dict=True:
- >>> solve(x - 3, dict=True)
- [{x: 3}]
- >>> sol = solve([x - 3, y - 1], dict=True)
- >>> sol
- [{x: 3, y: 1}]
- >>> sol[0][x]
- 3
- >>> sol[0][y]
- 1
- To get a list of *symbols* and set of solution(s) use flag set=True:
- >>> solve([x**2 - 3, y - 1], set=True)
- ([x, y], {(-sqrt(3), 1), (sqrt(3), 1)})
- Single expression and single symbol that is in the expression:
- >>> solve(x - y, x)
- [y]
- >>> solve(x - 3, x)
- [3]
- >>> solve(Eq(x, 3), x)
- [3]
- >>> solve(Poly(x - 3), x)
- [3]
- >>> solve(x**2 - y**2, x, set=True)
- ([x], {(-y,), (y,)})
- >>> solve(x**4 - 1, x, set=True)
- ([x], {(-1,), (1,), (-I,), (I,)})
- Single expression with no symbol that is in the expression:
- >>> solve(3, x)
- []
- >>> solve(x - 3, y)
- []
- Single expression with no symbol given. In this case, all free *symbols*
- will be selected as potential *symbols* to solve for. If the equation is
- univariate then a list of solutions is returned; otherwise - as is the case
- when *symbols* are given as an iterable of length greater than 1 - a list of
- mappings will be returned:
- >>> solve(x - 3)
- [3]
- >>> solve(x**2 - y**2)
- [{x: -y}, {x: y}]
- >>> solve(z**2*x**2 - z**2*y**2)
- [{x: -y}, {x: y}, {z: 0}]
- >>> solve(z**2*x - z**2*y**2)
- [{x: y**2}, {z: 0}]
- When an object other than a Symbol is given as a symbol, it is
- isolated algebraically and an implicit solution may be obtained.
- This is mostly provided as a convenience to save you from replacing
- the object with a Symbol and solving for that Symbol. It will only
- work if the specified object can be replaced with a Symbol using the
- subs method:
- >>> solve(f(x) - x, f(x))
- [x]
- >>> solve(f(x).diff(x) - f(x) - x, f(x).diff(x))
- [x + f(x)]
- >>> solve(f(x).diff(x) - f(x) - x, f(x))
- [-x + Derivative(f(x), x)]
- >>> solve(x + exp(x)**2, exp(x), set=True)
- ([exp(x)], {(-sqrt(-x),), (sqrt(-x),)})
- >>> from sympy import Indexed, IndexedBase, Tuple, sqrt
- >>> A = IndexedBase('A')
- >>> eqs = Tuple(A[1] + A[2] - 3, A[1] - A[2] + 1)
- >>> solve(eqs, eqs.atoms(Indexed))
- {A[1]: 1, A[2]: 2}
- * To solve for a symbol implicitly, use implicit=True:
- >>> solve(x + exp(x), x)
- [-LambertW(1)]
- >>> solve(x + exp(x), x, implicit=True)
- [-exp(x)]
- * It is possible to solve for anything that can be targeted with
- subs:
- >>> solve(x + 2 + sqrt(3), x + 2)
- [-sqrt(3)]
- >>> solve((x + 2 + sqrt(3), x + 4 + y), y, x + 2)
- {y: -2 + sqrt(3), x + 2: -sqrt(3)}
- * Nothing heroic is done in this implicit solving so you may end up
- with a symbol still in the solution:
- >>> eqs = (x*y + 3*y + sqrt(3), x + 4 + y)
- >>> solve(eqs, y, x + 2)
- {y: -sqrt(3)/(x + 3), x + 2: -2*x/(x + 3) - 6/(x + 3) + sqrt(3)/(x + 3)}
- >>> solve(eqs, y*x, x)
- {x: -y - 4, x*y: -3*y - sqrt(3)}
- * If you attempt to solve for a number remember that the number
- you have obtained does not necessarily mean that the value is
- equivalent to the expression obtained:
- >>> solve(sqrt(2) - 1, 1)
- [sqrt(2)]
- >>> solve(x - y + 1, 1) # /!\ -1 is targeted, too
- [x/(y - 1)]
- >>> [_.subs(z, -1) for _ in solve((x - y + 1).subs(-1, z), 1)]
- [-x + y]
- * To solve for a function within a derivative, use ``dsolve``.
- Single expression and more than one symbol:
- * When there is a linear solution:
- >>> solve(x - y**2, x, y)
- [(y**2, y)]
- >>> solve(x**2 - y, x, y)
- [(x, x**2)]
- >>> solve(x**2 - y, x, y, dict=True)
- [{y: x**2}]
- * When undetermined coefficients are identified:
- * That are linear:
- >>> solve((a + b)*x - b + 2, a, b)
- {a: -2, b: 2}
- * That are nonlinear:
- >>> solve((a + b)*x - b**2 + 2, a, b, set=True)
- ([a, b], {(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))})
- * If there is no linear solution, then the first successful
- attempt for a nonlinear solution will be returned:
- >>> solve(x**2 - y**2, x, y, dict=True)
- [{x: -y}, {x: y}]
- >>> solve(x**2 - y**2/exp(x), x, y, dict=True)
- [{x: 2*LambertW(-y/2)}, {x: 2*LambertW(y/2)}]
- >>> solve(x**2 - y**2/exp(x), y, x)
- [(-x*sqrt(exp(x)), x), (x*sqrt(exp(x)), x)]
- Iterable of one or more of the above:
- * Involving relationals or bools:
- >>> solve([x < 3, x - 2])
- Eq(x, 2)
- >>> solve([x > 3, x - 2])
- False
- * When the system is linear:
- * With a solution:
- >>> solve([x - 3], x)
- {x: 3}
- >>> solve((x + 5*y - 2, -3*x + 6*y - 15), x, y)
- {x: -3, y: 1}
- >>> solve((x + 5*y - 2, -3*x + 6*y - 15), x, y, z)
- {x: -3, y: 1}
- >>> solve((x + 5*y - 2, -3*x + 6*y - z), z, x, y)
- {x: 2 - 5*y, z: 21*y - 6}
- * Without a solution:
- >>> solve([x + 3, x - 3])
- []
- * When the system is not linear:
- >>> solve([x**2 + y -2, y**2 - 4], x, y, set=True)
- ([x, y], {(-2, -2), (0, 2), (2, -2)})
- * If no *symbols* are given, all free *symbols* will be selected and a
- list of mappings returned:
- >>> solve([x - 2, x**2 + y])
- [{x: 2, y: -4}]
- >>> solve([x - 2, x**2 + f(x)], {f(x), x})
- [{x: 2, f(x): -4}]
- * If any equation does not depend on the symbol(s) given, it will be
- eliminated from the equation set and an answer may be given
- implicitly in terms of variables that were not of interest:
- >>> solve([x - y, y - 3], x)
- {x: y}
- **Additional Examples**
- ``solve()`` with check=True (default) will run through the symbol tags to
- elimate unwanted solutions. If no assumptions are included, all possible
- solutions will be returned:
- >>> from sympy import Symbol, solve
- >>> x = Symbol("x")
- >>> solve(x**2 - 1)
- [-1, 1]
- By using the positive tag, only one solution will be returned:
- >>> pos = Symbol("pos", positive=True)
- >>> solve(pos**2 - 1)
- [1]
- Assumptions are not checked when ``solve()`` input involves
- relationals or bools.
- When the solutions are checked, those that make any denominator zero
- are automatically excluded. If you do not want to exclude such solutions,
- then use the check=False option:
- >>> from sympy import sin, limit
- >>> solve(sin(x)/x) # 0 is excluded
- [pi]
- If check=False, then a solution to the numerator being zero is found: x = 0.
- In this case, this is a spurious solution since $\sin(x)/x$ has the well
- known limit (without dicontinuity) of 1 at x = 0:
- >>> solve(sin(x)/x, check=False)
- [0, pi]
- In the following case, however, the limit exists and is equal to the
- value of x = 0 that is excluded when check=True:
- >>> eq = x**2*(1/x - z**2/x)
- >>> solve(eq, x)
- []
- >>> solve(eq, x, check=False)
- [0]
- >>> limit(eq, x, 0, '-')
- 0
- >>> limit(eq, x, 0, '+')
- 0
- **Disabling High-Order Explicit Solutions**
- When solving polynomial expressions, you might not want explicit solutions
- (which can be quite long). If the expression is univariate, ``CRootOf``
- instances will be returned instead:
- >>> solve(x**3 - x + 1)
- [-1/((-1/2 - sqrt(3)*I/2)*(3*sqrt(69)/2 + 27/2)**(1/3)) -
- (-1/2 - sqrt(3)*I/2)*(3*sqrt(69)/2 + 27/2)**(1/3)/3,
- -(-1/2 + sqrt(3)*I/2)*(3*sqrt(69)/2 + 27/2)**(1/3)/3 -
- 1/((-1/2 + sqrt(3)*I/2)*(3*sqrt(69)/2 + 27/2)**(1/3)),
- -(3*sqrt(69)/2 + 27/2)**(1/3)/3 -
- 1/(3*sqrt(69)/2 + 27/2)**(1/3)]
- >>> solve(x**3 - x + 1, cubics=False)
- [CRootOf(x**3 - x + 1, 0),
- CRootOf(x**3 - x + 1, 1),
- CRootOf(x**3 - x + 1, 2)]
- If the expression is multivariate, no solution might be returned:
- >>> solve(x**3 - x + a, x, cubics=False)
- []
- Sometimes solutions will be obtained even when a flag is False because the
- expression could be factored. In the following example, the equation can
- be factored as the product of a linear and a quadratic factor so explicit
- solutions (which did not require solving a cubic expression) are obtained:
- >>> eq = x**3 + 3*x**2 + x - 1
- >>> solve(eq, cubics=False)
- [-1, -1 + sqrt(2), -sqrt(2) - 1]
- **Solving Equations Involving Radicals**
- Because of SymPy's use of the principle root, some solutions
- to radical equations will be missed unless check=False:
- >>> from sympy import root
- >>> eq = root(x**3 - 3*x**2, 3) + 1 - x
- >>> solve(eq)
- []
- >>> solve(eq, check=False)
- [1/3]
- In the above example, there is only a single solution to the
- equation. Other expressions will yield spurious roots which
- must be checked manually; roots which give a negative argument
- to odd-powered radicals will also need special checking:
- >>> from sympy import real_root, S
- >>> eq = root(x, 3) - root(x, 5) + S(1)/7
- >>> solve(eq) # this gives 2 solutions but misses a 3rd
- [CRootOf(7*x**5 - 7*x**3 + 1, 1)**15,
- CRootOf(7*x**5 - 7*x**3 + 1, 2)**15]
- >>> sol = solve(eq, check=False)
- >>> [abs(eq.subs(x,i).n(2)) for i in sol]
- [0.48, 0.e-110, 0.e-110, 0.052, 0.052]
- The first solution is negative so ``real_root`` must be used to see that it
- satisfies the expression:
- >>> abs(real_root(eq.subs(x, sol[0])).n(2))
- 0.e-110
- If the roots of the equation are not real then more care will be
- necessary to find the roots, especially for higher order equations.
- Consider the following expression:
- >>> expr = root(x, 3) - root(x, 5)
- We will construct a known value for this expression at x = 3 by selecting
- the 1-th root for each radical:
- >>> expr1 = root(x, 3, 1) - root(x, 5, 1)
- >>> v = expr1.subs(x, -3)
- The ``solve`` function is unable to find any exact roots to this equation:
- >>> eq = Eq(expr, v); eq1 = Eq(expr1, v)
- >>> solve(eq, check=False), solve(eq1, check=False)
- ([], [])
- The function ``unrad``, however, can be used to get a form of the equation
- for which numerical roots can be found:
- >>> from sympy.solvers.solvers import unrad
- >>> from sympy import nroots
- >>> e, (p, cov) = unrad(eq)
- >>> pvals = nroots(e)
- >>> inversion = solve(cov, x)[0]
- >>> xvals = [inversion.subs(p, i) for i in pvals]
- Although ``eq`` or ``eq1`` could have been used to find ``xvals``, the
- solution can only be verified with ``expr1``:
- >>> z = expr - v
- >>> [xi.n(chop=1e-9) for xi in xvals if abs(z.subs(x, xi).n()) < 1e-9]
- []
- >>> z1 = expr1 - v
- >>> [xi.n(chop=1e-9) for xi in xvals if abs(z1.subs(x, xi).n()) < 1e-9]
- [-3.0]
- Parameters
- ==========
- f :
- - a single Expr or Poly that must be zero
- - an Equality
- - a Relational expression
- - a Boolean
- - iterable of one or more of the above
- symbols : (object(s) to solve for) specified as
- - none given (other non-numeric objects will be used)
- - single symbol
- - denested list of symbols
- (e.g., ``solve(f, x, y)``)
- - ordered iterable of symbols
- (e.g., ``solve(f, [x, y])``)
- flags :
- dict=True (default is False)
- Return list (perhaps empty) of solution mappings.
- set=True (default is False)
- Return list of symbols and set of tuple(s) of solution(s).
- exclude=[] (default)
- Do not try to solve for any of the free symbols in exclude;
- if expressions are given, the free symbols in them will
- be extracted automatically.
- check=True (default)
- If False, do not do any testing of solutions. This can be
- useful if you want to include solutions that make any
- denominator zero.
- numerical=True (default)
- Do a fast numerical check if *f* has only one symbol.
- minimal=True (default is False)
- A very fast, minimal testing.
- warn=True (default is False)
- Show a warning if ``checksol()`` could not conclude.
- simplify=True (default)
- Simplify all but polynomials of order 3 or greater before
- returning them and (if check is not False) use the
- general simplify function on the solutions and the
- expression obtained when they are substituted into the
- function which should be zero.
- force=True (default is False)
- Make positive all symbols without assumptions regarding sign.
- rational=True (default)
- Recast Floats as Rational; if this option is not used, the
- system containing Floats may fail to solve because of issues
- with polys. If rational=None, Floats will be recast as
- rationals but the answer will be recast as Floats. If the
- flag is False then nothing will be done to the Floats.
- manual=True (default is False)
- Do not use the polys/matrix method to solve a system of
- equations, solve them one at a time as you might "manually."
- implicit=True (default is False)
- Allows ``solve`` to return a solution for a pattern in terms of
- other functions that contain that pattern; this is only
- needed if the pattern is inside of some invertible function
- like cos, exp, ect.
- particular=True (default is False)
- Instructs ``solve`` to try to find a particular solution to a linear
- system with as many zeros as possible; this is very expensive.
- quick=True (default is False)
- When using particular=True, use a fast heuristic to find a
- solution with many zeros (instead of using the very slow method
- guaranteed to find the largest number of zeros possible).
- cubics=True (default)
- Return explicit solutions when cubic expressions are encountered.
- When False, quartics and quintics are disabled, too.
- quartics=True (default)
- Return explicit solutions when quartic expressions are encountered.
- When False, quintics are disabled, too.
- quintics=True (default)
- Return explicit solutions (if possible) when quintic expressions
- are encountered.
- See Also
- ========
- rsolve: For solving recurrence relationships
- dsolve: For solving differential equations
- """
- from .inequalities import reduce_inequalities
- # set solver types explicitly; as soon as one is False
- # all the rest will be False
- ###########################################################################
- hints = ('cubics', 'quartics', 'quintics')
- default = True
- for k in hints:
- default = flags.setdefault(k, bool(flags.get(k, default)))
- # keeping track of how f was passed since if it is a list
- # a dictionary of results will be returned.
- ###########################################################################
- def _sympified_list(w):
- return list(map(sympify, w if iterable(w) else [w]))
- bare_f = not iterable(f)
- ordered_symbols = (symbols and
- symbols[0] and
- (isinstance(symbols[0], Symbol) or
- is_sequence(symbols[0],
- include=GeneratorType)
- )
- )
- f, symbols = (_sympified_list(w) for w in [f, symbols])
- if isinstance(f, list):
- f = [s for s in f if s is not S.true and s is not True]
- implicit = flags.get('implicit', False)
- # preprocess symbol(s)
- ###########################################################################
- if not symbols:
- # get symbols from equations
- symbols = set().union(*[fi.free_symbols for fi in f])
- if len(symbols) < len(f):
- for fi in f:
- pot = preorder_traversal(fi)
- for p in pot:
- if isinstance(p, AppliedUndef):
- flags['dict'] = True # better show symbols
- symbols.add(p)
- pot.skip() # don't go any deeper
- symbols = list(symbols)
- ordered_symbols = False
- elif len(symbols) == 1 and iterable(symbols[0]):
- symbols = symbols[0]
- # remove symbols the user is not interested in
- exclude = flags.pop('exclude', set())
- if exclude:
- if isinstance(exclude, Expr):
- exclude = [exclude]
- exclude = set().union(*[e.free_symbols for e in sympify(exclude)])
- symbols = [s for s in symbols if s not in exclude]
- # preprocess equation(s)
- ###########################################################################
- for i, fi in enumerate(f):
- if isinstance(fi, (Eq, Ne)):
- if 'ImmutableDenseMatrix' in [type(a).__name__ for a in fi.args]:
- fi = fi.lhs - fi.rhs
- else:
- L, R = fi.args
- if isinstance(R, BooleanAtom):
- L, R = R, L
- if isinstance(L, BooleanAtom):
- if isinstance(fi, Ne):
- L = ~L
- if R.is_Relational:
- fi = ~R if L is S.false else R
- elif R.is_Symbol:
- return L
- elif R.is_Boolean and (~R).is_Symbol:
- return ~L
- else:
- raise NotImplementedError(filldedent('''
- Unanticipated argument of Eq when other arg
- is True or False.
- '''))
- else:
- fi = fi.rewrite(Add, evaluate=False)
- f[i] = fi
- if fi.is_Relational:
- return reduce_inequalities(f, symbols=symbols)
- if isinstance(fi, Poly):
- f[i] = fi.as_expr()
- # rewrite hyperbolics in terms of exp
- f[i] = f[i].replace(lambda w: isinstance(w, HyperbolicFunction) and \
- (len(w.free_symbols & set(symbols)) > 0), lambda w: w.rewrite(exp))
- # if we have a Matrix, we need to iterate over its elements again
- if f[i].is_Matrix:
- bare_f = False
- f.extend(list(f[i]))
- f[i] = S.Zero
- # if we can split it into real and imaginary parts then do so
- freei = f[i].free_symbols
- if freei and all(s.is_extended_real or s.is_imaginary for s in freei):
- fr, fi = f[i].as_real_imag()
- # accept as long as new re, im, arg or atan2 are not introduced
- had = f[i].atoms(re, im, arg, atan2)
- if fr and fi and fr != fi and not any(
- i.atoms(re, im, arg, atan2) - had for i in (fr, fi)):
- if bare_f:
- bare_f = False
- f[i: i + 1] = [fr, fi]
- # real/imag handling -----------------------------
- if any(isinstance(fi, (bool, BooleanAtom)) for fi in f):
- if flags.get('set', False):
- return [], set()
- return []
- for i, fi in enumerate(f):
- # Abs
- while True:
- was = fi
- fi = fi.replace(Abs, lambda arg:
- separatevars(Abs(arg)).rewrite(Piecewise) if arg.has(*symbols)
- else Abs(arg))
- if was == fi:
- break
- for e in fi.find(Abs):
- if e.has(*symbols):
- raise NotImplementedError('solving %s when the argument '
- 'is not real or imaginary.' % e)
- # arg
- fi = fi.replace(arg, lambda a: arg(a).rewrite(atan2).rewrite(atan))
- # save changes
- f[i] = fi
- # see if re(s) or im(s) appear
- freim = [fi for fi in f if fi.has(re, im)]
- if freim:
- irf = []
- for s in symbols:
- if s.is_real or s.is_imaginary:
- continue # neither re(x) nor im(x) will appear
- # if re(s) or im(s) appear, the auxiliary equation must be present
- if any(fi.has(re(s), im(s)) for fi in freim):
- irf.append((s, re(s) + S.ImaginaryUnit*im(s)))
- if irf:
- for s, rhs in irf:
- for i, fi in enumerate(f):
- f[i] = fi.xreplace({s: rhs})
- f.append(s - rhs)
- symbols.extend([re(s), im(s)])
- if bare_f:
- bare_f = False
- flags['dict'] = True
- # end of real/imag handling -----------------------------
- symbols = list(uniq(symbols))
- if not ordered_symbols:
- # we do this to make the results returned canonical in case f
- # contains a system of nonlinear equations; all other cases should
- # be unambiguous
- symbols = sorted(symbols, key=default_sort_key)
- # we can solve for non-symbol entities by replacing them with Dummy symbols
- f, symbols, swap_sym = recast_to_symbols(f, symbols)
- # this is needed in the next two events
- symset = set(symbols)
- # get rid of equations that have no symbols of interest; we don't
- # try to solve them because the user didn't ask and they might be
- # hard to solve; this means that solutions may be given in terms
- # of the eliminated equations e.g. solve((x-y, y-3), x) -> {x: y}
- newf = []
- for fi in f:
- # let the solver handle equations that..
- # - have no symbols but are expressions
- # - have symbols of interest
- # - have no symbols of interest but are constant
- # but when an expression is not constant and has no symbols of
- # interest, it can't change what we obtain for a solution from
- # the remaining equations so we don't include it; and if it's
- # zero it can be removed and if it's not zero, there is no
- # solution for the equation set as a whole
- #
- # The reason for doing this filtering is to allow an answer
- # to be obtained to queries like solve((x - y, y), x); without
- # this mod the return value is []
- ok = False
- if fi.free_symbols & symset:
- ok = True
- else:
- if fi.is_number:
- if fi.is_Number:
- if fi.is_zero:
- continue
- return []
- ok = True
- else:
- if fi.is_constant():
- ok = True
- if ok:
- newf.append(fi)
- if not newf:
- return []
- f = newf
- del newf
- # mask off any Object that we aren't going to invert: Derivative,
- # Integral, etc... so that solving for anything that they contain will
- # give an implicit solution
- seen = set()
- non_inverts = set()
- for fi in f:
- pot = preorder_traversal(fi)
- for p in pot:
- if not isinstance(p, Expr) or isinstance(p, Piecewise):
- pass
- elif (isinstance(p, bool) or
- not p.args or
- p in symset or
- p.is_Add or p.is_Mul or
- p.is_Pow and not implicit or
- p.is_Function and not implicit) and p.func not in (re, im):
- continue
- elif p not in seen:
- seen.add(p)
- if p.free_symbols & symset:
- non_inverts.add(p)
- else:
- continue
- pot.skip()
- del seen
- non_inverts = dict(list(zip(non_inverts, [Dummy() for _ in non_inverts])))
- f = [fi.subs(non_inverts) for fi in f]
- # Both xreplace and subs are needed below: xreplace to force substitution
- # inside Derivative, subs to handle non-straightforward substitutions
- non_inverts = [(v, k.xreplace(swap_sym).subs(swap_sym)) for k, v in non_inverts.items()]
- # rationalize Floats
- floats = False
- if flags.get('rational', True) is not False:
- for i, fi in enumerate(f):
- if fi.has(Float):
- floats = True
- f[i] = nsimplify(fi, rational=True)
- # capture any denominators before rewriting since
- # they may disappear after the rewrite, e.g. issue 14779
- flags['_denominators'] = _simple_dens(f[0], symbols)
- # Any embedded piecewise functions need to be brought out to the
- # top level so that the appropriate strategy gets selected.
- # However, this is necessary only if one of the piecewise
- # functions depends on one of the symbols we are solving for.
- def _has_piecewise(e):
- if e.is_Piecewise:
- return e.has(*symbols)
- return any(_has_piecewise(a) for a in e.args)
- for i, fi in enumerate(f):
- if _has_piecewise(fi):
- f[i] = piecewise_fold(fi)
- #
- # try to get a solution
- ###########################################################################
- if bare_f:
- solution = _solve(f[0], *symbols, **flags)
- else:
- solution = _solve_system(f, symbols, **flags)
- #
- # postprocessing
- ###########################################################################
- # Restore masked-off objects
- if non_inverts:
- def _do_dict(solution):
- return {k: v.subs(non_inverts) for k, v in
- solution.items()}
- for i in range(1):
- if isinstance(solution, dict):
- solution = _do_dict(solution)
- break
- elif solution and isinstance(solution, list):
- if isinstance(solution[0], dict):
- solution = [_do_dict(s) for s in solution]
- break
- elif isinstance(solution[0], tuple):
- solution = [tuple([v.subs(non_inverts) for v in s]) for s
- in solution]
- break
- else:
- solution = [v.subs(non_inverts) for v in solution]
- break
- elif not solution:
- break
- else:
- raise NotImplementedError(filldedent('''
- no handling of %s was implemented''' % solution))
- # Restore original "symbols" if a dictionary is returned.
- # This is not necessary for
- # - the single univariate equation case
- # since the symbol will have been removed from the solution;
- # - the nonlinear poly_system since that only supports zero-dimensional
- # systems and those results come back as a list
- #
- # ** unless there were Derivatives with the symbols, but those were handled
- # above.
- if swap_sym:
- symbols = [swap_sym.get(k, k) for k in symbols]
- if isinstance(solution, dict):
- solution = {swap_sym.get(k, k): v.subs(swap_sym)
- for k, v in solution.items()}
- elif solution and isinstance(solution, list) and isinstance(solution[0], dict):
- for i, sol in enumerate(solution):
- solution[i] = {swap_sym.get(k, k): v.subs(swap_sym)
- for k, v in sol.items()}
- # undo the dictionary solutions returned when the system was only partially
- # solved with poly-system if all symbols are present
- if (
- not flags.get('dict', False) and
- solution and
- ordered_symbols and
- not isinstance(solution, dict) and
- all(isinstance(sol, dict) for sol in solution)
- ):
- solution = [tuple([r.get(s, s) for s in symbols]) for r in solution]
- # Get assumptions about symbols, to filter solutions.
- # Note that if assumptions about a solution can't be verified, it is still
- # returned.
- check = flags.get('check', True)
- # restore floats
- if floats and solution and flags.get('rational', None) is None:
- solution = nfloat(solution, exponent=False)
- if check and solution: # assumption checking
- warn = flags.get('warn', False)
- got_None = [] # solutions for which one or more symbols gave None
- no_False = [] # solutions for which no symbols gave False
- if isinstance(solution, tuple):
- # this has already been checked and is in as_set form
- return solution
- elif isinstance(solution, list):
- if isinstance(solution[0], tuple):
- for sol in solution:
- for symb, val in zip(symbols, sol):
- test = check_assumptions(val, **symb.assumptions0)
- if test is False:
- break
- if test is None:
- got_None.append(sol)
- else:
- no_False.append(sol)
- elif isinstance(solution[0], dict):
- for sol in solution:
- a_None = False
- for symb, val in sol.items():
- test = check_assumptions(val, **symb.assumptions0)
- if test:
- continue
- if test is False:
- break
- a_None = True
- else:
- no_False.append(sol)
- if a_None:
- got_None.append(sol)
- else: # list of expressions
- for sol in solution:
- test = check_assumptions(sol, **symbols[0].assumptions0)
- if test is False:
- continue
- no_False.append(sol)
- if test is None:
- got_None.append(sol)
- elif isinstance(solution, dict):
- a_None = False
- for symb, val in solution.items():
- test = check_assumptions(val, **symb.assumptions0)
- if test:
- continue
- if test is False:
- no_False = None
- break
- a_None = True
- else:
- no_False = solution
- if a_None:
- got_None.append(solution)
- elif isinstance(solution, (Relational, And, Or)):
- if len(symbols) != 1:
- raise ValueError("Length should be 1")
- if warn and symbols[0].assumptions0:
- warnings.warn(filldedent("""
- \tWarning: assumptions about variable '%s' are
- not handled currently.""" % symbols[0]))
- # TODO: check also variable assumptions for inequalities
- else:
- raise TypeError('Unrecognized solution') # improve the checker
- solution = no_False
- if warn and got_None:
- warnings.warn(filldedent("""
- \tWarning: assumptions concerning following solution(s)
- cannot be checked:""" + '\n\t' +
- ', '.join(str(s) for s in got_None)))
- #
- # done
- ###########################################################################
- as_dict = flags.get('dict', False)
- as_set = flags.get('set', False)
- if not as_set and isinstance(solution, list):
- # Make sure that a list of solutions is ordered in a canonical way.
- solution.sort(key=default_sort_key)
- if not as_dict and not as_set:
- return solution or []
- # return a list of mappings or []
- if not solution:
- solution = []
- else:
- if isinstance(solution, dict):
- solution = [solution]
- elif iterable(solution[0]):
- solution = [dict(list(zip(symbols, s))) for s in solution]
- elif isinstance(solution[0], dict):
- pass
- else:
- if len(symbols) != 1:
- raise ValueError("Length should be 1")
- solution = [{symbols[0]: s} for s in solution]
- if as_dict:
- return solution
- assert as_set
- if not solution:
- return [], set()
- k = list(ordered(solution[0].keys()))
- return k, {tuple([s[ki] for ki in k]) for s in solution}
- def _solve(f, *symbols, **flags):
- """
- Return a checked solution for *f* in terms of one or more of the
- symbols. A list should be returned except for the case when a linear
- undetermined-coefficients equation is encountered (in which case
- a dictionary is returned).
- If no method is implemented to solve the equation, a NotImplementedError
- will be raised. In the case that conversion of an expression to a Poly
- gives None a ValueError will be raised.
- """
- not_impl_msg = "No algorithms are implemented to solve equation %s"
- if len(symbols) != 1:
- soln = None
- free = f.free_symbols
- ex = free - set(symbols)
- if len(ex) != 1:
- ind, dep = f.as_independent(*symbols)
- ex = ind.free_symbols & dep.free_symbols
- if len(ex) == 1:
- ex = ex.pop()
- try:
- # soln may come back as dict, list of dicts or tuples, or
- # tuple of symbol list and set of solution tuples
- soln = solve_undetermined_coeffs(f, symbols, ex, **flags)
- except NotImplementedError:
- pass
- if soln:
- if flags.get('simplify', True):
- if isinstance(soln, dict):
- for k in soln:
- soln[k] = simplify(soln[k])
- elif isinstance(soln, list):
- if isinstance(soln[0], dict):
- for d in soln:
- for k in d:
- d[k] = simplify(d[k])
- elif isinstance(soln[0], tuple):
- soln = [tuple(simplify(i) for i in j) for j in soln]
- else:
- raise TypeError('unrecognized args in list')
- elif isinstance(soln, tuple):
- sym, sols = soln
- soln = sym, {tuple(simplify(i) for i in j) for j in sols}
- else:
- raise TypeError('unrecognized solution type')
- return soln
- # find first successful solution
- failed = []
- got_s = set()
- result = []
- for s in symbols:
- xi, v = solve_linear(f, symbols=[s])
- if xi == s:
- # no need to check but we should simplify if desired
- if flags.get('simplify', True):
- v = simplify(v)
- vfree = v.free_symbols
- if got_s and any(ss in vfree for ss in got_s):
- # sol depends on previously solved symbols: discard it
- continue
- got_s.add(xi)
- result.append({xi: v})
- elif xi: # there might be a non-linear solution if xi is not 0
- failed.append(s)
- if not failed:
- return result
- for s in failed:
- try:
- soln = _solve(f, s, **flags)
- for sol in soln:
- if got_s and any(ss in sol.free_symbols for ss in got_s):
- # sol depends on previously solved symbols: discard it
- continue
- got_s.add(s)
- result.append({s: sol})
- except NotImplementedError:
- continue
- if got_s:
- return result
- else:
- raise NotImplementedError(not_impl_msg % f)
- symbol = symbols[0]
- #expand binomials only if it has the unknown symbol
- f = f.replace(lambda e: isinstance(e, binomial) and e.has(symbol),
- lambda e: expand_func(e))
- # /!\ capture this flag then set it to False so that no checking in
- # recursive calls will be done; only the final answer is checked
- flags['check'] = checkdens = check = flags.pop('check', True)
- # build up solutions if f is a Mul
- if f.is_Mul:
- result = set()
- for m in f.args:
- if m in {S.NegativeInfinity, S.ComplexInfinity, S.Infinity}:
- result = set()
- break
- soln = _solve(m, symbol, **flags)
- result.update(set(soln))
- result = list(result)
- if check:
- # all solutions have been checked but now we must
- # check that the solutions do not set denominators
- # in any factor to zero
- dens = flags.get('_denominators', _simple_dens(f, symbols))
- result = [s for s in result if
- not any(checksol(den, {symbol: s}, **flags) for den in
- dens)]
- # set flags for quick exit at end; solutions for each
- # factor were already checked and simplified
- check = False
- flags['simplify'] = False
- elif f.is_Piecewise:
- result = set()
- for i, (expr, cond) in enumerate(f.args):
- if expr.is_zero:
- raise NotImplementedError(
- 'solve cannot represent interval solutions')
- candidates = _solve(expr, symbol, **flags)
- # the explicit condition for this expr is the current cond
- # and none of the previous conditions
- args = [~c for _, c in f.args[:i]] + [cond]
- cond = And(*args)
- for candidate in candidates:
- if candidate in result:
- # an unconditional value was already there
- continue
- try:
- v = cond.subs(symbol, candidate)
- _eval_simplify = getattr(v, '_eval_simplify', None)
- if _eval_simplify is not None:
- # unconditionally take the simpification of v
- v = _eval_simplify(ratio=2, measure=lambda x: 1)
- except TypeError:
- # incompatible type with condition(s)
- continue
- if v == False:
- continue
- if v == True:
- result.add(candidate)
- else:
- result.add(Piecewise(
- (candidate, v),
- (S.NaN, True)))
- # set flags for quick exit at end; solutions for each
- # piece were already checked and simplified
- check = False
- flags['simplify'] = False
- else:
- # first see if it really depends on symbol and whether there
- # is only a linear solution
- f_num, sol = solve_linear(f, symbols=symbols)
- if f_num.is_zero or sol is S.NaN:
- return []
- elif f_num.is_Symbol:
- # no need to check but simplify if desired
- if flags.get('simplify', True):
- sol = simplify(sol)
- return [sol]
- poly = None
- # check for a single Add generator
- if not f_num.is_Add:
- add_args = [i for i in f_num.atoms(Add)
- if symbol in i.free_symbols]
- if len(add_args) == 1:
- gen = add_args[0]
- spart = gen.as_independent(symbol)[1].as_base_exp()[0]
- if spart == symbol:
- try:
- poly = Poly(f_num, spart)
- except PolynomialError:
- pass
- result = False # no solution was obtained
- msg = '' # there is no failure message
- # Poly is generally robust enough to convert anything to
- # a polynomial and tell us the different generators that it
- # contains, so we will inspect the generators identified by
- # polys to figure out what to do.
- # try to identify a single generator that will allow us to solve this
- # as a polynomial, followed (perhaps) by a change of variables if the
- # generator is not a symbol
- try:
- if poly is None:
- poly = Poly(f_num)
- if poly is None:
- raise ValueError('could not convert %s to Poly' % f_num)
- except GeneratorsNeeded:
- simplified_f = simplify(f_num)
- if simplified_f != f_num:
- return _solve(simplified_f, symbol, **flags)
- raise ValueError('expression appears to be a constant')
- gens = [g for g in poly.gens if g.has(symbol)]
- def _as_base_q(x):
- """Return (b**e, q) for x = b**(p*e/q) where p/q is the leading
- Rational of the exponent of x, e.g. exp(-2*x/3) -> (exp(x), 3)
- """
- b, e = x.as_base_exp()
- if e.is_Rational:
- return b, e.q
- if not e.is_Mul:
- return x, 1
- c, ee = e.as_coeff_Mul()
- if c.is_Rational and c is not S.One: # c could be a Float
- return b**ee, c.q
- return x, 1
- if len(gens) > 1:
- # If there is more than one generator, it could be that the
- # generators have the same base but different powers, e.g.
- # >>> Poly(exp(x) + 1/exp(x))
- # Poly(exp(-x) + exp(x), exp(-x), exp(x), domain='ZZ')
- #
- # If unrad was not disabled then there should be no rational
- # exponents appearing as in
- # >>> Poly(sqrt(x) + sqrt(sqrt(x)))
- # Poly(sqrt(x) + x**(1/4), sqrt(x), x**(1/4), domain='ZZ')
- bases, qs = list(zip(*[_as_base_q(g) for g in gens]))
- bases = set(bases)
- if len(bases) > 1 or not all(q == 1 for q in qs):
- funcs = {b for b in bases if b.is_Function}
- trig = {_ for _ in funcs if
- isinstance(_, TrigonometricFunction)}
- other = funcs - trig
- if not other and len(funcs.intersection(trig)) > 1:
- newf = None
- if f_num.is_Add and len(f_num.args) == 2:
- # check for sin(x)**p = cos(x)**p
- _args = f_num.args
- t = a, b = [i.atoms(Function).intersection(
- trig) for i in _args]
- if all(len(i) == 1 for i in t):
- a, b = [i.pop() for i in t]
- if isinstance(a, cos):
- a, b = b, a
- _args = _args[::-1]
- if isinstance(a, sin) and isinstance(b, cos
- ) and a.args[0] == b.args[0]:
- # sin(x) + cos(x) = 0 -> tan(x) + 1 = 0
- newf, _d = (TR2i(_args[0]/_args[1]) + 1
- ).as_numer_denom()
- if not _d.is_Number:
- newf = None
- if newf is None:
- newf = TR1(f_num).rewrite(tan)
- if newf != f_num:
- # don't check the rewritten form --check
- # solutions in the un-rewritten form below
- flags['check'] = False
- result = _solve(newf, symbol, **flags)
- flags['check'] = check
- # just a simple case - see if replacement of single function
- # clears all symbol-dependent functions, e.g.
- # log(x) - log(log(x) - 1) - 3 can be solved even though it has
- # two generators.
- if result is False and funcs:
- funcs = list(ordered(funcs)) # put shallowest function first
- f1 = funcs[0]
- t = Dummy('t')
- # perform the substitution
- ftry = f_num.subs(f1, t)
- # if no Functions left, we can proceed with usual solve
- if not ftry.has(symbol):
- cv_sols = _solve(ftry, t, **flags)
- cv_inv = _solve(t - f1, symbol, **flags)[0]
- sols = list()
- for sol in cv_sols:
- sols.append(cv_inv.subs(t, sol))
- result = list(ordered(sols))
- if result is False:
- msg = 'multiple generators %s' % gens
- else:
- # e.g. case where gens are exp(x), exp(-x)
- u = bases.pop()
- t = Dummy('t')
- inv = _solve(u - t, symbol, **flags)
- if isinstance(u, (Pow, exp)):
- # this will be resolved by factor in _tsolve but we might
- # as well try a simple expansion here to get things in
- # order so something like the following will work now without
- # having to factor:
- #
- # >>> eq = (exp(I*(-x-2))+exp(I*(x+2)))
- # >>> eq.subs(exp(x),y) # fails
- # exp(I*(-x - 2)) + exp(I*(x + 2))
- # >>> eq.expand().subs(exp(x),y) # works
- # y**I*exp(2*I) + y**(-I)*exp(-2*I)
- def _expand(p):
- b, e = p.as_base_exp()
- e = expand_mul(e)
- return expand_power_exp(b**e)
- ftry = f_num.replace(
- lambda w: w.is_Pow or isinstance(w, exp),
- _expand).subs(u, t)
- if not ftry.has(symbol):
- soln = _solve(ftry, t, **flags)
- sols = list()
- for sol in soln:
- for i in inv:
- sols.append(i.subs(t, sol))
- result = list(ordered(sols))
- elif len(gens) == 1:
- # There is only one generator that we are interested in, but
- # there may have been more than one generator identified by
- # polys (e.g. for symbols other than the one we are interested
- # in) so recast the poly in terms of our generator of interest.
- # Also use composite=True with f_num since Poly won't update
- # poly as documented in issue 8810.
- poly = Poly(f_num, gens[0], composite=True)
- # if we aren't on the tsolve-pass, use roots
- if not flags.pop('tsolve', False):
- soln = None
- deg = poly.degree()
- flags['tsolve'] = True
- hints = ('cubics', 'quartics', 'quintics')
- solvers = {h: flags.get(h) for h in hints}
- soln = roots(poly, **solvers)
- if sum(soln.values()) < deg:
- # e.g. roots(32*x**5 + 400*x**4 + 2032*x**3 +
- # 5000*x**2 + 6250*x + 3189) -> {}
- # so all_roots is used and RootOf instances are
- # returned *unless* the system is multivariate
- # or high-order EX domain.
- try:
- soln = poly.all_roots()
- except NotImplementedError:
- if not flags.get('incomplete', True):
- raise NotImplementedError(
- filldedent('''
- Neither high-order multivariate polynomials
- nor sorting of EX-domain polynomials is supported.
- If you want to see any results, pass keyword incomplete=True to
- solve; to see numerical values of roots
- for univariate expressions, use nroots.
- '''))
- else:
- pass
- else:
- soln = list(soln.keys())
- if soln is not None:
- u = poly.gen
- if u != symbol:
- try:
- t = Dummy('t')
- iv = _solve(u - t, symbol, **flags)
- soln = list(ordered({i.subs(t, s) for i in iv for s in soln}))
- except NotImplementedError:
- # perhaps _tsolve can handle f_num
- soln = None
- else:
- check = False # only dens need to be checked
- if soln is not None:
- if len(soln) > 2:
- # if the flag wasn't set then unset it since high-order
- # results are quite long. Perhaps one could base this
- # decision on a certain critical length of the
- # roots. In addition, wester test M2 has an expression
- # whose roots can be shown to be real with the
- # unsimplified form of the solution whereas only one of
- # the simplified forms appears to be real.
- flags['simplify'] = flags.get('simplify', False)
- result = soln
- # fallback if above fails
- # -----------------------
- if result is False:
- # try unrad
- if flags.pop('_unrad', True):
- try:
- u = unrad(f_num, symbol)
- except (ValueError, NotImplementedError):
- u = False
- if u:
- eq, cov = u
- if cov:
- isym, ieq = cov
- inv = _solve(ieq, symbol, **flags)[0]
- rv = {inv.subs(isym, xi) for xi in _solve(eq, isym, **flags)}
- else:
- try:
- rv = set(_solve(eq, symbol, **flags))
- except NotImplementedError:
- rv = None
- if rv is not None:
- result = list(ordered(rv))
- # if the flag wasn't set then unset it since unrad results
- # can be quite long or of very high order
- flags['simplify'] = flags.get('simplify', False)
- else:
- pass # for coverage
- # try _tsolve
- if result is False:
- flags.pop('tsolve', None) # allow tsolve to be used on next pass
- try:
- soln = _tsolve(f_num, symbol, **flags)
- if soln is not None:
- result = soln
- except PolynomialError:
- pass
- # ----------- end of fallback ----------------------------
- if result is False:
- raise NotImplementedError('\n'.join([msg, not_impl_msg % f]))
- if flags.get('simplify', True):
- result = list(map(simplify, result))
- # we just simplified the solution so we now set the flag to
- # False so the simplification doesn't happen again in checksol()
- flags['simplify'] = False
- if checkdens:
- # reject any result that makes any denom. affirmatively 0;
- # if in doubt, keep it
- dens = _simple_dens(f, symbols)
- result = [s for s in result if
- not any(checksol(d, {symbol: s}, **flags)
- for d in dens)]
- if check:
- # keep only results if the check is not False
- result = [r for r in result if
- checksol(f_num, {symbol: r}, **flags) is not False]
- return result
- def _solve_system(exprs, symbols, **flags):
- if not exprs:
- return []
- if flags.pop('_split', True):
- # Split the system into connected components
- V = exprs
- symsset = set(symbols)
- exprsyms = {e: e.free_symbols & symsset for e in exprs}
- E = []
- sym_indices = {sym: i for i, sym in enumerate(symbols)}
- for n, e1 in enumerate(exprs):
- for e2 in exprs[:n]:
- # Equations are connected if they share a symbol
- if exprsyms[e1] & exprsyms[e2]:
- E.append((e1, e2))
- G = V, E
- subexprs = connected_components(G)
- if len(subexprs) > 1:
- subsols = []
- for subexpr in subexprs:
- subsyms = set()
- for e in subexpr:
- subsyms |= exprsyms[e]
- subsyms = list(sorted(subsyms, key = lambda x: sym_indices[x]))
- flags['_split'] = False # skip split step
- subsol = _solve_system(subexpr, subsyms, **flags)
- if not isinstance(subsol, list):
- subsol = [subsol]
- subsols.append(subsol)
- # Full solution is cartesion product of subsystems
- sols = []
- for soldicts in product(*subsols):
- sols.append(dict(item for sd in soldicts
- for item in sd.items()))
- # Return a list with one dict as just the dict
- if len(sols) == 1:
- return sols[0]
- return sols
- polys = []
- dens = set()
- failed = []
- result = False
- linear = False
- manual = flags.get('manual', False)
- checkdens = check = flags.get('check', True)
- for j, g in enumerate(exprs):
- dens.update(_simple_dens(g, symbols))
- i, d = _invert(g, *symbols)
- g = d - i
- g = g.as_numer_denom()[0]
- if manual:
- failed.append(g)
- continue
- poly = g.as_poly(*symbols, extension=True)
- if poly is not None:
- polys.append(poly)
- else:
- failed.append(g)
- if not polys:
- solved_syms = []
- else:
- if all(p.is_linear for p in polys):
- n, m = len(polys), len(symbols)
- matrix = zeros(n, m + 1)
- for i, poly in enumerate(polys):
- for monom, coeff in poly.terms():
- try:
- j = monom.index(1)
- matrix[i, j] = coeff
- except ValueError:
- matrix[i, m] = -coeff
- # returns a dictionary ({symbols: values}) or None
- if flags.pop('particular', False):
- result = minsolve_linear_system(matrix, *symbols, **flags)
- else:
- result = solve_linear_system(matrix, *symbols, **flags)
- if failed:
- if result:
- solved_syms = list(result.keys())
- else:
- solved_syms = []
- else:
- linear = True
- else:
- if len(symbols) > len(polys):
- free = set().union(*[p.free_symbols for p in polys])
- free = list(ordered(free.intersection(symbols)))
- got_s = set()
- result = []
- for syms in subsets(free, len(polys)):
- try:
- # returns [] or list of tuples of solutions for syms
- res = solve_poly_system(polys, *syms)
- if res:
- for r in res:
- skip = False
- for r1 in r:
- if got_s and any(ss in r1.free_symbols
- for ss in got_s):
- # sol depends on previously
- # solved symbols: discard it
- skip = True
- if not skip:
- got_s.update(syms)
- result.extend([dict(list(zip(syms, r)))])
- except NotImplementedError:
- pass
- if got_s:
- solved_syms = list(got_s)
- else:
- raise NotImplementedError('no valid subset found')
- else:
- try:
- result = solve_poly_system(polys, *symbols)
- if result:
- solved_syms = symbols
- # we don't know here if the symbols provided
- # were given or not, so let solve resolve that.
- # A list of dictionaries is going to always be
- # returned from here.
- result = [dict(list(zip(solved_syms, r))) for r in result]
- except NotImplementedError:
- failed.extend([g.as_expr() for g in polys])
- solved_syms = []
- result = None
- if result:
- if isinstance(result, dict):
- result = [result]
- else:
- result = [{}]
- if failed:
- # For each failed equation, see if we can solve for one of the
- # remaining symbols from that equation. If so, we update the
- # solution set and continue with the next failed equation,
- # repeating until we are done or we get an equation that can't
- # be solved.
- def _ok_syms(e, sort=False):
- rv = (e.free_symbols - solved_syms) & legal
- # Solve first for symbols that have lower degree in the equation.
- # Ideally we want to solve firstly for symbols that appear linearly
- # with rational coefficients e.g. if e = x*y + z then we should
- # solve for z first.
- def key(sym):
- ep = e.as_poly(sym)
- if ep is None:
- complexity = (S.Infinity, S.Infinity, S.Infinity)
- else:
- coeff_syms = ep.LC().free_symbols
- complexity = (ep.degree(), len(coeff_syms & rv), len(coeff_syms))
- return complexity + (default_sort_key(sym),)
- if sort:
- rv = sorted(rv, key=key)
- return rv
- solved_syms = set(solved_syms) # set of symbols we have solved for
- legal = set(symbols) # what we are interested in
- # sort so equation with the fewest potential symbols is first
- u = Dummy() # used in solution checking
- for eq in ordered(failed, lambda _: len(_ok_syms(_))):
- newresult = []
- bad_results = []
- got_s = set()
- hit = False
- for r in result:
- # update eq with everything that is known so far
- eq2 = eq.subs(r)
- # if check is True then we see if it satisfies this
- # equation, otherwise we just accept it
- if check and r:
- b = checksol(u, u, eq2, minimal=True)
- if b is not None:
- # this solution is sufficient to know whether
- # it is valid or not so we either accept or
- # reject it, then continue
- if b:
- newresult.append(r)
- else:
- bad_results.append(r)
- continue
- # search for a symbol amongst those available that
- # can be solved for
- ok_syms = _ok_syms(eq2, sort=True)
- if not ok_syms:
- if r:
- newresult.append(r)
- break # skip as it's independent of desired symbols
- for s in ok_syms:
- try:
- soln = _solve(eq2, s, **flags)
- except NotImplementedError:
- continue
- # put each solution in r and append the now-expanded
- # result in the new result list; use copy since the
- # solution for s in being added in-place
- for sol in soln:
- if got_s and any(ss in sol.free_symbols for ss in got_s):
- # sol depends on previously solved symbols: discard it
- continue
- rnew = r.copy()
- for k, v in r.items():
- rnew[k] = v.subs(s, sol)
- # and add this new solution
- rnew[s] = sol
- # check that it is independent of previous solutions
- iset = set(rnew.items())
- for i in newresult:
- if len(i) < len(iset) and not set(i.items()) - iset:
- # this is a superset of a known solution that
- # is smaller
- break
- else:
- # keep it
- newresult.append(rnew)
- hit = True
- got_s.add(s)
- if not hit:
- raise NotImplementedError('could not solve %s' % eq2)
- else:
- result = newresult
- for b in bad_results:
- if b in result:
- result.remove(b)
- default_simplify = bool(failed) # rely on system-solvers to simplify
- if flags.get('simplify', default_simplify):
- for r in result:
- for k in r:
- r[k] = simplify(r[k])
- flags['simplify'] = False # don't need to do so in checksol now
- if checkdens:
- result = [r for r in result
- if not any(checksol(d, r, **flags) for d in dens)]
- if check and not linear:
- result = [r for r in result
- if not any(checksol(e, r, **flags) is False for e in exprs)]
- result = [r for r in result if r]
- if linear and result:
- result = result[0]
- return result
- def solve_linear(lhs, rhs=0, symbols=[], exclude=[]):
- r"""
- Return a tuple derived from ``f = lhs - rhs`` that is one of
- the following: ``(0, 1)``, ``(0, 0)``, ``(symbol, solution)``, ``(n, d)``.
- Explanation
- ===========
- ``(0, 1)`` meaning that ``f`` is independent of the symbols in *symbols*
- that are not in *exclude*.
- ``(0, 0)`` meaning that there is no solution to the equation amongst the
- symbols given. If the first element of the tuple is not zero, then the
- function is guaranteed to be dependent on a symbol in *symbols*.
- ``(symbol, solution)`` where symbol appears linearly in the numerator of
- ``f``, is in *symbols* (if given), and is not in *exclude* (if given). No
- simplification is done to ``f`` other than a ``mul=True`` expansion, so the
- solution will correspond strictly to a unique solution.
- ``(n, d)`` where ``n`` and ``d`` are the numerator and denominator of ``f``
- when the numerator was not linear in any symbol of interest; ``n`` will
- never be a symbol unless a solution for that symbol was found (in which case
- the second element is the solution, not the denominator).
- Examples
- ========
- >>> from sympy import cancel, Pow
- ``f`` is independent of the symbols in *symbols* that are not in
- *exclude*:
- >>> from sympy import cos, sin, solve_linear
- >>> from sympy.abc import x, y, z
- >>> eq = y*cos(x)**2 + y*sin(x)**2 - y # = y*(1 - 1) = 0
- >>> solve_linear(eq)
- (0, 1)
- >>> eq = cos(x)**2 + sin(x)**2 # = 1
- >>> solve_linear(eq)
- (0, 1)
- >>> solve_linear(x, exclude=[x])
- (0, 1)
- The variable ``x`` appears as a linear variable in each of the
- following:
- >>> solve_linear(x + y**2)
- (x, -y**2)
- >>> solve_linear(1/x - y**2)
- (x, y**(-2))
- When not linear in ``x`` or ``y`` then the numerator and denominator are
- returned:
- >>> solve_linear(x**2/y**2 - 3)
- (x**2 - 3*y**2, y**2)
- If the numerator of the expression is a symbol, then ``(0, 0)`` is
- returned if the solution for that symbol would have set any
- denominator to 0:
- >>> eq = 1/(1/x - 2)
- >>> eq.as_numer_denom()
- (x, 1 - 2*x)
- >>> solve_linear(eq)
- (0, 0)
- But automatic rewriting may cause a symbol in the denominator to
- appear in the numerator so a solution will be returned:
- >>> (1/x)**-1
- x
- >>> solve_linear((1/x)**-1)
- (x, 0)
- Use an unevaluated expression to avoid this:
- >>> solve_linear(Pow(1/x, -1, evaluate=False))
- (0, 0)
- If ``x`` is allowed to cancel in the following expression, then it
- appears to be linear in ``x``, but this sort of cancellation is not
- done by ``solve_linear`` so the solution will always satisfy the
- original expression without causing a division by zero error.
- >>> eq = x**2*(1/x - z**2/x)
- >>> solve_linear(cancel(eq))
- (x, 0)
- >>> solve_linear(eq)
- (x**2*(1 - z**2), x)
- A list of symbols for which a solution is desired may be given:
- >>> solve_linear(x + y + z, symbols=[y])
- (y, -x - z)
- A list of symbols to ignore may also be given:
- >>> solve_linear(x + y + z, exclude=[x])
- (y, -x - z)
- (A solution for ``y`` is obtained because it is the first variable
- from the canonically sorted list of symbols that had a linear
- solution.)
- """
- if isinstance(lhs, Eq):
- if rhs:
- raise ValueError(filldedent('''
- If lhs is an Equality, rhs must be 0 but was %s''' % rhs))
- rhs = lhs.rhs
- lhs = lhs.lhs
- dens = None
- eq = lhs - rhs
- n, d = eq.as_numer_denom()
- if not n:
- return S.Zero, S.One
- free = n.free_symbols
- if not symbols:
- symbols = free
- else:
- bad = [s for s in symbols if not s.is_Symbol]
- if bad:
- if len(bad) == 1:
- bad = bad[0]
- if len(symbols) == 1:
- eg = 'solve(%s, %s)' % (eq, symbols[0])
- else:
- eg = 'solve(%s, *%s)' % (eq, list(symbols))
- raise ValueError(filldedent('''
- solve_linear only handles symbols, not %s. To isolate
- non-symbols use solve, e.g. >>> %s <<<.
- ''' % (bad, eg)))
- symbols = free.intersection(symbols)
- symbols = symbols.difference(exclude)
- if not symbols:
- return S.Zero, S.One
- from sympy.integrals.integrals import Integral
- # derivatives are easy to do but tricky to analyze to see if they
- # are going to disallow a linear solution, so for simplicity we
- # just evaluate the ones that have the symbols of interest
- derivs = defaultdict(list)
- for der in n.atoms(Derivative):
- csym = der.free_symbols & symbols
- for c in csym:
- derivs[c].append(der)
- all_zero = True
- for xi in sorted(symbols, key=default_sort_key): # canonical order
- # if there are derivatives in this var, calculate them now
- if isinstance(derivs[xi], list):
- derivs[xi] = {der: der.doit() for der in derivs[xi]}
- newn = n.subs(derivs[xi])
- dnewn_dxi = newn.diff(xi)
- # dnewn_dxi can be nonzero if it survives differentation by any
- # of its free symbols
- free = dnewn_dxi.free_symbols
- if dnewn_dxi and (not free or any(dnewn_dxi.diff(s) for s in free) or free == symbols):
- all_zero = False
- if dnewn_dxi is S.NaN:
- break
- if xi not in dnewn_dxi.free_symbols:
- vi = -1/dnewn_dxi*(newn.subs(xi, 0))
- if dens is None:
- dens = _simple_dens(eq, symbols)
- if not any(checksol(di, {xi: vi}, minimal=True) is True
- for di in dens):
- # simplify any trivial integral
- irep = [(i, i.doit()) for i in vi.atoms(Integral) if
- i.function.is_number]
- # do a slight bit of simplification
- vi = expand_mul(vi.subs(irep))
- return xi, vi
- if all_zero:
- return S.Zero, S.One
- if n.is_Symbol: # no solution for this symbol was found
- return S.Zero, S.Zero
- return n, d
- def minsolve_linear_system(system, *symbols, **flags):
- r"""
- Find a particular solution to a linear system.
- Explanation
- ===========
- In particular, try to find a solution with the minimal possible number
- of non-zero variables using a naive algorithm with exponential complexity.
- If ``quick=True``, a heuristic is used.
- """
- quick = flags.get('quick', False)
- # Check if there are any non-zero solutions at all
- s0 = solve_linear_system(system, *symbols, **flags)
- if not s0 or all(v == 0 for v in s0.values()):
- return s0
- if quick:
- # We just solve the system and try to heuristically find a nice
- # solution.
- s = solve_linear_system(system, *symbols)
- def update(determined, solution):
- delete = []
- for k, v in solution.items():
- solution[k] = v.subs(determined)
- if not solution[k].free_symbols:
- delete.append(k)
- determined[k] = solution[k]
- for k in delete:
- del solution[k]
- determined = {}
- update(determined, s)
- while s:
- # NOTE sort by default_sort_key to get deterministic result
- k = max((k for k in s.values()),
- key=lambda x: (len(x.free_symbols), default_sort_key(x)))
- x = max(k.free_symbols, key=default_sort_key)
- if len(k.free_symbols) != 1:
- determined[x] = S.Zero
- else:
- val = solve(k)[0]
- if val == 0 and all(v.subs(x, val) == 0 for v in s.values()):
- determined[x] = S.One
- else:
- determined[x] = val
- update(determined, s)
- return determined
- else:
- # We try to select n variables which we want to be non-zero.
- # All others will be assumed zero. We try to solve the modified system.
- # If there is a non-trivial solution, just set the free variables to
- # one. If we do this for increasing n, trying all combinations of
- # variables, we will find an optimal solution.
- # We speed up slightly by starting at one less than the number of
- # variables the quick method manages.
- from itertools import combinations
- N = len(symbols)
- bestsol = minsolve_linear_system(system, *symbols, quick=True)
- n0 = len([x for x in bestsol.values() if x != 0])
- for n in range(n0 - 1, 1, -1):
- debug('minsolve: %s' % n)
- thissol = None
- for nonzeros in combinations(list(range(N)), n):
- subm = Matrix([system.col(i).T for i in nonzeros] + [system.col(-1).T]).T
- s = solve_linear_system(subm, *[symbols[i] for i in nonzeros])
- if s and not all(v == 0 for v in s.values()):
- subs = [(symbols[v], S.One) for v in nonzeros]
- for k, v in s.items():
- s[k] = v.subs(subs)
- for sym in symbols:
- if sym not in s:
- if symbols.index(sym) in nonzeros:
- s[sym] = S.One
- else:
- s[sym] = S.Zero
- thissol = s
- break
- if thissol is None:
- break
- bestsol = thissol
- return bestsol
- def solve_linear_system(system, *symbols, **flags):
- r"""
- Solve system of $N$ linear equations with $M$ variables, which means
- both under- and overdetermined systems are supported.
- Explanation
- ===========
- The possible number of solutions is zero, one, or infinite. Respectively,
- this procedure will return None or a dictionary with solutions. In the
- case of underdetermined systems, all arbitrary parameters are skipped.
- This may cause a situation in which an empty dictionary is returned.
- In that case, all symbols can be assigned arbitrary values.
- Input to this function is a $N\times M + 1$ matrix, which means it has
- to be in augmented form. If you prefer to enter $N$ equations and $M$
- unknowns then use ``solve(Neqs, *Msymbols)`` instead. Note: a local
- copy of the matrix is made by this routine so the matrix that is
- passed will not be modified.
- The algorithm used here is fraction-free Gaussian elimination,
- which results, after elimination, in an upper-triangular matrix.
- Then solutions are found using back-substitution. This approach
- is more efficient and compact than the Gauss-Jordan method.
- Examples
- ========
- >>> from sympy import Matrix, solve_linear_system
- >>> from sympy.abc import x, y
- Solve the following system::
- x + 4 y == 2
- -2 x + y == 14
- >>> system = Matrix(( (1, 4, 2), (-2, 1, 14)))
- >>> solve_linear_system(system, x, y)
- {x: -6, y: 2}
- A degenerate system returns an empty dictionary:
- >>> system = Matrix(( (0,0,0), (0,0,0) ))
- >>> solve_linear_system(system, x, y)
- {}
- """
- assert system.shape[1] == len(symbols) + 1
- # This is just a wrapper for solve_lin_sys
- eqs = list(system * Matrix(symbols + (-1,)))
- eqs, ring = sympy_eqs_to_ring(eqs, symbols)
- sol = solve_lin_sys(eqs, ring, _raw=False)
- if sol is not None:
- sol = {sym:val for sym, val in sol.items() if sym != val}
- return sol
- def solve_undetermined_coeffs(equ, coeffs, sym, **flags):
- r"""
- Solve equation of a type $p(x; a_1, \ldots, a_k) = q(x)$ where both
- $p$ and $q$ are univariate polynomials that depend on $k$ parameters.
- Explanation
- ===========
- The result of this function is a dictionary with symbolic values of those
- parameters with respect to coefficients in $q$.
- This function accepts both equations class instances and ordinary
- SymPy expressions. Specification of parameters and variables is
- obligatory for efficiency and simplicity reasons.
- Examples
- ========
- >>> from sympy import Eq, solve_undetermined_coeffs
- >>> from sympy.abc import a, b, c, x
- >>> solve_undetermined_coeffs(Eq(2*a*x + a+b, x), [a, b], x)
- {a: 1/2, b: -1/2}
- >>> solve_undetermined_coeffs(Eq(a*c*x + a+b, x), [a, b], x)
- {a: 1/c, b: -1/c}
- """
- if isinstance(equ, Eq):
- # got equation, so move all the
- # terms to the left hand side
- equ = equ.lhs - equ.rhs
- equ = cancel(equ).as_numer_denom()[0]
- system = list(collect(equ.expand(), sym, evaluate=False).values())
- if not any(equ.has(sym) for equ in system):
- # consecutive powers in the input expressions have
- # been successfully collected, so solve remaining
- # system using Gaussian elimination algorithm
- return solve(system, *coeffs, **flags)
- else:
- return None # no solutions
- def solve_linear_system_LU(matrix, syms):
- """
- Solves the augmented matrix system using ``LUsolve`` and returns a
- dictionary in which solutions are keyed to the symbols of *syms* as ordered.
- Explanation
- ===========
- The matrix must be invertible.
- Examples
- ========
- >>> from sympy import Matrix, solve_linear_system_LU
- >>> from sympy.abc import x, y, z
- >>> solve_linear_system_LU(Matrix([
- ... [1, 2, 0, 1],
- ... [3, 2, 2, 1],
- ... [2, 0, 0, 1]]), [x, y, z])
- {x: 1/2, y: 1/4, z: -1/2}
- See Also
- ========
- LUsolve
- """
- if matrix.rows != matrix.cols - 1:
- raise ValueError("Rows should be equal to columns - 1")
- A = matrix[:matrix.rows, :matrix.rows]
- b = matrix[:, matrix.cols - 1:]
- soln = A.LUsolve(b)
- solutions = {}
- for i in range(soln.rows):
- solutions[syms[i]] = soln[i, 0]
- return solutions
- def det_perm(M):
- """
- Return the determinant of *M* by using permutations to select factors.
- Explanation
- ===========
- For sizes larger than 8 the number of permutations becomes prohibitively
- large, or if there are no symbols in the matrix, it is better to use the
- standard determinant routines (e.g., ``M.det()``.)
- See Also
- ========
- det_minor
- det_quick
- """
- args = []
- s = True
- n = M.rows
- list_ = M.flat()
- for perm in generate_bell(n):
- fac = []
- idx = 0
- for j in perm:
- fac.append(list_[idx + j])
- idx += n
- term = Mul(*fac) # disaster with unevaluated Mul -- takes forever for n=7
- args.append(term if s else -term)
- s = not s
- return Add(*args)
- def det_minor(M):
- """
- Return the ``det(M)`` computed from minors without
- introducing new nesting in products.
- See Also
- ========
- det_perm
- det_quick
- """
- n = M.rows
- if n == 2:
- return M[0, 0]*M[1, 1] - M[1, 0]*M[0, 1]
- else:
- return sum([(1, -1)[i % 2]*Add(*[M[0, i]*d for d in
- Add.make_args(det_minor(M.minor_submatrix(0, i)))])
- if M[0, i] else S.Zero for i in range(n)])
- def det_quick(M, method=None):
- """
- Return ``det(M)`` assuming that either
- there are lots of zeros or the size of the matrix
- is small. If this assumption is not met, then the normal
- Matrix.det function will be used with method = ``method``.
- See Also
- ========
- det_minor
- det_perm
- """
- if any(i.has(Symbol) for i in M):
- if M.rows < 8 and all(i.has(Symbol) for i in M):
- return det_perm(M)
- return det_minor(M)
- else:
- return M.det(method=method) if method else M.det()
- def inv_quick(M):
- """Return the inverse of ``M``, assuming that either
- there are lots of zeros or the size of the matrix
- is small.
- """
- if not all(i.is_Number for i in M):
- if not any(i.is_Number for i in M):
- det = lambda _: det_perm(_)
- else:
- det = lambda _: det_minor(_)
- else:
- return M.inv()
- n = M.rows
- d = det(M)
- if d == S.Zero:
- raise NonInvertibleMatrixError("Matrix det == 0; not invertible")
- ret = zeros(n)
- s1 = -1
- for i in range(n):
- s = s1 = -s1
- for j in range(n):
- di = det(M.minor_submatrix(i, j))
- ret[j, i] = s*di/d
- s = -s
- return ret
- # these are functions that have multiple inverse values per period
- multi_inverses = {
- sin: lambda x: (asin(x), S.Pi - asin(x)),
- cos: lambda x: (acos(x), 2*S.Pi - acos(x)),
- }
- def _tsolve(eq, sym, **flags):
- """
- Helper for ``_solve`` that solves a transcendental equation with respect
- to the given symbol. Various equations containing powers and logarithms,
- can be solved.
- There is currently no guarantee that all solutions will be returned or
- that a real solution will be favored over a complex one.
- Either a list of potential solutions will be returned or None will be
- returned (in the case that no method was known to get a solution
- for the equation). All other errors (like the inability to cast an
- expression as a Poly) are unhandled.
- Examples
- ========
- >>> from sympy import log
- >>> from sympy.solvers.solvers import _tsolve as tsolve
- >>> from sympy.abc import x
- >>> tsolve(3**(2*x + 5) - 4, x)
- [-5/2 + log(2)/log(3), (-5*log(3)/2 + log(2) + I*pi)/log(3)]
- >>> tsolve(log(x) + 2*x, x)
- [LambertW(2)/2]
- """
- if 'tsolve_saw' not in flags:
- flags['tsolve_saw'] = []
- if eq in flags['tsolve_saw']:
- return None
- else:
- flags['tsolve_saw'].append(eq)
- rhs, lhs = _invert(eq, sym)
- if lhs == sym:
- return [rhs]
- try:
- if lhs.is_Add:
- # it's time to try factoring; powdenest is used
- # to try get powers in standard form for better factoring
- f = factor(powdenest(lhs - rhs))
- if f.is_Mul:
- return _solve(f, sym, **flags)
- if rhs:
- f = logcombine(lhs, force=flags.get('force', True))
- if f.count(log) != lhs.count(log):
- if isinstance(f, log):
- return _solve(f.args[0] - exp(rhs), sym, **flags)
- return _tsolve(f - rhs, sym, **flags)
- elif lhs.is_Pow:
- if lhs.exp.is_Integer:
- if lhs - rhs != eq:
- return _solve(lhs - rhs, sym, **flags)
- if sym not in lhs.exp.free_symbols:
- return _solve(lhs.base - rhs**(1/lhs.exp), sym, **flags)
- # _tsolve calls this with Dummy before passing the actual number in.
- if any(t.is_Dummy for t in rhs.free_symbols):
- raise NotImplementedError # _tsolve will call here again...
- # a ** g(x) == 0
- if not rhs:
- # f(x)**g(x) only has solutions where f(x) == 0 and g(x) != 0 at
- # the same place
- sol_base = _solve(lhs.base, sym, **flags)
- return [s for s in sol_base if lhs.exp.subs(sym, s) != 0]
- # a ** g(x) == b
- if not lhs.base.has(sym):
- if lhs.base == 0:
- return _solve(lhs.exp, sym, **flags) if rhs != 0 else []
- # Gets most solutions...
- if lhs.base == rhs.as_base_exp()[0]:
- # handles case when bases are equal
- sol = _solve(lhs.exp - rhs.as_base_exp()[1], sym, **flags)
- else:
- # handles cases when bases are not equal and exp
- # may or may not be equal
- sol = _solve(exp(log(lhs.base)*lhs.exp)-exp(log(rhs)), sym, **flags)
- # Check for duplicate solutions
- def equal(expr1, expr2):
- _ = Dummy()
- eq = checksol(expr1 - _, _, expr2)
- if eq is None:
- if nsimplify(expr1) != nsimplify(expr2):
- return False
- # they might be coincidentally the same
- # so check more rigorously
- eq = expr1.equals(expr2)
- return eq
- # Guess a rational exponent
- e_rat = nsimplify(log(abs(rhs))/log(abs(lhs.base)))
- e_rat = simplify(posify(e_rat)[0])
- n, d = fraction(e_rat)
- if expand(lhs.base**n - rhs**d) == 0:
- sol = [s for s in sol if not equal(lhs.exp.subs(sym, s), e_rat)]
- sol.extend(_solve(lhs.exp - e_rat, sym, **flags))
- return list(ordered(set(sol)))
- # f(x) ** g(x) == c
- else:
- sol = []
- logform = lhs.exp*log(lhs.base) - log(rhs)
- if logform != lhs - rhs:
- try:
- sol.extend(_solve(logform, sym, **flags))
- except NotImplementedError:
- pass
- # Collect possible solutions and check with substitution later.
- check = []
- if rhs == 1:
- # f(x) ** g(x) = 1 -- g(x)=0 or f(x)=+-1
- check.extend(_solve(lhs.exp, sym, **flags))
- check.extend(_solve(lhs.base - 1, sym, **flags))
- check.extend(_solve(lhs.base + 1, sym, **flags))
- elif rhs.is_Rational:
- for d in (i for i in divisors(abs(rhs.p)) if i != 1):
- e, t = integer_log(rhs.p, d)
- if not t:
- continue # rhs.p != d**b
- for s in divisors(abs(rhs.q)):
- if s**e== rhs.q:
- r = Rational(d, s)
- check.extend(_solve(lhs.base - r, sym, **flags))
- check.extend(_solve(lhs.base + r, sym, **flags))
- check.extend(_solve(lhs.exp - e, sym, **flags))
- elif rhs.is_irrational:
- b_l, e_l = lhs.base.as_base_exp()
- n, d = (e_l*lhs.exp).as_numer_denom()
- b, e = sqrtdenest(rhs).as_base_exp()
- check = [sqrtdenest(i) for i in (_solve(lhs.base - b, sym, **flags))]
- check.extend([sqrtdenest(i) for i in (_solve(lhs.exp - e, sym, **flags))])
- if e_l*d != 1:
- check.extend(_solve(b_l**n - rhs**(e_l*d), sym, **flags))
- for s in check:
- ok = checksol(eq, sym, s)
- if ok is None:
- ok = eq.subs(sym, s).equals(0)
- if ok:
- sol.append(s)
- return list(ordered(set(sol)))
- elif lhs.is_Function and len(lhs.args) == 1:
- if lhs.func in multi_inverses:
- # sin(x) = 1/3 -> x - asin(1/3) & x - (pi - asin(1/3))
- soln = []
- for i in multi_inverses[type(lhs)](rhs):
- soln.extend(_solve(lhs.args[0] - i, sym, **flags))
- return list(ordered(soln))
- elif lhs.func == LambertW:
- return _solve(lhs.args[0] - rhs*exp(rhs), sym, **flags)
- rewrite = lhs.rewrite(exp)
- if rewrite != lhs:
- return _solve(rewrite - rhs, sym, **flags)
- except NotImplementedError:
- pass
- # maybe it is a lambert pattern
- if flags.pop('bivariate', True):
- # lambert forms may need some help being recognized, e.g. changing
- # 2**(3*x) + x**3*log(2)**3 + 3*x**2*log(2)**2 + 3*x*log(2) + 1
- # to 2**(3*x) + (x*log(2) + 1)**3
- g = _filtered_gens(eq.as_poly(), sym)
- up_or_log = set()
- for gi in g:
- if isinstance(gi, (exp, log)) or (gi.is_Pow and gi.base == S.Exp1):
- up_or_log.add(gi)
- elif gi.is_Pow:
- gisimp = powdenest(expand_power_exp(gi))
- if gisimp.is_Pow and sym in gisimp.exp.free_symbols:
- up_or_log.add(gi)
- eq_down = expand_log(expand_power_exp(eq)).subs(
- dict(list(zip(up_or_log, [0]*len(up_or_log)))))
- eq = expand_power_exp(factor(eq_down, deep=True) + (eq - eq_down))
- rhs, lhs = _invert(eq, sym)
- if lhs.has(sym):
- try:
- poly = lhs.as_poly()
- g = _filtered_gens(poly, sym)
- _eq = lhs - rhs
- sols = _solve_lambert(_eq, sym, g)
- # use a simplified form if it satisfies eq
- # and has fewer operations
- for n, s in enumerate(sols):
- ns = nsimplify(s)
- if ns != s and ns.count_ops() <= s.count_ops():
- ok = checksol(_eq, sym, ns)
- if ok is None:
- ok = _eq.subs(sym, ns).equals(0)
- if ok:
- sols[n] = ns
- return sols
- except NotImplementedError:
- # maybe it's a convoluted function
- if len(g) == 2:
- try:
- gpu = bivariate_type(lhs - rhs, *g)
- if gpu is None:
- raise NotImplementedError
- g, p, u = gpu
- flags['bivariate'] = False
- inversion = _tsolve(g - u, sym, **flags)
- if inversion:
- sol = _solve(p, u, **flags)
- return list(ordered({i.subs(u, s)
- for i in inversion for s in sol}))
- except NotImplementedError:
- pass
- else:
- pass
- if flags.pop('force', True):
- flags['force'] = False
- pos, reps = posify(lhs - rhs)
- if rhs == S.ComplexInfinity:
- return []
- for u, s in reps.items():
- if s == sym:
- break
- else:
- u = sym
- if pos.has(u):
- try:
- soln = _solve(pos, u, **flags)
- return list(ordered([s.subs(reps) for s in soln]))
- except NotImplementedError:
- pass
- else:
- pass # here for coverage
- return # here for coverage
- # TODO: option for calculating J numerically
- @conserve_mpmath_dps
- def nsolve(*args, dict=False, **kwargs):
- r"""
- Solve a nonlinear equation system numerically: ``nsolve(f, [args,] x0,
- modules=['mpmath'], **kwargs)``.
- Explanation
- ===========
- ``f`` is a vector function of symbolic expressions representing the system.
- *args* are the variables. If there is only one variable, this argument can
- be omitted. ``x0`` is a starting vector close to a solution.
- Use the modules keyword to specify which modules should be used to
- evaluate the function and the Jacobian matrix. Make sure to use a module
- that supports matrices. For more information on the syntax, please see the
- docstring of ``lambdify``.
- If the keyword arguments contain ``dict=True`` (default is False) ``nsolve``
- will return a list (perhaps empty) of solution mappings. This might be
- especially useful if you want to use ``nsolve`` as a fallback to solve since
- using the dict argument for both methods produces return values of
- consistent type structure. Please note: to keep this consistent with
- ``solve``, the solution will be returned in a list even though ``nsolve``
- (currently at least) only finds one solution at a time.
- Overdetermined systems are supported.
- Examples
- ========
- >>> from sympy import Symbol, nsolve
- >>> import mpmath
- >>> mpmath.mp.dps = 15
- >>> x1 = Symbol('x1')
- >>> x2 = Symbol('x2')
- >>> f1 = 3 * x1**2 - 2 * x2**2 - 1
- >>> f2 = x1**2 - 2 * x1 + x2**2 + 2 * x2 - 8
- >>> print(nsolve((f1, f2), (x1, x2), (-1, 1)))
- Matrix([[-1.19287309935246], [1.27844411169911]])
- For one-dimensional functions the syntax is simplified:
- >>> from sympy import sin, nsolve
- >>> from sympy.abc import x
- >>> nsolve(sin(x), x, 2)
- 3.14159265358979
- >>> nsolve(sin(x), 2)
- 3.14159265358979
- To solve with higher precision than the default, use the prec argument:
- >>> from sympy import cos
- >>> nsolve(cos(x) - x, 1)
- 0.739085133215161
- >>> nsolve(cos(x) - x, 1, prec=50)
- 0.73908513321516064165531208767387340401341175890076
- >>> cos(_)
- 0.73908513321516064165531208767387340401341175890076
- To solve for complex roots of real functions, a nonreal initial point
- must be specified:
- >>> from sympy import I
- >>> nsolve(x**2 + 2, I)
- 1.4142135623731*I
- ``mpmath.findroot`` is used and you can find their more extensive
- documentation, especially concerning keyword parameters and
- available solvers. Note, however, that functions which are very
- steep near the root, the verification of the solution may fail. In
- this case you should use the flag ``verify=False`` and
- independently verify the solution.
- >>> from sympy import cos, cosh
- >>> f = cos(x)*cosh(x) - 1
- >>> nsolve(f, 3.14*100)
- Traceback (most recent call last):
- ...
- ValueError: Could not find root within given tolerance. (1.39267e+230 > 2.1684e-19)
- >>> ans = nsolve(f, 3.14*100, verify=False); ans
- 312.588469032184
- >>> f.subs(x, ans).n(2)
- 2.1e+121
- >>> (f/f.diff(x)).subs(x, ans).n(2)
- 7.4e-15
- One might safely skip the verification if bounds of the root are known
- and a bisection method is used:
- >>> bounds = lambda i: (3.14*i, 3.14*(i + 1))
- >>> nsolve(f, bounds(100), solver='bisect', verify=False)
- 315.730061685774
- Alternatively, a function may be better behaved when the
- denominator is ignored. Since this is not always the case, however,
- the decision of what function to use is left to the discretion of
- the user.
- >>> eq = x**2/(1 - x)/(1 - 2*x)**2 - 100
- >>> nsolve(eq, 0.46)
- Traceback (most recent call last):
- ...
- ValueError: Could not find root within given tolerance. (10000 > 2.1684e-19)
- Try another starting point or tweak arguments.
- >>> nsolve(eq.as_numer_denom()[0], 0.46)
- 0.46792545969349058
- """
- # there are several other SymPy functions that use method= so
- # guard against that here
- if 'method' in kwargs:
- raise ValueError(filldedent('''
- Keyword "method" should not be used in this context. When using
- some mpmath solvers directly, the keyword "method" is
- used, but when using nsolve (and findroot) the keyword to use is
- "solver".'''))
- if 'prec' in kwargs:
- import mpmath
- mpmath.mp.dps = kwargs.pop('prec')
- # keyword argument to return result as a dictionary
- as_dict = dict
- from builtins import dict # to unhide the builtin
- # interpret arguments
- if len(args) == 3:
- f = args[0]
- fargs = args[1]
- x0 = args[2]
- if iterable(fargs) and iterable(x0):
- if len(x0) != len(fargs):
- raise TypeError('nsolve expected exactly %i guess vectors, got %i'
- % (len(fargs), len(x0)))
- elif len(args) == 2:
- f = args[0]
- fargs = None
- x0 = args[1]
- if iterable(f):
- raise TypeError('nsolve expected 3 arguments, got 2')
- elif len(args) < 2:
- raise TypeError('nsolve expected at least 2 arguments, got %i'
- % len(args))
- else:
- raise TypeError('nsolve expected at most 3 arguments, got %i'
- % len(args))
- modules = kwargs.get('modules', ['mpmath'])
- if iterable(f):
- f = list(f)
- for i, fi in enumerate(f):
- if isinstance(fi, Eq):
- f[i] = fi.lhs - fi.rhs
- f = Matrix(f).T
- if iterable(x0):
- x0 = list(x0)
- if not isinstance(f, Matrix):
- # assume it's a SymPy expression
- if isinstance(f, Eq):
- f = f.lhs - f.rhs
- syms = f.free_symbols
- if fargs is None:
- fargs = syms.copy().pop()
- if not (len(syms) == 1 and (fargs in syms or fargs[0] in syms)):
- raise ValueError(filldedent('''
- expected a one-dimensional and numerical function'''))
- # the function is much better behaved if there is no denominator
- # but sending the numerator is left to the user since sometimes
- # the function is better behaved when the denominator is present
- # e.g., issue 11768
- f = lambdify(fargs, f, modules)
- x = sympify(findroot(f, x0, **kwargs))
- if as_dict:
- return [{fargs: x}]
- return x
- if len(fargs) > f.cols:
- raise NotImplementedError(filldedent('''
- need at least as many equations as variables'''))
- verbose = kwargs.get('verbose', False)
- if verbose:
- print('f(x):')
- print(f)
- # derive Jacobian
- J = f.jacobian(fargs)
- if verbose:
- print('J(x):')
- print(J)
- # create functions
- f = lambdify(fargs, f.T, modules)
- J = lambdify(fargs, J, modules)
- # solve the system numerically
- x = findroot(f, x0, J=J, **kwargs)
- if as_dict:
- return [dict(zip(fargs, [sympify(xi) for xi in x]))]
- return Matrix(x)
- def _invert(eq, *symbols, **kwargs):
- """
- Return tuple (i, d) where ``i`` is independent of *symbols* and ``d``
- contains symbols.
- Explanation
- ===========
- ``i`` and ``d`` are obtained after recursively using algebraic inversion
- until an uninvertible ``d`` remains. If there are no free symbols then
- ``d`` will be zero. Some (but not necessarily all) solutions to the
- expression ``i - d`` will be related to the solutions of the original
- expression.
- Examples
- ========
- >>> from sympy.solvers.solvers import _invert as invert
- >>> from sympy import sqrt, cos
- >>> from sympy.abc import x, y
- >>> invert(x - 3)
- (3, x)
- >>> invert(3)
- (3, 0)
- >>> invert(2*cos(x) - 1)
- (1/2, cos(x))
- >>> invert(sqrt(x) - 3)
- (3, sqrt(x))
- >>> invert(sqrt(x) + y, x)
- (-y, sqrt(x))
- >>> invert(sqrt(x) + y, y)
- (-sqrt(x), y)
- >>> invert(sqrt(x) + y, x, y)
- (0, sqrt(x) + y)
- If there is more than one symbol in a power's base and the exponent
- is not an Integer, then the principal root will be used for the
- inversion:
- >>> invert(sqrt(x + y) - 2)
- (4, x + y)
- >>> invert(sqrt(x + y) - 2)
- (4, x + y)
- If the exponent is an Integer, setting ``integer_power`` to True
- will force the principal root to be selected:
- >>> invert(x**2 - 4, integer_power=True)
- (2, x)
- """
- eq = sympify(eq)
- if eq.args:
- # make sure we are working with flat eq
- eq = eq.func(*eq.args)
- free = eq.free_symbols
- if not symbols:
- symbols = free
- if not free & set(symbols):
- return eq, S.Zero
- dointpow = bool(kwargs.get('integer_power', False))
- lhs = eq
- rhs = S.Zero
- while True:
- was = lhs
- while True:
- indep, dep = lhs.as_independent(*symbols)
- # dep + indep == rhs
- if lhs.is_Add:
- # this indicates we have done it all
- if indep.is_zero:
- break
- lhs = dep
- rhs -= indep
- # dep * indep == rhs
- else:
- # this indicates we have done it all
- if indep is S.One:
- break
- lhs = dep
- rhs /= indep
- # collect like-terms in symbols
- if lhs.is_Add:
- terms = {}
- for a in lhs.args:
- i, d = a.as_independent(*symbols)
- terms.setdefault(d, []).append(i)
- if any(len(v) > 1 for v in terms.values()):
- args = []
- for d, i in terms.items():
- if len(i) > 1:
- args.append(Add(*i)*d)
- else:
- args.append(i[0]*d)
- lhs = Add(*args)
- # if it's a two-term Add with rhs = 0 and two powers we can get the
- # dependent terms together, e.g. 3*f(x) + 2*g(x) -> f(x)/g(x) = -2/3
- if lhs.is_Add and not rhs and len(lhs.args) == 2 and \
- not lhs.is_polynomial(*symbols):
- a, b = ordered(lhs.args)
- ai, ad = a.as_independent(*symbols)
- bi, bd = b.as_independent(*symbols)
- if any(_ispow(i) for i in (ad, bd)):
- a_base, a_exp = ad.as_base_exp()
- b_base, b_exp = bd.as_base_exp()
- if a_base == b_base:
- # a = -b
- lhs = powsimp(powdenest(ad/bd))
- rhs = -bi/ai
- else:
- rat = ad/bd
- _lhs = powsimp(ad/bd)
- if _lhs != rat:
- lhs = _lhs
- rhs = -bi/ai
- elif ai == -bi:
- if isinstance(ad, Function) and ad.func == bd.func:
- if len(ad.args) == len(bd.args) == 1:
- lhs = ad.args[0] - bd.args[0]
- elif len(ad.args) == len(bd.args):
- # should be able to solve
- # f(x, y) - f(2 - x, 0) == 0 -> x == 1
- raise NotImplementedError(
- 'equal function with more than 1 argument')
- else:
- raise ValueError(
- 'function with different numbers of args')
- elif lhs.is_Mul and any(_ispow(a) for a in lhs.args):
- lhs = powsimp(powdenest(lhs))
- if lhs.is_Function:
- if hasattr(lhs, 'inverse') and lhs.inverse() is not None and len(lhs.args) == 1:
- # -1
- # f(x) = g -> x = f (g)
- #
- # /!\ inverse should not be defined if there are multiple values
- # for the function -- these are handled in _tsolve
- #
- rhs = lhs.inverse()(rhs)
- lhs = lhs.args[0]
- elif isinstance(lhs, atan2):
- y, x = lhs.args
- lhs = 2*atan(y/(sqrt(x**2 + y**2) + x))
- elif lhs.func == rhs.func:
- if len(lhs.args) == len(rhs.args) == 1:
- lhs = lhs.args[0]
- rhs = rhs.args[0]
- elif len(lhs.args) == len(rhs.args):
- # should be able to solve
- # f(x, y) == f(2, 3) -> x == 2
- # f(x, x + y) == f(2, 3) -> x == 2
- raise NotImplementedError(
- 'equal function with more than 1 argument')
- else:
- raise ValueError(
- 'function with different numbers of args')
- if rhs and lhs.is_Pow and lhs.exp.is_Integer and lhs.exp < 0:
- lhs = 1/lhs
- rhs = 1/rhs
- # base**a = b -> base = b**(1/a) if
- # a is an Integer and dointpow=True (this gives real branch of root)
- # a is not an Integer and the equation is multivariate and the
- # base has more than 1 symbol in it
- # The rationale for this is that right now the multi-system solvers
- # doesn't try to resolve generators to see, for example, if the whole
- # system is written in terms of sqrt(x + y) so it will just fail, so we
- # do that step here.
- if lhs.is_Pow and (
- lhs.exp.is_Integer and dointpow or not lhs.exp.is_Integer and
- len(symbols) > 1 and len(lhs.base.free_symbols & set(symbols)) > 1):
- rhs = rhs**(1/lhs.exp)
- lhs = lhs.base
- if lhs == was:
- break
- return rhs, lhs
- def unrad(eq, *syms, **flags):
- """
- Remove radicals with symbolic arguments and return (eq, cov),
- None, or raise an error.
- Explanation
- ===========
- None is returned if there are no radicals to remove.
- NotImplementedError is raised if there are radicals and they cannot be
- removed or if the relationship between the original symbols and the
- change of variable needed to rewrite the system as a polynomial cannot
- be solved.
- Otherwise the tuple, ``(eq, cov)``, is returned where:
- *eq*, ``cov``
- *eq* is an equation without radicals (in the symbol(s) of
- interest) whose solutions are a superset of the solutions to the
- original expression. *eq* might be rewritten in terms of a new
- variable; the relationship to the original variables is given by
- ``cov`` which is a list containing ``v`` and ``v**p - b`` where
- ``p`` is the power needed to clear the radical and ``b`` is the
- radical now expressed as a polynomial in the symbols of interest.
- For example, for sqrt(2 - x) the tuple would be
- ``(c, c**2 - 2 + x)``. The solutions of *eq* will contain
- solutions to the original equation (if there are any).
- *syms*
- An iterable of symbols which, if provided, will limit the focus of
- radical removal: only radicals with one or more of the symbols of
- interest will be cleared. All free symbols are used if *syms* is not
- set.
- *flags* are used internally for communication during recursive calls.
- Two options are also recognized:
- ``take``, when defined, is interpreted as a single-argument function
- that returns True if a given Pow should be handled.
- Radicals can be removed from an expression if:
- * All bases of the radicals are the same; a change of variables is
- done in this case.
- * If all radicals appear in one term of the expression.
- * There are only four terms with sqrt() factors or there are less than
- four terms having sqrt() factors.
- * There are only two terms with radicals.
- Examples
- ========
- >>> from sympy.solvers.solvers import unrad
- >>> from sympy.abc import x
- >>> from sympy import sqrt, Rational, root
- >>> unrad(sqrt(x)*x**Rational(1, 3) + 2)
- (x**5 - 64, [])
- >>> unrad(sqrt(x) + root(x + 1, 3))
- (-x**3 + x**2 + 2*x + 1, [])
- >>> eq = sqrt(x) + root(x, 3) - 2
- >>> unrad(eq)
- (_p**3 + _p**2 - 2, [_p, _p**6 - x])
- """
- uflags = dict(check=False, simplify=False)
- def _cov(p, e):
- if cov:
- # XXX - uncovered
- oldp, olde = cov
- if Poly(e, p).degree(p) in (1, 2):
- cov[:] = [p, olde.subs(oldp, _solve(e, p, **uflags)[0])]
- else:
- raise NotImplementedError
- else:
- cov[:] = [p, e]
- def _canonical(eq, cov):
- if cov:
- # change symbol to vanilla so no solutions are eliminated
- p, e = cov
- rep = {p: Dummy(p.name)}
- eq = eq.xreplace(rep)
- cov = [p.xreplace(rep), e.xreplace(rep)]
- # remove constants and powers of factors since these don't change
- # the location of the root; XXX should factor or factor_terms be used?
- eq = factor_terms(_mexpand(eq.as_numer_denom()[0], recursive=True), clear=True)
- if eq.is_Mul:
- args = []
- for f in eq.args:
- if f.is_number:
- continue
- if f.is_Pow:
- args.append(f.base)
- else:
- args.append(f)
- eq = Mul(*args) # leave as Mul for more efficient solving
- # make the sign canonical
- margs = list(Mul.make_args(eq))
- changed = False
- for i, m in enumerate(margs):
- if m.could_extract_minus_sign():
- margs[i] = -m
- changed = True
- if changed:
- eq = Mul(*margs, evaluate=False)
- return eq, cov
- def _Q(pow):
- # return leading Rational of denominator of Pow's exponent
- c = pow.as_base_exp()[1].as_coeff_Mul()[0]
- if not c.is_Rational:
- return S.One
- return c.q
- # define the _take method that will determine whether a term is of interest
- def _take(d):
- # return True if coefficient of any factor's exponent's den is not 1
- for pow in Mul.make_args(d):
- if not pow.is_Pow:
- continue
- if _Q(pow) == 1:
- continue
- if pow.free_symbols & syms:
- return True
- return False
- _take = flags.setdefault('_take', _take)
- if isinstance(eq, Eq):
- eq = eq.lhs - eq.rhs # XXX legacy Eq as Eqn support
- elif not isinstance(eq, Expr):
- return
- cov, nwas, rpt = [flags.setdefault(k, v) for k, v in
- sorted(dict(cov=[], n=None, rpt=0).items())]
- # preconditioning
- eq = powdenest(factor_terms(eq, radical=True, clear=True))
- eq = eq.as_numer_denom()[0]
- eq = _mexpand(eq, recursive=True)
- if eq.is_number:
- return
- # see if there are radicals in symbols of interest
- syms = set(syms) or eq.free_symbols # _take uses this
- poly = eq.as_poly()
- gens = [g for g in poly.gens if _take(g)]
- if not gens:
- return
- # recast poly in terms of eigen-gens
- poly = eq.as_poly(*gens)
- # - an exponent has a symbol of interest (don't handle)
- if any(g.exp.has(*syms) for g in gens):
- return
- def _rads_bases_lcm(poly):
- # if all the bases are the same or all the radicals are in one
- # term, `lcm` will be the lcm of the denominators of the
- # exponents of the radicals
- lcm = 1
- rads = set()
- bases = set()
- for g in poly.gens:
- q = _Q(g)
- if q != 1:
- rads.add(g)
- lcm = ilcm(lcm, q)
- bases.add(g.base)
- return rads, bases, lcm
- rads, bases, lcm = _rads_bases_lcm(poly)
- covsym = Dummy('p', nonnegative=True)
- # only keep in syms symbols that actually appear in radicals;
- # and update gens
- newsyms = set()
- for r in rads:
- newsyms.update(syms & r.free_symbols)
- if newsyms != syms:
- syms = newsyms
- # get terms together that have common generators
- drad = dict(list(zip(rads, list(range(len(rads))))))
- rterms = {(): []}
- args = Add.make_args(poly.as_expr())
- for t in args:
- if _take(t):
- common = set(t.as_poly().gens).intersection(rads)
- key = tuple(sorted([drad[i] for i in common]))
- else:
- key = ()
- rterms.setdefault(key, []).append(t)
- others = Add(*rterms.pop(()))
- rterms = [Add(*rterms[k]) for k in rterms.keys()]
- # the output will depend on the order terms are processed, so
- # make it canonical quickly
- rterms = list(reversed(list(ordered(rterms))))
- ok = False # we don't have a solution yet
- depth = sqrt_depth(eq)
- if len(rterms) == 1 and not (rterms[0].is_Add and lcm > 2):
- eq = rterms[0]**lcm - ((-others)**lcm)
- ok = True
- else:
- if len(rterms) == 1 and rterms[0].is_Add:
- rterms = list(rterms[0].args)
- if len(bases) == 1:
- b = bases.pop()
- if len(syms) > 1:
- x = b.free_symbols
- else:
- x = syms
- x = list(ordered(x))[0]
- try:
- inv = _solve(covsym**lcm - b, x, **uflags)
- if not inv:
- raise NotImplementedError
- eq = poly.as_expr().subs(b, covsym**lcm).subs(x, inv[0])
- _cov(covsym, covsym**lcm - b)
- return _canonical(eq, cov)
- except NotImplementedError:
- pass
- if len(rterms) == 2:
- if not others:
- eq = rterms[0]**lcm - (-rterms[1])**lcm
- ok = True
- elif not log(lcm, 2).is_Integer:
- # the lcm-is-power-of-two case is handled below
- r0, r1 = rterms
- if flags.get('_reverse', False):
- r1, r0 = r0, r1
- i0 = _rads0, _bases0, lcm0 = _rads_bases_lcm(r0.as_poly())
- i1 = _rads1, _bases1, lcm1 = _rads_bases_lcm(r1.as_poly())
- for reverse in range(2):
- if reverse:
- i0, i1 = i1, i0
- r0, r1 = r1, r0
- _rads1, _, lcm1 = i1
- _rads1 = Mul(*_rads1)
- t1 = _rads1**lcm1
- c = covsym**lcm1 - t1
- for x in syms:
- try:
- sol = _solve(c, x, **uflags)
- if not sol:
- raise NotImplementedError
- neweq = r0.subs(x, sol[0]) + covsym*r1/_rads1 + \
- others
- tmp = unrad(neweq, covsym)
- if tmp:
- eq, newcov = tmp
- if newcov:
- newp, newc = newcov
- _cov(newp, c.subs(covsym,
- _solve(newc, covsym, **uflags)[0]))
- else:
- _cov(covsym, c)
- else:
- eq = neweq
- _cov(covsym, c)
- ok = True
- break
- except NotImplementedError:
- if reverse:
- raise NotImplementedError(
- 'no successful change of variable found')
- else:
- pass
- if ok:
- break
- elif len(rterms) == 3:
- # two cube roots and another with order less than 5
- # (so an analytical solution can be found) or a base
- # that matches one of the cube root bases
- info = [_rads_bases_lcm(i.as_poly()) for i in rterms]
- RAD = 0
- BASES = 1
- LCM = 2
- if info[0][LCM] != 3:
- info.append(info.pop(0))
- rterms.append(rterms.pop(0))
- elif info[1][LCM] != 3:
- info.append(info.pop(1))
- rterms.append(rterms.pop(1))
- if info[0][LCM] == info[1][LCM] == 3:
- if info[1][BASES] != info[2][BASES]:
- info[0], info[1] = info[1], info[0]
- rterms[0], rterms[1] = rterms[1], rterms[0]
- if info[1][BASES] == info[2][BASES]:
- eq = rterms[0]**3 + (rterms[1] + rterms[2] + others)**3
- ok = True
- elif info[2][LCM] < 5:
- # a*root(A, 3) + b*root(B, 3) + others = c
- a, b, c, d, A, B = [Dummy(i) for i in 'abcdAB']
- # zz represents the unraded expression into which the
- # specifics for this case are substituted
- zz = (c - d)*(A**3*a**9 + 3*A**2*B*a**6*b**3 -
- 3*A**2*a**6*c**3 + 9*A**2*a**6*c**2*d - 9*A**2*a**6*c*d**2 +
- 3*A**2*a**6*d**3 + 3*A*B**2*a**3*b**6 + 21*A*B*a**3*b**3*c**3 -
- 63*A*B*a**3*b**3*c**2*d + 63*A*B*a**3*b**3*c*d**2 -
- 21*A*B*a**3*b**3*d**3 + 3*A*a**3*c**6 - 18*A*a**3*c**5*d +
- 45*A*a**3*c**4*d**2 - 60*A*a**3*c**3*d**3 + 45*A*a**3*c**2*d**4 -
- 18*A*a**3*c*d**5 + 3*A*a**3*d**6 + B**3*b**9 - 3*B**2*b**6*c**3 +
- 9*B**2*b**6*c**2*d - 9*B**2*b**6*c*d**2 + 3*B**2*b**6*d**3 +
- 3*B*b**3*c**6 - 18*B*b**3*c**5*d + 45*B*b**3*c**4*d**2 -
- 60*B*b**3*c**3*d**3 + 45*B*b**3*c**2*d**4 - 18*B*b**3*c*d**5 +
- 3*B*b**3*d**6 - c**9 + 9*c**8*d - 36*c**7*d**2 + 84*c**6*d**3 -
- 126*c**5*d**4 + 126*c**4*d**5 - 84*c**3*d**6 + 36*c**2*d**7 -
- 9*c*d**8 + d**9)
- def _t(i):
- b = Mul(*info[i][RAD])
- return cancel(rterms[i]/b), Mul(*info[i][BASES])
- aa, AA = _t(0)
- bb, BB = _t(1)
- cc = -rterms[2]
- dd = others
- eq = zz.xreplace(dict(zip(
- (a, A, b, B, c, d),
- (aa, AA, bb, BB, cc, dd))))
- ok = True
- # handle power-of-2 cases
- if not ok:
- if log(lcm, 2).is_Integer and (not others and
- len(rterms) == 4 or len(rterms) < 4):
- def _norm2(a, b):
- return a**2 + b**2 + 2*a*b
- if len(rterms) == 4:
- # (r0+r1)**2 - (r2+r3)**2
- r0, r1, r2, r3 = rterms
- eq = _norm2(r0, r1) - _norm2(r2, r3)
- ok = True
- elif len(rterms) == 3:
- # (r1+r2)**2 - (r0+others)**2
- r0, r1, r2 = rterms
- eq = _norm2(r1, r2) - _norm2(r0, others)
- ok = True
- elif len(rterms) == 2:
- # r0**2 - (r1+others)**2
- r0, r1 = rterms
- eq = r0**2 - _norm2(r1, others)
- ok = True
- new_depth = sqrt_depth(eq) if ok else depth
- rpt += 1 # XXX how many repeats with others unchanging is enough?
- if not ok or (
- nwas is not None and len(rterms) == nwas and
- new_depth is not None and new_depth == depth and
- rpt > 3):
- raise NotImplementedError('Cannot remove all radicals')
- flags.update(dict(cov=cov, n=len(rterms), rpt=rpt))
- neq = unrad(eq, *syms, **flags)
- if neq:
- eq, cov = neq
- eq, cov = _canonical(eq, cov)
- return eq, cov
- # Delayed imports
- from sympy.solvers.bivariate import (
- bivariate_type, _solve_lambert, _filtered_gens)
|