mpelements.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. """Real and complex elements. """
  2. from sympy.polys.domains.domainelement import DomainElement
  3. from sympy.utilities import public
  4. from mpmath.ctx_mp_python import PythonMPContext, _mpf, _mpc, _constant
  5. from mpmath.libmp import (MPZ_ONE, fzero, fone, finf, fninf, fnan,
  6. round_nearest, mpf_mul, repr_dps, int_types,
  7. from_int, from_float, from_str, to_rational)
  8. from mpmath.rational import mpq
  9. @public
  10. class RealElement(_mpf, DomainElement):
  11. """An element of a real domain. """
  12. __slots__ = ('__mpf__',)
  13. def _set_mpf(self, val):
  14. self.__mpf__ = val
  15. _mpf_ = property(lambda self: self.__mpf__, _set_mpf)
  16. def parent(self):
  17. return self.context._parent
  18. @public
  19. class ComplexElement(_mpc, DomainElement):
  20. """An element of a complex domain. """
  21. __slots__ = ('__mpc__',)
  22. def _set_mpc(self, val):
  23. self.__mpc__ = val
  24. _mpc_ = property(lambda self: self.__mpc__, _set_mpc)
  25. def parent(self):
  26. return self.context._parent
  27. new = object.__new__
  28. @public
  29. class MPContext(PythonMPContext):
  30. def __init__(ctx, prec=53, dps=None, tol=None, real=False):
  31. ctx._prec_rounding = [prec, round_nearest]
  32. if dps is None:
  33. ctx._set_prec(prec)
  34. else:
  35. ctx._set_dps(dps)
  36. ctx.mpf = RealElement
  37. ctx.mpc = ComplexElement
  38. ctx.mpf._ctxdata = [ctx.mpf, new, ctx._prec_rounding]
  39. ctx.mpc._ctxdata = [ctx.mpc, new, ctx._prec_rounding]
  40. if real:
  41. ctx.mpf.context = ctx
  42. else:
  43. ctx.mpc.context = ctx
  44. ctx.constant = _constant
  45. ctx.constant._ctxdata = [ctx.mpf, new, ctx._prec_rounding]
  46. ctx.constant.context = ctx
  47. ctx.types = [ctx.mpf, ctx.mpc, ctx.constant]
  48. ctx.trap_complex = True
  49. ctx.pretty = True
  50. if tol is None:
  51. ctx.tol = ctx._make_tol()
  52. elif tol is False:
  53. ctx.tol = fzero
  54. else:
  55. ctx.tol = ctx._convert_tol(tol)
  56. ctx.tolerance = ctx.make_mpf(ctx.tol)
  57. if not ctx.tolerance:
  58. ctx.max_denom = 1000000
  59. else:
  60. ctx.max_denom = int(1/ctx.tolerance)
  61. ctx.zero = ctx.make_mpf(fzero)
  62. ctx.one = ctx.make_mpf(fone)
  63. ctx.j = ctx.make_mpc((fzero, fone))
  64. ctx.inf = ctx.make_mpf(finf)
  65. ctx.ninf = ctx.make_mpf(fninf)
  66. ctx.nan = ctx.make_mpf(fnan)
  67. def _make_tol(ctx):
  68. hundred = (0, 25, 2, 5)
  69. eps = (0, MPZ_ONE, 1-ctx.prec, 1)
  70. return mpf_mul(hundred, eps)
  71. def make_tol(ctx):
  72. return ctx.make_mpf(ctx._make_tol())
  73. def _convert_tol(ctx, tol):
  74. if isinstance(tol, int_types):
  75. return from_int(tol)
  76. if isinstance(tol, float):
  77. return from_float(tol)
  78. if hasattr(tol, "_mpf_"):
  79. return tol._mpf_
  80. prec, rounding = ctx._prec_rounding
  81. if isinstance(tol, str):
  82. return from_str(tol, prec, rounding)
  83. raise ValueError("expected a real number, got %s" % tol)
  84. def _convert_fallback(ctx, x, strings):
  85. raise TypeError("cannot create mpf from " + repr(x))
  86. @property
  87. def _repr_digits(ctx):
  88. return repr_dps(ctx._prec)
  89. @property
  90. def _str_digits(ctx):
  91. return ctx._dps
  92. def to_rational(ctx, s, limit=True):
  93. p, q = to_rational(s._mpf_)
  94. if not limit or q <= ctx.max_denom:
  95. return p, q
  96. p0, q0, p1, q1 = 0, 1, 1, 0
  97. n, d = p, q
  98. while True:
  99. a = n//d
  100. q2 = q0 + a*q1
  101. if q2 > ctx.max_denom:
  102. break
  103. p0, q0, p1, q1 = p1, q1, p0 + a*p1, q2
  104. n, d = d, n - a*d
  105. k = (ctx.max_denom - q0)//q1
  106. number = mpq(p, q)
  107. bound1 = mpq(p0 + k*p1, q0 + k*q1)
  108. bound2 = mpq(p1, q1)
  109. if not bound2 or not bound1:
  110. return p, q
  111. elif abs(bound2 - number) <= abs(bound1 - number):
  112. return bound2._mpq_
  113. else:
  114. return bound1._mpq_
  115. def almosteq(ctx, s, t, rel_eps=None, abs_eps=None):
  116. t = ctx.convert(t)
  117. if abs_eps is None and rel_eps is None:
  118. rel_eps = abs_eps = ctx.tolerance or ctx.make_tol()
  119. if abs_eps is None:
  120. abs_eps = ctx.convert(rel_eps)
  121. elif rel_eps is None:
  122. rel_eps = ctx.convert(abs_eps)
  123. diff = abs(s-t)
  124. if diff <= abs_eps:
  125. return True
  126. abss = abs(s)
  127. abst = abs(t)
  128. if abss < abst:
  129. err = diff/abst
  130. else:
  131. err = diff/abss
  132. return err <= rel_eps