123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218 |
- """Implementation of RootOf class and related tools. """
- from sympy.core.basic import Basic
- from sympy.core import (S, Expr, Integer, Float, I, oo, Add, Lambda,
- symbols, sympify, Rational, Dummy)
- from sympy.core.cache import cacheit
- from sympy.core.relational import is_le
- from sympy.core.sorting import ordered
- from sympy.polys.domains import QQ
- from sympy.polys.polyerrors import (
- MultivariatePolynomialError,
- GeneratorsNeeded,
- PolynomialError,
- DomainError)
- from sympy.polys.polyfuncs import symmetrize, viete
- from sympy.polys.polyroots import (
- roots_linear, roots_quadratic, roots_binomial,
- preprocess_roots, roots)
- from sympy.polys.polytools import Poly, PurePoly, factor
- from sympy.polys.rationaltools import together
- from sympy.polys.rootisolation import (
- dup_isolate_complex_roots_sqf,
- dup_isolate_real_roots_sqf)
- from sympy.utilities import lambdify, public, sift, numbered_symbols
- from mpmath import mpf, mpc, findroot, workprec
- from mpmath.libmp.libmpf import dps_to_prec, prec_to_dps
- from sympy.multipledispatch import dispatch
- from itertools import chain
- __all__ = ['CRootOf']
- class _pure_key_dict:
- """A minimal dictionary that makes sure that the key is a
- univariate PurePoly instance.
- Examples
- ========
- Only the following actions are guaranteed:
- >>> from sympy.polys.rootoftools import _pure_key_dict
- >>> from sympy import PurePoly
- >>> from sympy.abc import x, y
- 1) creation
- >>> P = _pure_key_dict()
- 2) assignment for a PurePoly or univariate polynomial
- >>> P[x] = 1
- >>> P[PurePoly(x - y, x)] = 2
- 3) retrieval based on PurePoly key comparison (use this
- instead of the get method)
- >>> P[y]
- 1
- 4) KeyError when trying to retrieve a nonexisting key
- >>> P[y + 1]
- Traceback (most recent call last):
- ...
- KeyError: PurePoly(y + 1, y, domain='ZZ')
- 5) ability to query with ``in``
- >>> x + 1 in P
- False
- NOTE: this is a *not* a dictionary. It is a very basic object
- for internal use that makes sure to always address its cache
- via PurePoly instances. It does not, for example, implement
- ``get`` or ``setdefault``.
- """
- def __init__(self):
- self._dict = {}
- def __getitem__(self, k):
- if not isinstance(k, PurePoly):
- if not (isinstance(k, Expr) and len(k.free_symbols) == 1):
- raise KeyError
- k = PurePoly(k, expand=False)
- return self._dict[k]
- def __setitem__(self, k, v):
- if not isinstance(k, PurePoly):
- if not (isinstance(k, Expr) and len(k.free_symbols) == 1):
- raise ValueError('expecting univariate expression')
- k = PurePoly(k, expand=False)
- self._dict[k] = v
- def __contains__(self, k):
- try:
- self[k]
- return True
- except KeyError:
- return False
- _reals_cache = _pure_key_dict()
- _complexes_cache = _pure_key_dict()
- def _pure_factors(poly):
- _, factors = poly.factor_list()
- return [(PurePoly(f, expand=False), m) for f, m in factors]
- def _imag_count_of_factor(f):
- """Return the number of imaginary roots for irreducible
- univariate polynomial ``f``.
- """
- terms = [(i, j) for (i,), j in f.terms()]
- if any(i % 2 for i, j in terms):
- return 0
- # update signs
- even = [(i, I**i*j) for i, j in terms]
- even = Poly.from_dict(dict(even), Dummy('x'))
- return int(even.count_roots(-oo, oo))
- @public
- def rootof(f, x, index=None, radicals=True, expand=True):
- """An indexed root of a univariate polynomial.
- Returns either a :obj:`ComplexRootOf` object or an explicit
- expression involving radicals.
- Parameters
- ==========
- f : Expr
- Univariate polynomial.
- x : Symbol, optional
- Generator for ``f``.
- index : int or Integer
- radicals : bool
- Return a radical expression if possible.
- expand : bool
- Expand ``f``.
- """
- return CRootOf(f, x, index=index, radicals=radicals, expand=expand)
- @public
- class RootOf(Expr):
- """Represents a root of a univariate polynomial.
- Base class for roots of different kinds of polynomials.
- Only complex roots are currently supported.
- """
- __slots__ = ('poly',)
- def __new__(cls, f, x, index=None, radicals=True, expand=True):
- """Construct a new ``CRootOf`` object for ``k``-th root of ``f``."""
- return rootof(f, x, index=index, radicals=radicals, expand=expand)
- @public
- class ComplexRootOf(RootOf):
- """Represents an indexed complex root of a polynomial.
- Roots of a univariate polynomial separated into disjoint
- real or complex intervals and indexed in a fixed order:
- * real roots come first and are sorted in increasing order;
- * complex roots come next and are sorted primarily by increasing
- real part, secondarily by increasing imaginary part.
- Currently only rational coefficients are allowed.
- Can be imported as ``CRootOf``. To avoid confusion, the
- generator must be a Symbol.
- Examples
- ========
- >>> from sympy import CRootOf, rootof
- >>> from sympy.abc import x
- CRootOf is a way to reference a particular root of a
- polynomial. If there is a rational root, it will be returned:
- >>> CRootOf.clear_cache() # for doctest reproducibility
- >>> CRootOf(x**2 - 4, 0)
- -2
- Whether roots involving radicals are returned or not
- depends on whether the ``radicals`` flag is true (which is
- set to True with rootof):
- >>> CRootOf(x**2 - 3, 0)
- CRootOf(x**2 - 3, 0)
- >>> CRootOf(x**2 - 3, 0, radicals=True)
- -sqrt(3)
- >>> rootof(x**2 - 3, 0)
- -sqrt(3)
- The following cannot be expressed in terms of radicals:
- >>> r = rootof(4*x**5 + 16*x**3 + 12*x**2 + 7, 0); r
- CRootOf(4*x**5 + 16*x**3 + 12*x**2 + 7, 0)
- The root bounds can be seen, however, and they are used by the
- evaluation methods to get numerical approximations for the root.
- >>> interval = r._get_interval(); interval
- (-1, 0)
- >>> r.evalf(2)
- -0.98
- The evalf method refines the width of the root bounds until it
- guarantees that any decimal approximation within those bounds
- will satisfy the desired precision. It then stores the refined
- interval so subsequent requests at or below the requested
- precision will not have to recompute the root bounds and will
- return very quickly.
- Before evaluation above, the interval was
- >>> interval
- (-1, 0)
- After evaluation it is now
- >>> r._get_interval() # doctest: +SKIP
- (-165/169, -206/211)
- To reset all intervals for a given polynomial, the :meth:`_reset` method
- can be called from any CRootOf instance of the polynomial:
- >>> r._reset()
- >>> r._get_interval()
- (-1, 0)
- The :meth:`eval_approx` method will also find the root to a given
- precision but the interval is not modified unless the search
- for the root fails to converge within the root bounds. And
- the secant method is used to find the root. (The ``evalf``
- method uses bisection and will always update the interval.)
- >>> r.eval_approx(2)
- -0.98
- The interval needed to be slightly updated to find that root:
- >>> r._get_interval()
- (-1, -1/2)
- The ``evalf_rational`` will compute a rational approximation
- of the root to the desired accuracy or precision.
- >>> r.eval_rational(n=2)
- -69629/71318
- >>> t = CRootOf(x**3 + 10*x + 1, 1)
- >>> t.eval_rational(1e-1)
- 15/256 - 805*I/256
- >>> t.eval_rational(1e-1, 1e-4)
- 3275/65536 - 414645*I/131072
- >>> t.eval_rational(1e-4, 1e-4)
- 6545/131072 - 414645*I/131072
- >>> t.eval_rational(n=2)
- 104755/2097152 - 6634255*I/2097152
- Notes
- =====
- Although a PurePoly can be constructed from a non-symbol generator
- RootOf instances of non-symbols are disallowed to avoid confusion
- over what root is being represented.
- >>> from sympy import exp, PurePoly
- >>> PurePoly(x) == PurePoly(exp(x))
- True
- >>> CRootOf(x - 1, 0)
- 1
- >>> CRootOf(exp(x) - 1, 0) # would correspond to x == 0
- Traceback (most recent call last):
- ...
- sympy.polys.polyerrors.PolynomialError: generator must be a Symbol
- See Also
- ========
- eval_approx
- eval_rational
- """
- __slots__ = ('index',)
- is_complex = True
- is_number = True
- is_finite = True
- def __new__(cls, f, x, index=None, radicals=False, expand=True):
- """ Construct an indexed complex root of a polynomial.
- See ``rootof`` for the parameters.
- The default value of ``radicals`` is ``False`` to satisfy
- ``eval(srepr(expr) == expr``.
- """
- x = sympify(x)
- if index is None and x.is_Integer:
- x, index = None, x
- else:
- index = sympify(index)
- if index is not None and index.is_Integer:
- index = int(index)
- else:
- raise ValueError("expected an integer root index, got %s" % index)
- poly = PurePoly(f, x, greedy=False, expand=expand)
- if not poly.is_univariate:
- raise PolynomialError("only univariate polynomials are allowed")
- if not poly.gen.is_Symbol:
- # PurePoly(sin(x) + 1) == PurePoly(x + 1) but the roots of
- # x for each are not the same: issue 8617
- raise PolynomialError("generator must be a Symbol")
- degree = poly.degree()
- if degree <= 0:
- raise PolynomialError("Cannot construct CRootOf object for %s" % f)
- if index < -degree or index >= degree:
- raise IndexError("root index out of [%d, %d] range, got %d" %
- (-degree, degree - 1, index))
- elif index < 0:
- index += degree
- dom = poly.get_domain()
- if not dom.is_Exact:
- poly = poly.to_exact()
- roots = cls._roots_trivial(poly, radicals)
- if roots is not None:
- return roots[index]
- coeff, poly = preprocess_roots(poly)
- dom = poly.get_domain()
- if not dom.is_ZZ:
- raise NotImplementedError("CRootOf is not supported over %s" % dom)
- root = cls._indexed_root(poly, index)
- return coeff * cls._postprocess_root(root, radicals)
- @classmethod
- def _new(cls, poly, index):
- """Construct new ``CRootOf`` object from raw data. """
- obj = Expr.__new__(cls)
- obj.poly = PurePoly(poly)
- obj.index = index
- try:
- _reals_cache[obj.poly] = _reals_cache[poly]
- _complexes_cache[obj.poly] = _complexes_cache[poly]
- except KeyError:
- pass
- return obj
- def _hashable_content(self):
- return (self.poly, self.index)
- @property
- def expr(self):
- return self.poly.as_expr()
- @property
- def args(self):
- return (self.expr, Integer(self.index))
- @property
- def free_symbols(self):
- # CRootOf currently only works with univariate expressions
- # whose poly attribute should be a PurePoly with no free
- # symbols
- return set()
- def _eval_is_real(self):
- """Return ``True`` if the root is real. """
- return self.index < len(_reals_cache[self.poly])
- def _eval_is_imaginary(self):
- """Return ``True`` if the root is imaginary. """
- if self.index >= len(_reals_cache[self.poly]):
- ivl = self._get_interval()
- return ivl.ax*ivl.bx <= 0 # all others are on one side or the other
- return False # XXX is this necessary?
- @classmethod
- def real_roots(cls, poly, radicals=True):
- """Get real roots of a polynomial. """
- return cls._get_roots("_real_roots", poly, radicals)
- @classmethod
- def all_roots(cls, poly, radicals=True):
- """Get real and complex roots of a polynomial. """
- return cls._get_roots("_all_roots", poly, radicals)
- @classmethod
- def _get_reals_sqf(cls, currentfactor, use_cache=True):
- """Get real root isolating intervals for a square-free factor."""
- if use_cache and currentfactor in _reals_cache:
- real_part = _reals_cache[currentfactor]
- else:
- _reals_cache[currentfactor] = real_part = \
- dup_isolate_real_roots_sqf(
- currentfactor.rep.rep, currentfactor.rep.dom, blackbox=True)
- return real_part
- @classmethod
- def _get_complexes_sqf(cls, currentfactor, use_cache=True):
- """Get complex root isolating intervals for a square-free factor."""
- if use_cache and currentfactor in _complexes_cache:
- complex_part = _complexes_cache[currentfactor]
- else:
- _complexes_cache[currentfactor] = complex_part = \
- dup_isolate_complex_roots_sqf(
- currentfactor.rep.rep, currentfactor.rep.dom, blackbox=True)
- return complex_part
- @classmethod
- def _get_reals(cls, factors, use_cache=True):
- """Compute real root isolating intervals for a list of factors. """
- reals = []
- for currentfactor, k in factors:
- try:
- if not use_cache:
- raise KeyError
- r = _reals_cache[currentfactor]
- reals.extend([(i, currentfactor, k) for i in r])
- except KeyError:
- real_part = cls._get_reals_sqf(currentfactor, use_cache)
- new = [(root, currentfactor, k) for root in real_part]
- reals.extend(new)
- reals = cls._reals_sorted(reals)
- return reals
- @classmethod
- def _get_complexes(cls, factors, use_cache=True):
- """Compute complex root isolating intervals for a list of factors. """
- complexes = []
- for currentfactor, k in ordered(factors):
- try:
- if not use_cache:
- raise KeyError
- c = _complexes_cache[currentfactor]
- complexes.extend([(i, currentfactor, k) for i in c])
- except KeyError:
- complex_part = cls._get_complexes_sqf(currentfactor, use_cache)
- new = [(root, currentfactor, k) for root in complex_part]
- complexes.extend(new)
- complexes = cls._complexes_sorted(complexes)
- return complexes
- @classmethod
- def _reals_sorted(cls, reals):
- """Make real isolating intervals disjoint and sort roots. """
- cache = {}
- for i, (u, f, k) in enumerate(reals):
- for j, (v, g, m) in enumerate(reals[i + 1:]):
- u, v = u.refine_disjoint(v)
- reals[i + j + 1] = (v, g, m)
- reals[i] = (u, f, k)
- reals = sorted(reals, key=lambda r: r[0].a)
- for root, currentfactor, _ in reals:
- if currentfactor in cache:
- cache[currentfactor].append(root)
- else:
- cache[currentfactor] = [root]
- for currentfactor, root in cache.items():
- _reals_cache[currentfactor] = root
- return reals
- @classmethod
- def _refine_imaginary(cls, complexes):
- sifted = sift(complexes, lambda c: c[1])
- complexes = []
- for f in ordered(sifted):
- nimag = _imag_count_of_factor(f)
- if nimag == 0:
- # refine until xbounds are neg or pos
- for u, f, k in sifted[f]:
- while u.ax*u.bx <= 0:
- u = u._inner_refine()
- complexes.append((u, f, k))
- else:
- # refine until all but nimag xbounds are neg or pos
- potential_imag = list(range(len(sifted[f])))
- while True:
- assert len(potential_imag) > 1
- for i in list(potential_imag):
- u, f, k = sifted[f][i]
- if u.ax*u.bx > 0:
- potential_imag.remove(i)
- elif u.ax != u.bx:
- u = u._inner_refine()
- sifted[f][i] = u, f, k
- if len(potential_imag) == nimag:
- break
- complexes.extend(sifted[f])
- return complexes
- @classmethod
- def _refine_complexes(cls, complexes):
- """return complexes such that no bounding rectangles of non-conjugate
- roots would intersect. In addition, assure that neither ay nor by is
- 0 to guarantee that non-real roots are distinct from real roots in
- terms of the y-bounds.
- """
- # get the intervals pairwise-disjoint.
- # If rectangles were drawn around the coordinates of the bounding
- # rectangles, no rectangles would intersect after this procedure.
- for i, (u, f, k) in enumerate(complexes):
- for j, (v, g, m) in enumerate(complexes[i + 1:]):
- u, v = u.refine_disjoint(v)
- complexes[i + j + 1] = (v, g, m)
- complexes[i] = (u, f, k)
- # refine until the x-bounds are unambiguously positive or negative
- # for non-imaginary roots
- complexes = cls._refine_imaginary(complexes)
- # make sure that all y bounds are off the real axis
- # and on the same side of the axis
- for i, (u, f, k) in enumerate(complexes):
- while u.ay*u.by <= 0:
- u = u.refine()
- complexes[i] = u, f, k
- return complexes
- @classmethod
- def _complexes_sorted(cls, complexes):
- """Make complex isolating intervals disjoint and sort roots. """
- complexes = cls._refine_complexes(complexes)
- # XXX don't sort until you are sure that it is compatible
- # with the indexing method but assert that the desired state
- # is not broken
- C, F = 0, 1 # location of ComplexInterval and factor
- fs = {i[F] for i in complexes}
- for i in range(1, len(complexes)):
- if complexes[i][F] != complexes[i - 1][F]:
- # if this fails the factors of a root were not
- # contiguous because a discontinuity should only
- # happen once
- fs.remove(complexes[i - 1][F])
- for i in range(len(complexes)):
- # negative im part (conj=True) comes before
- # positive im part (conj=False)
- assert complexes[i][C].conj is (i % 2 == 0)
- # update cache
- cache = {}
- # -- collate
- for root, currentfactor, _ in complexes:
- cache.setdefault(currentfactor, []).append(root)
- # -- store
- for currentfactor, root in cache.items():
- _complexes_cache[currentfactor] = root
- return complexes
- @classmethod
- def _reals_index(cls, reals, index):
- """
- Map initial real root index to an index in a factor where
- the root belongs.
- """
- i = 0
- for j, (_, currentfactor, k) in enumerate(reals):
- if index < i + k:
- poly, index = currentfactor, 0
- for _, currentfactor, _ in reals[:j]:
- if currentfactor == poly:
- index += 1
- return poly, index
- else:
- i += k
- @classmethod
- def _complexes_index(cls, complexes, index):
- """
- Map initial complex root index to an index in a factor where
- the root belongs.
- """
- i = 0
- for j, (_, currentfactor, k) in enumerate(complexes):
- if index < i + k:
- poly, index = currentfactor, 0
- for _, currentfactor, _ in complexes[:j]:
- if currentfactor == poly:
- index += 1
- index += len(_reals_cache[poly])
- return poly, index
- else:
- i += k
- @classmethod
- def _count_roots(cls, roots):
- """Count the number of real or complex roots with multiplicities."""
- return sum([k for _, _, k in roots])
- @classmethod
- def _indexed_root(cls, poly, index):
- """Get a root of a composite polynomial by index. """
- factors = _pure_factors(poly)
- reals = cls._get_reals(factors)
- reals_count = cls._count_roots(reals)
- if index < reals_count:
- return cls._reals_index(reals, index)
- else:
- complexes = cls._get_complexes(factors)
- return cls._complexes_index(complexes, index - reals_count)
- @classmethod
- def _real_roots(cls, poly):
- """Get real roots of a composite polynomial. """
- factors = _pure_factors(poly)
- reals = cls._get_reals(factors)
- reals_count = cls._count_roots(reals)
- roots = []
- for index in range(0, reals_count):
- roots.append(cls._reals_index(reals, index))
- return roots
- def _reset(self):
- """
- Reset all intervals
- """
- self._all_roots(self.poly, use_cache=False)
- @classmethod
- def _all_roots(cls, poly, use_cache=True):
- """Get real and complex roots of a composite polynomial. """
- factors = _pure_factors(poly)
- reals = cls._get_reals(factors, use_cache=use_cache)
- reals_count = cls._count_roots(reals)
- roots = []
- for index in range(0, reals_count):
- roots.append(cls._reals_index(reals, index))
- complexes = cls._get_complexes(factors, use_cache=use_cache)
- complexes_count = cls._count_roots(complexes)
- for index in range(0, complexes_count):
- roots.append(cls._complexes_index(complexes, index))
- return roots
- @classmethod
- @cacheit
- def _roots_trivial(cls, poly, radicals):
- """Compute roots in linear, quadratic and binomial cases. """
- if poly.degree() == 1:
- return roots_linear(poly)
- if not radicals:
- return None
- if poly.degree() == 2:
- return roots_quadratic(poly)
- elif poly.length() == 2 and poly.TC():
- return roots_binomial(poly)
- else:
- return None
- @classmethod
- def _preprocess_roots(cls, poly):
- """Take heroic measures to make ``poly`` compatible with ``CRootOf``."""
- dom = poly.get_domain()
- if not dom.is_Exact:
- poly = poly.to_exact()
- coeff, poly = preprocess_roots(poly)
- dom = poly.get_domain()
- if not dom.is_ZZ:
- raise NotImplementedError(
- "sorted roots not supported over %s" % dom)
- return coeff, poly
- @classmethod
- def _postprocess_root(cls, root, radicals):
- """Return the root if it is trivial or a ``CRootOf`` object. """
- poly, index = root
- roots = cls._roots_trivial(poly, radicals)
- if roots is not None:
- return roots[index]
- else:
- return cls._new(poly, index)
- @classmethod
- def _get_roots(cls, method, poly, radicals):
- """Return postprocessed roots of specified kind. """
- if not poly.is_univariate:
- raise PolynomialError("only univariate polynomials are allowed")
- # get rid of gen and it's free symbol
- d = Dummy()
- poly = poly.subs(poly.gen, d)
- x = symbols('x')
- # see what others are left and select x or a numbered x
- # that doesn't clash
- free_names = {str(i) for i in poly.free_symbols}
- for x in chain((symbols('x'),), numbered_symbols('x')):
- if x.name not in free_names:
- poly = poly.xreplace({d: x})
- break
- coeff, poly = cls._preprocess_roots(poly)
- roots = []
- for root in getattr(cls, method)(poly):
- roots.append(coeff*cls._postprocess_root(root, radicals))
- return roots
- @classmethod
- def clear_cache(cls):
- """Reset cache for reals and complexes.
- The intervals used to approximate a root instance are updated
- as needed. When a request is made to see the intervals, the
- most current values are shown. `clear_cache` will reset all
- CRootOf instances back to their original state.
- See Also
- ========
- _reset
- """
- global _reals_cache, _complexes_cache
- _reals_cache = _pure_key_dict()
- _complexes_cache = _pure_key_dict()
- def _get_interval(self):
- """Internal function for retrieving isolation interval from cache. """
- if self.is_real:
- return _reals_cache[self.poly][self.index]
- else:
- reals_count = len(_reals_cache[self.poly])
- return _complexes_cache[self.poly][self.index - reals_count]
- def _set_interval(self, interval):
- """Internal function for updating isolation interval in cache. """
- if self.is_real:
- _reals_cache[self.poly][self.index] = interval
- else:
- reals_count = len(_reals_cache[self.poly])
- _complexes_cache[self.poly][self.index - reals_count] = interval
- def _eval_subs(self, old, new):
- # don't allow subs to change anything
- return self
- def _eval_conjugate(self):
- if self.is_real:
- return self
- expr, i = self.args
- return self.func(expr, i + (1 if self._get_interval().conj else -1))
- def eval_approx(self, n):
- """Evaluate this complex root to the given precision.
- This uses secant method and root bounds are used to both
- generate an initial guess and to check that the root
- returned is valid. If ever the method converges outside the
- root bounds, the bounds will be made smaller and updated.
- """
- prec = dps_to_prec(n)
- with workprec(prec):
- g = self.poly.gen
- if not g.is_Symbol:
- d = Dummy('x')
- if self.is_imaginary:
- d *= I
- func = lambdify(d, self.expr.subs(g, d))
- else:
- expr = self.expr
- if self.is_imaginary:
- expr = self.expr.subs(g, I*g)
- func = lambdify(g, expr)
- interval = self._get_interval()
- while True:
- if self.is_real:
- a = mpf(str(interval.a))
- b = mpf(str(interval.b))
- if a == b:
- root = a
- break
- x0 = mpf(str(interval.center))
- x1 = x0 + mpf(str(interval.dx))/4
- elif self.is_imaginary:
- a = mpf(str(interval.ay))
- b = mpf(str(interval.by))
- if a == b:
- root = mpc(mpf('0'), a)
- break
- x0 = mpf(str(interval.center[1]))
- x1 = x0 + mpf(str(interval.dy))/4
- else:
- ax = mpf(str(interval.ax))
- bx = mpf(str(interval.bx))
- ay = mpf(str(interval.ay))
- by = mpf(str(interval.by))
- if ax == bx and ay == by:
- root = mpc(ax, ay)
- break
- x0 = mpc(*map(str, interval.center))
- x1 = x0 + mpc(*map(str, (interval.dx, interval.dy)))/4
- try:
- # without a tolerance, this will return when (to within
- # the given precision) x_i == x_{i-1}
- root = findroot(func, (x0, x1))
- # If the (real or complex) root is not in the 'interval',
- # then keep refining the interval. This happens if findroot
- # accidentally finds a different root outside of this
- # interval because our initial estimate 'x0' was not close
- # enough. It is also possible that the secant method will
- # get trapped by a max/min in the interval; the root
- # verification by findroot will raise a ValueError in this
- # case and the interval will then be tightened -- and
- # eventually the root will be found.
- #
- # It is also possible that findroot will not have any
- # successful iterations to process (in which case it
- # will fail to initialize a variable that is tested
- # after the iterations and raise an UnboundLocalError).
- if self.is_real or self.is_imaginary:
- if not bool(root.imag) == self.is_real and (
- a <= root <= b):
- if self.is_imaginary:
- root = mpc(mpf('0'), root.real)
- break
- elif (ax <= root.real <= bx and ay <= root.imag <= by):
- break
- except (UnboundLocalError, ValueError):
- pass
- interval = interval.refine()
- # update the interval so we at least (for this precision or
- # less) don't have much work to do to recompute the root
- self._set_interval(interval)
- return (Float._new(root.real._mpf_, prec) +
- I*Float._new(root.imag._mpf_, prec))
- def _eval_evalf(self, prec, **kwargs):
- """Evaluate this complex root to the given precision."""
- # all kwargs are ignored
- return self.eval_rational(n=prec_to_dps(prec))._evalf(prec)
- def eval_rational(self, dx=None, dy=None, n=15):
- """
- Return a Rational approximation of ``self`` that has real
- and imaginary component approximations that are within ``dx``
- and ``dy`` of the true values, respectively. Alternatively,
- ``n`` digits of precision can be specified.
- The interval is refined with bisection and is sure to
- converge. The root bounds are updated when the refinement
- is complete so recalculation at the same or lesser precision
- will not have to repeat the refinement and should be much
- faster.
- The following example first obtains Rational approximation to
- 1e-8 accuracy for all roots of the 4-th order Legendre
- polynomial. Since the roots are all less than 1, this will
- ensure the decimal representation of the approximation will be
- correct (including rounding) to 6 digits:
- >>> from sympy import legendre_poly, Symbol
- >>> x = Symbol("x")
- >>> p = legendre_poly(4, x, polys=True)
- >>> r = p.real_roots()[-1]
- >>> r.eval_rational(10**-8).n(6)
- 0.861136
- It is not necessary to a two-step calculation, however: the
- decimal representation can be computed directly:
- >>> r.evalf(17)
- 0.86113631159405258
- """
- dy = dy or dx
- if dx:
- rtol = None
- dx = dx if isinstance(dx, Rational) else Rational(str(dx))
- dy = dy if isinstance(dy, Rational) else Rational(str(dy))
- else:
- # 5 binary (or 2 decimal) digits are needed to ensure that
- # a given digit is correctly rounded
- # prec_to_dps(dps_to_prec(n) + 5) - n <= 2 (tested for
- # n in range(1000000)
- rtol = S(10)**-(n + 2) # +2 for guard digits
- interval = self._get_interval()
- while True:
- if self.is_real:
- if rtol:
- dx = abs(interval.center*rtol)
- interval = interval.refine_size(dx=dx)
- c = interval.center
- real = Rational(c)
- imag = S.Zero
- if not rtol or interval.dx < abs(c*rtol):
- break
- elif self.is_imaginary:
- if rtol:
- dy = abs(interval.center[1]*rtol)
- dx = 1
- interval = interval.refine_size(dx=dx, dy=dy)
- c = interval.center[1]
- imag = Rational(c)
- real = S.Zero
- if not rtol or interval.dy < abs(c*rtol):
- break
- else:
- if rtol:
- dx = abs(interval.center[0]*rtol)
- dy = abs(interval.center[1]*rtol)
- interval = interval.refine_size(dx, dy)
- c = interval.center
- real, imag = map(Rational, c)
- if not rtol or (
- interval.dx < abs(c[0]*rtol) and
- interval.dy < abs(c[1]*rtol)):
- break
- # update the interval so we at least (for this precision or
- # less) don't have much work to do to recompute the root
- self._set_interval(interval)
- return real + I*imag
- CRootOf = ComplexRootOf
- @dispatch(ComplexRootOf, ComplexRootOf)
- def _eval_is_eq(lhs, rhs): # noqa:F811
- # if we use is_eq to check here, we get infinite recurion
- return lhs == rhs
- @dispatch(ComplexRootOf, Basic) # type:ignore
- def _eval_is_eq(lhs, rhs): # noqa:F811
- # CRootOf represents a Root, so if rhs is that root, it should set
- # the expression to zero *and* it should be in the interval of the
- # CRootOf instance. It must also be a number that agrees with the
- # is_real value of the CRootOf instance.
- if not rhs.is_number:
- return None
- if not rhs.is_finite:
- return False
- z = lhs.expr.subs(lhs.expr.free_symbols.pop(), rhs).is_zero
- if z is False: # all roots will make z True but we don't know
- # whether this is the right root if z is True
- return False
- o = rhs.is_real, rhs.is_imaginary
- s = lhs.is_real, lhs.is_imaginary
- assert None not in s # this is part of initial refinement
- if o != s and None not in o:
- return False
- re, im = rhs.as_real_imag()
- if lhs.is_real:
- if im:
- return False
- i = lhs._get_interval()
- a, b = [Rational(str(_)) for _ in (i.a, i.b)]
- return sympify(a <= rhs and rhs <= b)
- i = lhs._get_interval()
- r1, r2, i1, i2 = [Rational(str(j)) for j in (
- i.ax, i.bx, i.ay, i.by)]
- return is_le(r1, re) and is_le(re,r2) and is_le(i1,im) and is_le(im,i2)
- @public
- class RootSum(Expr):
- """Represents a sum of all roots of a univariate polynomial. """
- __slots__ = ('poly', 'fun', 'auto')
- def __new__(cls, expr, func=None, x=None, auto=True, quadratic=False):
- """Construct a new ``RootSum`` instance of roots of a polynomial."""
- coeff, poly = cls._transform(expr, x)
- if not poly.is_univariate:
- raise MultivariatePolynomialError(
- "only univariate polynomials are allowed")
- if func is None:
- func = Lambda(poly.gen, poly.gen)
- else:
- is_func = getattr(func, 'is_Function', False)
- if is_func and 1 in func.nargs:
- if not isinstance(func, Lambda):
- func = Lambda(poly.gen, func(poly.gen))
- else:
- raise ValueError(
- "expected a univariate function, got %s" % func)
- var, expr = func.variables[0], func.expr
- if coeff is not S.One:
- expr = expr.subs(var, coeff*var)
- deg = poly.degree()
- if not expr.has(var):
- return deg*expr
- if expr.is_Add:
- add_const, expr = expr.as_independent(var)
- else:
- add_const = S.Zero
- if expr.is_Mul:
- mul_const, expr = expr.as_independent(var)
- else:
- mul_const = S.One
- func = Lambda(var, expr)
- rational = cls._is_func_rational(poly, func)
- factors, terms = _pure_factors(poly), []
- for poly, k in factors:
- if poly.is_linear:
- term = func(roots_linear(poly)[0])
- elif quadratic and poly.is_quadratic:
- term = sum(map(func, roots_quadratic(poly)))
- else:
- if not rational or not auto:
- term = cls._new(poly, func, auto)
- else:
- term = cls._rational_case(poly, func)
- terms.append(k*term)
- return mul_const*Add(*terms) + deg*add_const
- @classmethod
- def _new(cls, poly, func, auto=True):
- """Construct new raw ``RootSum`` instance. """
- obj = Expr.__new__(cls)
- obj.poly = poly
- obj.fun = func
- obj.auto = auto
- return obj
- @classmethod
- def new(cls, poly, func, auto=True):
- """Construct new ``RootSum`` instance. """
- if not func.expr.has(*func.variables):
- return func.expr
- rational = cls._is_func_rational(poly, func)
- if not rational or not auto:
- return cls._new(poly, func, auto)
- else:
- return cls._rational_case(poly, func)
- @classmethod
- def _transform(cls, expr, x):
- """Transform an expression to a polynomial. """
- poly = PurePoly(expr, x, greedy=False)
- return preprocess_roots(poly)
- @classmethod
- def _is_func_rational(cls, poly, func):
- """Check if a lambda is a rational function. """
- var, expr = func.variables[0], func.expr
- return expr.is_rational_function(var)
- @classmethod
- def _rational_case(cls, poly, func):
- """Handle the rational function case. """
- roots = symbols('r:%d' % poly.degree())
- var, expr = func.variables[0], func.expr
- f = sum(expr.subs(var, r) for r in roots)
- p, q = together(f).as_numer_denom()
- domain = QQ[roots]
- p = p.expand()
- q = q.expand()
- try:
- p = Poly(p, domain=domain, expand=False)
- except GeneratorsNeeded:
- p, p_coeff = None, (p,)
- else:
- p_monom, p_coeff = zip(*p.terms())
- try:
- q = Poly(q, domain=domain, expand=False)
- except GeneratorsNeeded:
- q, q_coeff = None, (q,)
- else:
- q_monom, q_coeff = zip(*q.terms())
- coeffs, mapping = symmetrize(p_coeff + q_coeff, formal=True)
- formulas, values = viete(poly, roots), []
- for (sym, _), (_, val) in zip(mapping, formulas):
- values.append((sym, val))
- for i, (coeff, _) in enumerate(coeffs):
- coeffs[i] = coeff.subs(values)
- n = len(p_coeff)
- p_coeff = coeffs[:n]
- q_coeff = coeffs[n:]
- if p is not None:
- p = Poly(dict(zip(p_monom, p_coeff)), *p.gens).as_expr()
- else:
- (p,) = p_coeff
- if q is not None:
- q = Poly(dict(zip(q_monom, q_coeff)), *q.gens).as_expr()
- else:
- (q,) = q_coeff
- return factor(p/q)
- def _hashable_content(self):
- return (self.poly, self.fun)
- @property
- def expr(self):
- return self.poly.as_expr()
- @property
- def args(self):
- return (self.expr, self.fun, self.poly.gen)
- @property
- def free_symbols(self):
- return self.poly.free_symbols | self.fun.free_symbols
- @property
- def is_commutative(self):
- return True
- def doit(self, **hints):
- if not hints.get('roots', True):
- return self
- _roots = roots(self.poly, multiple=True)
- if len(_roots) < self.poly.degree():
- return self
- else:
- return Add(*[self.fun(r) for r in _roots])
- def _eval_evalf(self, prec):
- try:
- _roots = self.poly.nroots(n=prec_to_dps(prec))
- except (DomainError, PolynomialError):
- return self
- else:
- return Add(*[self.fun(r) for r in _roots])
- def _eval_derivative(self, x):
- var, expr = self.fun.args
- func = Lambda(var, expr.diff(x))
- return self.new(self.poly, func, self.auto)
|