rootoftools.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218
  1. """Implementation of RootOf class and related tools. """
  2. from sympy.core.basic import Basic
  3. from sympy.core import (S, Expr, Integer, Float, I, oo, Add, Lambda,
  4. symbols, sympify, Rational, Dummy)
  5. from sympy.core.cache import cacheit
  6. from sympy.core.relational import is_le
  7. from sympy.core.sorting import ordered
  8. from sympy.polys.domains import QQ
  9. from sympy.polys.polyerrors import (
  10. MultivariatePolynomialError,
  11. GeneratorsNeeded,
  12. PolynomialError,
  13. DomainError)
  14. from sympy.polys.polyfuncs import symmetrize, viete
  15. from sympy.polys.polyroots import (
  16. roots_linear, roots_quadratic, roots_binomial,
  17. preprocess_roots, roots)
  18. from sympy.polys.polytools import Poly, PurePoly, factor
  19. from sympy.polys.rationaltools import together
  20. from sympy.polys.rootisolation import (
  21. dup_isolate_complex_roots_sqf,
  22. dup_isolate_real_roots_sqf)
  23. from sympy.utilities import lambdify, public, sift, numbered_symbols
  24. from mpmath import mpf, mpc, findroot, workprec
  25. from mpmath.libmp.libmpf import dps_to_prec, prec_to_dps
  26. from sympy.multipledispatch import dispatch
  27. from itertools import chain
  28. __all__ = ['CRootOf']
  29. class _pure_key_dict:
  30. """A minimal dictionary that makes sure that the key is a
  31. univariate PurePoly instance.
  32. Examples
  33. ========
  34. Only the following actions are guaranteed:
  35. >>> from sympy.polys.rootoftools import _pure_key_dict
  36. >>> from sympy import PurePoly
  37. >>> from sympy.abc import x, y
  38. 1) creation
  39. >>> P = _pure_key_dict()
  40. 2) assignment for a PurePoly or univariate polynomial
  41. >>> P[x] = 1
  42. >>> P[PurePoly(x - y, x)] = 2
  43. 3) retrieval based on PurePoly key comparison (use this
  44. instead of the get method)
  45. >>> P[y]
  46. 1
  47. 4) KeyError when trying to retrieve a nonexisting key
  48. >>> P[y + 1]
  49. Traceback (most recent call last):
  50. ...
  51. KeyError: PurePoly(y + 1, y, domain='ZZ')
  52. 5) ability to query with ``in``
  53. >>> x + 1 in P
  54. False
  55. NOTE: this is a *not* a dictionary. It is a very basic object
  56. for internal use that makes sure to always address its cache
  57. via PurePoly instances. It does not, for example, implement
  58. ``get`` or ``setdefault``.
  59. """
  60. def __init__(self):
  61. self._dict = {}
  62. def __getitem__(self, k):
  63. if not isinstance(k, PurePoly):
  64. if not (isinstance(k, Expr) and len(k.free_symbols) == 1):
  65. raise KeyError
  66. k = PurePoly(k, expand=False)
  67. return self._dict[k]
  68. def __setitem__(self, k, v):
  69. if not isinstance(k, PurePoly):
  70. if not (isinstance(k, Expr) and len(k.free_symbols) == 1):
  71. raise ValueError('expecting univariate expression')
  72. k = PurePoly(k, expand=False)
  73. self._dict[k] = v
  74. def __contains__(self, k):
  75. try:
  76. self[k]
  77. return True
  78. except KeyError:
  79. return False
  80. _reals_cache = _pure_key_dict()
  81. _complexes_cache = _pure_key_dict()
  82. def _pure_factors(poly):
  83. _, factors = poly.factor_list()
  84. return [(PurePoly(f, expand=False), m) for f, m in factors]
  85. def _imag_count_of_factor(f):
  86. """Return the number of imaginary roots for irreducible
  87. univariate polynomial ``f``.
  88. """
  89. terms = [(i, j) for (i,), j in f.terms()]
  90. if any(i % 2 for i, j in terms):
  91. return 0
  92. # update signs
  93. even = [(i, I**i*j) for i, j in terms]
  94. even = Poly.from_dict(dict(even), Dummy('x'))
  95. return int(even.count_roots(-oo, oo))
  96. @public
  97. def rootof(f, x, index=None, radicals=True, expand=True):
  98. """An indexed root of a univariate polynomial.
  99. Returns either a :obj:`ComplexRootOf` object or an explicit
  100. expression involving radicals.
  101. Parameters
  102. ==========
  103. f : Expr
  104. Univariate polynomial.
  105. x : Symbol, optional
  106. Generator for ``f``.
  107. index : int or Integer
  108. radicals : bool
  109. Return a radical expression if possible.
  110. expand : bool
  111. Expand ``f``.
  112. """
  113. return CRootOf(f, x, index=index, radicals=radicals, expand=expand)
  114. @public
  115. class RootOf(Expr):
  116. """Represents a root of a univariate polynomial.
  117. Base class for roots of different kinds of polynomials.
  118. Only complex roots are currently supported.
  119. """
  120. __slots__ = ('poly',)
  121. def __new__(cls, f, x, index=None, radicals=True, expand=True):
  122. """Construct a new ``CRootOf`` object for ``k``-th root of ``f``."""
  123. return rootof(f, x, index=index, radicals=radicals, expand=expand)
  124. @public
  125. class ComplexRootOf(RootOf):
  126. """Represents an indexed complex root of a polynomial.
  127. Roots of a univariate polynomial separated into disjoint
  128. real or complex intervals and indexed in a fixed order:
  129. * real roots come first and are sorted in increasing order;
  130. * complex roots come next and are sorted primarily by increasing
  131. real part, secondarily by increasing imaginary part.
  132. Currently only rational coefficients are allowed.
  133. Can be imported as ``CRootOf``. To avoid confusion, the
  134. generator must be a Symbol.
  135. Examples
  136. ========
  137. >>> from sympy import CRootOf, rootof
  138. >>> from sympy.abc import x
  139. CRootOf is a way to reference a particular root of a
  140. polynomial. If there is a rational root, it will be returned:
  141. >>> CRootOf.clear_cache() # for doctest reproducibility
  142. >>> CRootOf(x**2 - 4, 0)
  143. -2
  144. Whether roots involving radicals are returned or not
  145. depends on whether the ``radicals`` flag is true (which is
  146. set to True with rootof):
  147. >>> CRootOf(x**2 - 3, 0)
  148. CRootOf(x**2 - 3, 0)
  149. >>> CRootOf(x**2 - 3, 0, radicals=True)
  150. -sqrt(3)
  151. >>> rootof(x**2 - 3, 0)
  152. -sqrt(3)
  153. The following cannot be expressed in terms of radicals:
  154. >>> r = rootof(4*x**5 + 16*x**3 + 12*x**2 + 7, 0); r
  155. CRootOf(4*x**5 + 16*x**3 + 12*x**2 + 7, 0)
  156. The root bounds can be seen, however, and they are used by the
  157. evaluation methods to get numerical approximations for the root.
  158. >>> interval = r._get_interval(); interval
  159. (-1, 0)
  160. >>> r.evalf(2)
  161. -0.98
  162. The evalf method refines the width of the root bounds until it
  163. guarantees that any decimal approximation within those bounds
  164. will satisfy the desired precision. It then stores the refined
  165. interval so subsequent requests at or below the requested
  166. precision will not have to recompute the root bounds and will
  167. return very quickly.
  168. Before evaluation above, the interval was
  169. >>> interval
  170. (-1, 0)
  171. After evaluation it is now
  172. >>> r._get_interval() # doctest: +SKIP
  173. (-165/169, -206/211)
  174. To reset all intervals for a given polynomial, the :meth:`_reset` method
  175. can be called from any CRootOf instance of the polynomial:
  176. >>> r._reset()
  177. >>> r._get_interval()
  178. (-1, 0)
  179. The :meth:`eval_approx` method will also find the root to a given
  180. precision but the interval is not modified unless the search
  181. for the root fails to converge within the root bounds. And
  182. the secant method is used to find the root. (The ``evalf``
  183. method uses bisection and will always update the interval.)
  184. >>> r.eval_approx(2)
  185. -0.98
  186. The interval needed to be slightly updated to find that root:
  187. >>> r._get_interval()
  188. (-1, -1/2)
  189. The ``evalf_rational`` will compute a rational approximation
  190. of the root to the desired accuracy or precision.
  191. >>> r.eval_rational(n=2)
  192. -69629/71318
  193. >>> t = CRootOf(x**3 + 10*x + 1, 1)
  194. >>> t.eval_rational(1e-1)
  195. 15/256 - 805*I/256
  196. >>> t.eval_rational(1e-1, 1e-4)
  197. 3275/65536 - 414645*I/131072
  198. >>> t.eval_rational(1e-4, 1e-4)
  199. 6545/131072 - 414645*I/131072
  200. >>> t.eval_rational(n=2)
  201. 104755/2097152 - 6634255*I/2097152
  202. Notes
  203. =====
  204. Although a PurePoly can be constructed from a non-symbol generator
  205. RootOf instances of non-symbols are disallowed to avoid confusion
  206. over what root is being represented.
  207. >>> from sympy import exp, PurePoly
  208. >>> PurePoly(x) == PurePoly(exp(x))
  209. True
  210. >>> CRootOf(x - 1, 0)
  211. 1
  212. >>> CRootOf(exp(x) - 1, 0) # would correspond to x == 0
  213. Traceback (most recent call last):
  214. ...
  215. sympy.polys.polyerrors.PolynomialError: generator must be a Symbol
  216. See Also
  217. ========
  218. eval_approx
  219. eval_rational
  220. """
  221. __slots__ = ('index',)
  222. is_complex = True
  223. is_number = True
  224. is_finite = True
  225. def __new__(cls, f, x, index=None, radicals=False, expand=True):
  226. """ Construct an indexed complex root of a polynomial.
  227. See ``rootof`` for the parameters.
  228. The default value of ``radicals`` is ``False`` to satisfy
  229. ``eval(srepr(expr) == expr``.
  230. """
  231. x = sympify(x)
  232. if index is None and x.is_Integer:
  233. x, index = None, x
  234. else:
  235. index = sympify(index)
  236. if index is not None and index.is_Integer:
  237. index = int(index)
  238. else:
  239. raise ValueError("expected an integer root index, got %s" % index)
  240. poly = PurePoly(f, x, greedy=False, expand=expand)
  241. if not poly.is_univariate:
  242. raise PolynomialError("only univariate polynomials are allowed")
  243. if not poly.gen.is_Symbol:
  244. # PurePoly(sin(x) + 1) == PurePoly(x + 1) but the roots of
  245. # x for each are not the same: issue 8617
  246. raise PolynomialError("generator must be a Symbol")
  247. degree = poly.degree()
  248. if degree <= 0:
  249. raise PolynomialError("Cannot construct CRootOf object for %s" % f)
  250. if index < -degree or index >= degree:
  251. raise IndexError("root index out of [%d, %d] range, got %d" %
  252. (-degree, degree - 1, index))
  253. elif index < 0:
  254. index += degree
  255. dom = poly.get_domain()
  256. if not dom.is_Exact:
  257. poly = poly.to_exact()
  258. roots = cls._roots_trivial(poly, radicals)
  259. if roots is not None:
  260. return roots[index]
  261. coeff, poly = preprocess_roots(poly)
  262. dom = poly.get_domain()
  263. if not dom.is_ZZ:
  264. raise NotImplementedError("CRootOf is not supported over %s" % dom)
  265. root = cls._indexed_root(poly, index)
  266. return coeff * cls._postprocess_root(root, radicals)
  267. @classmethod
  268. def _new(cls, poly, index):
  269. """Construct new ``CRootOf`` object from raw data. """
  270. obj = Expr.__new__(cls)
  271. obj.poly = PurePoly(poly)
  272. obj.index = index
  273. try:
  274. _reals_cache[obj.poly] = _reals_cache[poly]
  275. _complexes_cache[obj.poly] = _complexes_cache[poly]
  276. except KeyError:
  277. pass
  278. return obj
  279. def _hashable_content(self):
  280. return (self.poly, self.index)
  281. @property
  282. def expr(self):
  283. return self.poly.as_expr()
  284. @property
  285. def args(self):
  286. return (self.expr, Integer(self.index))
  287. @property
  288. def free_symbols(self):
  289. # CRootOf currently only works with univariate expressions
  290. # whose poly attribute should be a PurePoly with no free
  291. # symbols
  292. return set()
  293. def _eval_is_real(self):
  294. """Return ``True`` if the root is real. """
  295. return self.index < len(_reals_cache[self.poly])
  296. def _eval_is_imaginary(self):
  297. """Return ``True`` if the root is imaginary. """
  298. if self.index >= len(_reals_cache[self.poly]):
  299. ivl = self._get_interval()
  300. return ivl.ax*ivl.bx <= 0 # all others are on one side or the other
  301. return False # XXX is this necessary?
  302. @classmethod
  303. def real_roots(cls, poly, radicals=True):
  304. """Get real roots of a polynomial. """
  305. return cls._get_roots("_real_roots", poly, radicals)
  306. @classmethod
  307. def all_roots(cls, poly, radicals=True):
  308. """Get real and complex roots of a polynomial. """
  309. return cls._get_roots("_all_roots", poly, radicals)
  310. @classmethod
  311. def _get_reals_sqf(cls, currentfactor, use_cache=True):
  312. """Get real root isolating intervals for a square-free factor."""
  313. if use_cache and currentfactor in _reals_cache:
  314. real_part = _reals_cache[currentfactor]
  315. else:
  316. _reals_cache[currentfactor] = real_part = \
  317. dup_isolate_real_roots_sqf(
  318. currentfactor.rep.rep, currentfactor.rep.dom, blackbox=True)
  319. return real_part
  320. @classmethod
  321. def _get_complexes_sqf(cls, currentfactor, use_cache=True):
  322. """Get complex root isolating intervals for a square-free factor."""
  323. if use_cache and currentfactor in _complexes_cache:
  324. complex_part = _complexes_cache[currentfactor]
  325. else:
  326. _complexes_cache[currentfactor] = complex_part = \
  327. dup_isolate_complex_roots_sqf(
  328. currentfactor.rep.rep, currentfactor.rep.dom, blackbox=True)
  329. return complex_part
  330. @classmethod
  331. def _get_reals(cls, factors, use_cache=True):
  332. """Compute real root isolating intervals for a list of factors. """
  333. reals = []
  334. for currentfactor, k in factors:
  335. try:
  336. if not use_cache:
  337. raise KeyError
  338. r = _reals_cache[currentfactor]
  339. reals.extend([(i, currentfactor, k) for i in r])
  340. except KeyError:
  341. real_part = cls._get_reals_sqf(currentfactor, use_cache)
  342. new = [(root, currentfactor, k) for root in real_part]
  343. reals.extend(new)
  344. reals = cls._reals_sorted(reals)
  345. return reals
  346. @classmethod
  347. def _get_complexes(cls, factors, use_cache=True):
  348. """Compute complex root isolating intervals for a list of factors. """
  349. complexes = []
  350. for currentfactor, k in ordered(factors):
  351. try:
  352. if not use_cache:
  353. raise KeyError
  354. c = _complexes_cache[currentfactor]
  355. complexes.extend([(i, currentfactor, k) for i in c])
  356. except KeyError:
  357. complex_part = cls._get_complexes_sqf(currentfactor, use_cache)
  358. new = [(root, currentfactor, k) for root in complex_part]
  359. complexes.extend(new)
  360. complexes = cls._complexes_sorted(complexes)
  361. return complexes
  362. @classmethod
  363. def _reals_sorted(cls, reals):
  364. """Make real isolating intervals disjoint and sort roots. """
  365. cache = {}
  366. for i, (u, f, k) in enumerate(reals):
  367. for j, (v, g, m) in enumerate(reals[i + 1:]):
  368. u, v = u.refine_disjoint(v)
  369. reals[i + j + 1] = (v, g, m)
  370. reals[i] = (u, f, k)
  371. reals = sorted(reals, key=lambda r: r[0].a)
  372. for root, currentfactor, _ in reals:
  373. if currentfactor in cache:
  374. cache[currentfactor].append(root)
  375. else:
  376. cache[currentfactor] = [root]
  377. for currentfactor, root in cache.items():
  378. _reals_cache[currentfactor] = root
  379. return reals
  380. @classmethod
  381. def _refine_imaginary(cls, complexes):
  382. sifted = sift(complexes, lambda c: c[1])
  383. complexes = []
  384. for f in ordered(sifted):
  385. nimag = _imag_count_of_factor(f)
  386. if nimag == 0:
  387. # refine until xbounds are neg or pos
  388. for u, f, k in sifted[f]:
  389. while u.ax*u.bx <= 0:
  390. u = u._inner_refine()
  391. complexes.append((u, f, k))
  392. else:
  393. # refine until all but nimag xbounds are neg or pos
  394. potential_imag = list(range(len(sifted[f])))
  395. while True:
  396. assert len(potential_imag) > 1
  397. for i in list(potential_imag):
  398. u, f, k = sifted[f][i]
  399. if u.ax*u.bx > 0:
  400. potential_imag.remove(i)
  401. elif u.ax != u.bx:
  402. u = u._inner_refine()
  403. sifted[f][i] = u, f, k
  404. if len(potential_imag) == nimag:
  405. break
  406. complexes.extend(sifted[f])
  407. return complexes
  408. @classmethod
  409. def _refine_complexes(cls, complexes):
  410. """return complexes such that no bounding rectangles of non-conjugate
  411. roots would intersect. In addition, assure that neither ay nor by is
  412. 0 to guarantee that non-real roots are distinct from real roots in
  413. terms of the y-bounds.
  414. """
  415. # get the intervals pairwise-disjoint.
  416. # If rectangles were drawn around the coordinates of the bounding
  417. # rectangles, no rectangles would intersect after this procedure.
  418. for i, (u, f, k) in enumerate(complexes):
  419. for j, (v, g, m) in enumerate(complexes[i + 1:]):
  420. u, v = u.refine_disjoint(v)
  421. complexes[i + j + 1] = (v, g, m)
  422. complexes[i] = (u, f, k)
  423. # refine until the x-bounds are unambiguously positive or negative
  424. # for non-imaginary roots
  425. complexes = cls._refine_imaginary(complexes)
  426. # make sure that all y bounds are off the real axis
  427. # and on the same side of the axis
  428. for i, (u, f, k) in enumerate(complexes):
  429. while u.ay*u.by <= 0:
  430. u = u.refine()
  431. complexes[i] = u, f, k
  432. return complexes
  433. @classmethod
  434. def _complexes_sorted(cls, complexes):
  435. """Make complex isolating intervals disjoint and sort roots. """
  436. complexes = cls._refine_complexes(complexes)
  437. # XXX don't sort until you are sure that it is compatible
  438. # with the indexing method but assert that the desired state
  439. # is not broken
  440. C, F = 0, 1 # location of ComplexInterval and factor
  441. fs = {i[F] for i in complexes}
  442. for i in range(1, len(complexes)):
  443. if complexes[i][F] != complexes[i - 1][F]:
  444. # if this fails the factors of a root were not
  445. # contiguous because a discontinuity should only
  446. # happen once
  447. fs.remove(complexes[i - 1][F])
  448. for i in range(len(complexes)):
  449. # negative im part (conj=True) comes before
  450. # positive im part (conj=False)
  451. assert complexes[i][C].conj is (i % 2 == 0)
  452. # update cache
  453. cache = {}
  454. # -- collate
  455. for root, currentfactor, _ in complexes:
  456. cache.setdefault(currentfactor, []).append(root)
  457. # -- store
  458. for currentfactor, root in cache.items():
  459. _complexes_cache[currentfactor] = root
  460. return complexes
  461. @classmethod
  462. def _reals_index(cls, reals, index):
  463. """
  464. Map initial real root index to an index in a factor where
  465. the root belongs.
  466. """
  467. i = 0
  468. for j, (_, currentfactor, k) in enumerate(reals):
  469. if index < i + k:
  470. poly, index = currentfactor, 0
  471. for _, currentfactor, _ in reals[:j]:
  472. if currentfactor == poly:
  473. index += 1
  474. return poly, index
  475. else:
  476. i += k
  477. @classmethod
  478. def _complexes_index(cls, complexes, index):
  479. """
  480. Map initial complex root index to an index in a factor where
  481. the root belongs.
  482. """
  483. i = 0
  484. for j, (_, currentfactor, k) in enumerate(complexes):
  485. if index < i + k:
  486. poly, index = currentfactor, 0
  487. for _, currentfactor, _ in complexes[:j]:
  488. if currentfactor == poly:
  489. index += 1
  490. index += len(_reals_cache[poly])
  491. return poly, index
  492. else:
  493. i += k
  494. @classmethod
  495. def _count_roots(cls, roots):
  496. """Count the number of real or complex roots with multiplicities."""
  497. return sum([k for _, _, k in roots])
  498. @classmethod
  499. def _indexed_root(cls, poly, index):
  500. """Get a root of a composite polynomial by index. """
  501. factors = _pure_factors(poly)
  502. reals = cls._get_reals(factors)
  503. reals_count = cls._count_roots(reals)
  504. if index < reals_count:
  505. return cls._reals_index(reals, index)
  506. else:
  507. complexes = cls._get_complexes(factors)
  508. return cls._complexes_index(complexes, index - reals_count)
  509. @classmethod
  510. def _real_roots(cls, poly):
  511. """Get real roots of a composite polynomial. """
  512. factors = _pure_factors(poly)
  513. reals = cls._get_reals(factors)
  514. reals_count = cls._count_roots(reals)
  515. roots = []
  516. for index in range(0, reals_count):
  517. roots.append(cls._reals_index(reals, index))
  518. return roots
  519. def _reset(self):
  520. """
  521. Reset all intervals
  522. """
  523. self._all_roots(self.poly, use_cache=False)
  524. @classmethod
  525. def _all_roots(cls, poly, use_cache=True):
  526. """Get real and complex roots of a composite polynomial. """
  527. factors = _pure_factors(poly)
  528. reals = cls._get_reals(factors, use_cache=use_cache)
  529. reals_count = cls._count_roots(reals)
  530. roots = []
  531. for index in range(0, reals_count):
  532. roots.append(cls._reals_index(reals, index))
  533. complexes = cls._get_complexes(factors, use_cache=use_cache)
  534. complexes_count = cls._count_roots(complexes)
  535. for index in range(0, complexes_count):
  536. roots.append(cls._complexes_index(complexes, index))
  537. return roots
  538. @classmethod
  539. @cacheit
  540. def _roots_trivial(cls, poly, radicals):
  541. """Compute roots in linear, quadratic and binomial cases. """
  542. if poly.degree() == 1:
  543. return roots_linear(poly)
  544. if not radicals:
  545. return None
  546. if poly.degree() == 2:
  547. return roots_quadratic(poly)
  548. elif poly.length() == 2 and poly.TC():
  549. return roots_binomial(poly)
  550. else:
  551. return None
  552. @classmethod
  553. def _preprocess_roots(cls, poly):
  554. """Take heroic measures to make ``poly`` compatible with ``CRootOf``."""
  555. dom = poly.get_domain()
  556. if not dom.is_Exact:
  557. poly = poly.to_exact()
  558. coeff, poly = preprocess_roots(poly)
  559. dom = poly.get_domain()
  560. if not dom.is_ZZ:
  561. raise NotImplementedError(
  562. "sorted roots not supported over %s" % dom)
  563. return coeff, poly
  564. @classmethod
  565. def _postprocess_root(cls, root, radicals):
  566. """Return the root if it is trivial or a ``CRootOf`` object. """
  567. poly, index = root
  568. roots = cls._roots_trivial(poly, radicals)
  569. if roots is not None:
  570. return roots[index]
  571. else:
  572. return cls._new(poly, index)
  573. @classmethod
  574. def _get_roots(cls, method, poly, radicals):
  575. """Return postprocessed roots of specified kind. """
  576. if not poly.is_univariate:
  577. raise PolynomialError("only univariate polynomials are allowed")
  578. # get rid of gen and it's free symbol
  579. d = Dummy()
  580. poly = poly.subs(poly.gen, d)
  581. x = symbols('x')
  582. # see what others are left and select x or a numbered x
  583. # that doesn't clash
  584. free_names = {str(i) for i in poly.free_symbols}
  585. for x in chain((symbols('x'),), numbered_symbols('x')):
  586. if x.name not in free_names:
  587. poly = poly.xreplace({d: x})
  588. break
  589. coeff, poly = cls._preprocess_roots(poly)
  590. roots = []
  591. for root in getattr(cls, method)(poly):
  592. roots.append(coeff*cls._postprocess_root(root, radicals))
  593. return roots
  594. @classmethod
  595. def clear_cache(cls):
  596. """Reset cache for reals and complexes.
  597. The intervals used to approximate a root instance are updated
  598. as needed. When a request is made to see the intervals, the
  599. most current values are shown. `clear_cache` will reset all
  600. CRootOf instances back to their original state.
  601. See Also
  602. ========
  603. _reset
  604. """
  605. global _reals_cache, _complexes_cache
  606. _reals_cache = _pure_key_dict()
  607. _complexes_cache = _pure_key_dict()
  608. def _get_interval(self):
  609. """Internal function for retrieving isolation interval from cache. """
  610. if self.is_real:
  611. return _reals_cache[self.poly][self.index]
  612. else:
  613. reals_count = len(_reals_cache[self.poly])
  614. return _complexes_cache[self.poly][self.index - reals_count]
  615. def _set_interval(self, interval):
  616. """Internal function for updating isolation interval in cache. """
  617. if self.is_real:
  618. _reals_cache[self.poly][self.index] = interval
  619. else:
  620. reals_count = len(_reals_cache[self.poly])
  621. _complexes_cache[self.poly][self.index - reals_count] = interval
  622. def _eval_subs(self, old, new):
  623. # don't allow subs to change anything
  624. return self
  625. def _eval_conjugate(self):
  626. if self.is_real:
  627. return self
  628. expr, i = self.args
  629. return self.func(expr, i + (1 if self._get_interval().conj else -1))
  630. def eval_approx(self, n):
  631. """Evaluate this complex root to the given precision.
  632. This uses secant method and root bounds are used to both
  633. generate an initial guess and to check that the root
  634. returned is valid. If ever the method converges outside the
  635. root bounds, the bounds will be made smaller and updated.
  636. """
  637. prec = dps_to_prec(n)
  638. with workprec(prec):
  639. g = self.poly.gen
  640. if not g.is_Symbol:
  641. d = Dummy('x')
  642. if self.is_imaginary:
  643. d *= I
  644. func = lambdify(d, self.expr.subs(g, d))
  645. else:
  646. expr = self.expr
  647. if self.is_imaginary:
  648. expr = self.expr.subs(g, I*g)
  649. func = lambdify(g, expr)
  650. interval = self._get_interval()
  651. while True:
  652. if self.is_real:
  653. a = mpf(str(interval.a))
  654. b = mpf(str(interval.b))
  655. if a == b:
  656. root = a
  657. break
  658. x0 = mpf(str(interval.center))
  659. x1 = x0 + mpf(str(interval.dx))/4
  660. elif self.is_imaginary:
  661. a = mpf(str(interval.ay))
  662. b = mpf(str(interval.by))
  663. if a == b:
  664. root = mpc(mpf('0'), a)
  665. break
  666. x0 = mpf(str(interval.center[1]))
  667. x1 = x0 + mpf(str(interval.dy))/4
  668. else:
  669. ax = mpf(str(interval.ax))
  670. bx = mpf(str(interval.bx))
  671. ay = mpf(str(interval.ay))
  672. by = mpf(str(interval.by))
  673. if ax == bx and ay == by:
  674. root = mpc(ax, ay)
  675. break
  676. x0 = mpc(*map(str, interval.center))
  677. x1 = x0 + mpc(*map(str, (interval.dx, interval.dy)))/4
  678. try:
  679. # without a tolerance, this will return when (to within
  680. # the given precision) x_i == x_{i-1}
  681. root = findroot(func, (x0, x1))
  682. # If the (real or complex) root is not in the 'interval',
  683. # then keep refining the interval. This happens if findroot
  684. # accidentally finds a different root outside of this
  685. # interval because our initial estimate 'x0' was not close
  686. # enough. It is also possible that the secant method will
  687. # get trapped by a max/min in the interval; the root
  688. # verification by findroot will raise a ValueError in this
  689. # case and the interval will then be tightened -- and
  690. # eventually the root will be found.
  691. #
  692. # It is also possible that findroot will not have any
  693. # successful iterations to process (in which case it
  694. # will fail to initialize a variable that is tested
  695. # after the iterations and raise an UnboundLocalError).
  696. if self.is_real or self.is_imaginary:
  697. if not bool(root.imag) == self.is_real and (
  698. a <= root <= b):
  699. if self.is_imaginary:
  700. root = mpc(mpf('0'), root.real)
  701. break
  702. elif (ax <= root.real <= bx and ay <= root.imag <= by):
  703. break
  704. except (UnboundLocalError, ValueError):
  705. pass
  706. interval = interval.refine()
  707. # update the interval so we at least (for this precision or
  708. # less) don't have much work to do to recompute the root
  709. self._set_interval(interval)
  710. return (Float._new(root.real._mpf_, prec) +
  711. I*Float._new(root.imag._mpf_, prec))
  712. def _eval_evalf(self, prec, **kwargs):
  713. """Evaluate this complex root to the given precision."""
  714. # all kwargs are ignored
  715. return self.eval_rational(n=prec_to_dps(prec))._evalf(prec)
  716. def eval_rational(self, dx=None, dy=None, n=15):
  717. """
  718. Return a Rational approximation of ``self`` that has real
  719. and imaginary component approximations that are within ``dx``
  720. and ``dy`` of the true values, respectively. Alternatively,
  721. ``n`` digits of precision can be specified.
  722. The interval is refined with bisection and is sure to
  723. converge. The root bounds are updated when the refinement
  724. is complete so recalculation at the same or lesser precision
  725. will not have to repeat the refinement and should be much
  726. faster.
  727. The following example first obtains Rational approximation to
  728. 1e-8 accuracy for all roots of the 4-th order Legendre
  729. polynomial. Since the roots are all less than 1, this will
  730. ensure the decimal representation of the approximation will be
  731. correct (including rounding) to 6 digits:
  732. >>> from sympy import legendre_poly, Symbol
  733. >>> x = Symbol("x")
  734. >>> p = legendre_poly(4, x, polys=True)
  735. >>> r = p.real_roots()[-1]
  736. >>> r.eval_rational(10**-8).n(6)
  737. 0.861136
  738. It is not necessary to a two-step calculation, however: the
  739. decimal representation can be computed directly:
  740. >>> r.evalf(17)
  741. 0.86113631159405258
  742. """
  743. dy = dy or dx
  744. if dx:
  745. rtol = None
  746. dx = dx if isinstance(dx, Rational) else Rational(str(dx))
  747. dy = dy if isinstance(dy, Rational) else Rational(str(dy))
  748. else:
  749. # 5 binary (or 2 decimal) digits are needed to ensure that
  750. # a given digit is correctly rounded
  751. # prec_to_dps(dps_to_prec(n) + 5) - n <= 2 (tested for
  752. # n in range(1000000)
  753. rtol = S(10)**-(n + 2) # +2 for guard digits
  754. interval = self._get_interval()
  755. while True:
  756. if self.is_real:
  757. if rtol:
  758. dx = abs(interval.center*rtol)
  759. interval = interval.refine_size(dx=dx)
  760. c = interval.center
  761. real = Rational(c)
  762. imag = S.Zero
  763. if not rtol or interval.dx < abs(c*rtol):
  764. break
  765. elif self.is_imaginary:
  766. if rtol:
  767. dy = abs(interval.center[1]*rtol)
  768. dx = 1
  769. interval = interval.refine_size(dx=dx, dy=dy)
  770. c = interval.center[1]
  771. imag = Rational(c)
  772. real = S.Zero
  773. if not rtol or interval.dy < abs(c*rtol):
  774. break
  775. else:
  776. if rtol:
  777. dx = abs(interval.center[0]*rtol)
  778. dy = abs(interval.center[1]*rtol)
  779. interval = interval.refine_size(dx, dy)
  780. c = interval.center
  781. real, imag = map(Rational, c)
  782. if not rtol or (
  783. interval.dx < abs(c[0]*rtol) and
  784. interval.dy < abs(c[1]*rtol)):
  785. break
  786. # update the interval so we at least (for this precision or
  787. # less) don't have much work to do to recompute the root
  788. self._set_interval(interval)
  789. return real + I*imag
  790. CRootOf = ComplexRootOf
  791. @dispatch(ComplexRootOf, ComplexRootOf)
  792. def _eval_is_eq(lhs, rhs): # noqa:F811
  793. # if we use is_eq to check here, we get infinite recurion
  794. return lhs == rhs
  795. @dispatch(ComplexRootOf, Basic) # type:ignore
  796. def _eval_is_eq(lhs, rhs): # noqa:F811
  797. # CRootOf represents a Root, so if rhs is that root, it should set
  798. # the expression to zero *and* it should be in the interval of the
  799. # CRootOf instance. It must also be a number that agrees with the
  800. # is_real value of the CRootOf instance.
  801. if not rhs.is_number:
  802. return None
  803. if not rhs.is_finite:
  804. return False
  805. z = lhs.expr.subs(lhs.expr.free_symbols.pop(), rhs).is_zero
  806. if z is False: # all roots will make z True but we don't know
  807. # whether this is the right root if z is True
  808. return False
  809. o = rhs.is_real, rhs.is_imaginary
  810. s = lhs.is_real, lhs.is_imaginary
  811. assert None not in s # this is part of initial refinement
  812. if o != s and None not in o:
  813. return False
  814. re, im = rhs.as_real_imag()
  815. if lhs.is_real:
  816. if im:
  817. return False
  818. i = lhs._get_interval()
  819. a, b = [Rational(str(_)) for _ in (i.a, i.b)]
  820. return sympify(a <= rhs and rhs <= b)
  821. i = lhs._get_interval()
  822. r1, r2, i1, i2 = [Rational(str(j)) for j in (
  823. i.ax, i.bx, i.ay, i.by)]
  824. return is_le(r1, re) and is_le(re,r2) and is_le(i1,im) and is_le(im,i2)
  825. @public
  826. class RootSum(Expr):
  827. """Represents a sum of all roots of a univariate polynomial. """
  828. __slots__ = ('poly', 'fun', 'auto')
  829. def __new__(cls, expr, func=None, x=None, auto=True, quadratic=False):
  830. """Construct a new ``RootSum`` instance of roots of a polynomial."""
  831. coeff, poly = cls._transform(expr, x)
  832. if not poly.is_univariate:
  833. raise MultivariatePolynomialError(
  834. "only univariate polynomials are allowed")
  835. if func is None:
  836. func = Lambda(poly.gen, poly.gen)
  837. else:
  838. is_func = getattr(func, 'is_Function', False)
  839. if is_func and 1 in func.nargs:
  840. if not isinstance(func, Lambda):
  841. func = Lambda(poly.gen, func(poly.gen))
  842. else:
  843. raise ValueError(
  844. "expected a univariate function, got %s" % func)
  845. var, expr = func.variables[0], func.expr
  846. if coeff is not S.One:
  847. expr = expr.subs(var, coeff*var)
  848. deg = poly.degree()
  849. if not expr.has(var):
  850. return deg*expr
  851. if expr.is_Add:
  852. add_const, expr = expr.as_independent(var)
  853. else:
  854. add_const = S.Zero
  855. if expr.is_Mul:
  856. mul_const, expr = expr.as_independent(var)
  857. else:
  858. mul_const = S.One
  859. func = Lambda(var, expr)
  860. rational = cls._is_func_rational(poly, func)
  861. factors, terms = _pure_factors(poly), []
  862. for poly, k in factors:
  863. if poly.is_linear:
  864. term = func(roots_linear(poly)[0])
  865. elif quadratic and poly.is_quadratic:
  866. term = sum(map(func, roots_quadratic(poly)))
  867. else:
  868. if not rational or not auto:
  869. term = cls._new(poly, func, auto)
  870. else:
  871. term = cls._rational_case(poly, func)
  872. terms.append(k*term)
  873. return mul_const*Add(*terms) + deg*add_const
  874. @classmethod
  875. def _new(cls, poly, func, auto=True):
  876. """Construct new raw ``RootSum`` instance. """
  877. obj = Expr.__new__(cls)
  878. obj.poly = poly
  879. obj.fun = func
  880. obj.auto = auto
  881. return obj
  882. @classmethod
  883. def new(cls, poly, func, auto=True):
  884. """Construct new ``RootSum`` instance. """
  885. if not func.expr.has(*func.variables):
  886. return func.expr
  887. rational = cls._is_func_rational(poly, func)
  888. if not rational or not auto:
  889. return cls._new(poly, func, auto)
  890. else:
  891. return cls._rational_case(poly, func)
  892. @classmethod
  893. def _transform(cls, expr, x):
  894. """Transform an expression to a polynomial. """
  895. poly = PurePoly(expr, x, greedy=False)
  896. return preprocess_roots(poly)
  897. @classmethod
  898. def _is_func_rational(cls, poly, func):
  899. """Check if a lambda is a rational function. """
  900. var, expr = func.variables[0], func.expr
  901. return expr.is_rational_function(var)
  902. @classmethod
  903. def _rational_case(cls, poly, func):
  904. """Handle the rational function case. """
  905. roots = symbols('r:%d' % poly.degree())
  906. var, expr = func.variables[0], func.expr
  907. f = sum(expr.subs(var, r) for r in roots)
  908. p, q = together(f).as_numer_denom()
  909. domain = QQ[roots]
  910. p = p.expand()
  911. q = q.expand()
  912. try:
  913. p = Poly(p, domain=domain, expand=False)
  914. except GeneratorsNeeded:
  915. p, p_coeff = None, (p,)
  916. else:
  917. p_monom, p_coeff = zip(*p.terms())
  918. try:
  919. q = Poly(q, domain=domain, expand=False)
  920. except GeneratorsNeeded:
  921. q, q_coeff = None, (q,)
  922. else:
  923. q_monom, q_coeff = zip(*q.terms())
  924. coeffs, mapping = symmetrize(p_coeff + q_coeff, formal=True)
  925. formulas, values = viete(poly, roots), []
  926. for (sym, _), (_, val) in zip(mapping, formulas):
  927. values.append((sym, val))
  928. for i, (coeff, _) in enumerate(coeffs):
  929. coeffs[i] = coeff.subs(values)
  930. n = len(p_coeff)
  931. p_coeff = coeffs[:n]
  932. q_coeff = coeffs[n:]
  933. if p is not None:
  934. p = Poly(dict(zip(p_monom, p_coeff)), *p.gens).as_expr()
  935. else:
  936. (p,) = p_coeff
  937. if q is not None:
  938. q = Poly(dict(zip(q_monom, q_coeff)), *q.gens).as_expr()
  939. else:
  940. (q,) = q_coeff
  941. return factor(p/q)
  942. def _hashable_content(self):
  943. return (self.poly, self.fun)
  944. @property
  945. def expr(self):
  946. return self.poly.as_expr()
  947. @property
  948. def args(self):
  949. return (self.expr, self.fun, self.poly.gen)
  950. @property
  951. def free_symbols(self):
  952. return self.poly.free_symbols | self.fun.free_symbols
  953. @property
  954. def is_commutative(self):
  955. return True
  956. def doit(self, **hints):
  957. if not hints.get('roots', True):
  958. return self
  959. _roots = roots(self.poly, multiple=True)
  960. if len(_roots) < self.poly.degree():
  961. return self
  962. else:
  963. return Add(*[self.fun(r) for r in _roots])
  964. def _eval_evalf(self, prec):
  965. try:
  966. _roots = self.poly.nroots(n=prec_to_dps(prec))
  967. except (DomainError, PolynomialError):
  968. return self
  969. else:
  970. return Add(*[self.fun(r) for r in _roots])
  971. def _eval_derivative(self, x):
  972. var, expr = self.fun.args
  973. func = Lambda(var, expr.diff(x))
  974. return self.new(self.poly, func, self.auto)