pythonmpq.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. """
  2. PythonMPQ: Rational number type based on Python integers.
  3. This class is intended as a pure Python fallback for when gmpy2 is not
  4. installed. If gmpy2 is installed then its mpq type will be used instead. The
  5. mpq type is around 20x faster. We could just use the stdlib Fraction class
  6. here but that is slower:
  7. from fractions import Fraction
  8. from sympy.external.pythonmpq import PythonMPQ
  9. nums = range(1000)
  10. dens = range(5, 1005)
  11. rats = [Fraction(n, d) for n, d in zip(nums, dens)]
  12. sum(rats) # <--- 24 milliseconds
  13. rats = [PythonMPQ(n, d) for n, d in zip(nums, dens)]
  14. sum(rats) # <--- 7 milliseconds
  15. Both mpq and Fraction have some awkward features like the behaviour of
  16. division with // and %:
  17. >>> from fractions import Fraction
  18. >>> Fraction(2, 3) % Fraction(1, 4)
  19. 1/6
  20. For the QQ domain we do not want this behaviour because there should be no
  21. remainder when dividing rational numbers. SymPy does not make use of this
  22. aspect of mpq when gmpy2 is installed. Since this class is a fallback for that
  23. case we do not bother implementing e.g. __mod__ so that we can be sure we
  24. are not using it when gmpy2 is installed either.
  25. """
  26. import operator
  27. from math import gcd
  28. from decimal import Decimal
  29. from fractions import Fraction
  30. import sys
  31. from typing import Tuple as tTuple, Type
  32. # Used for __hash__
  33. _PyHASH_MODULUS = sys.hash_info.modulus
  34. _PyHASH_INF = sys.hash_info.inf
  35. class PythonMPQ:
  36. """Rational number implementation that is intended to be compatible with
  37. gmpy2's mpq.
  38. Also slightly faster than fractions.Fraction.
  39. PythonMPQ should be treated as immutable although no effort is made to
  40. prevent mutation (since that might slow down calculations).
  41. """
  42. __slots__ = ('numerator', 'denominator')
  43. def __new__(cls, numerator, denominator=None):
  44. """Construct PythonMPQ with gcd computation and checks"""
  45. if denominator is not None:
  46. #
  47. # PythonMPQ(n, d): require n and d to be int and d != 0
  48. #
  49. if isinstance(numerator, int) and isinstance(denominator, int):
  50. # This is the slow part:
  51. divisor = gcd(numerator, denominator)
  52. numerator //= divisor
  53. denominator //= divisor
  54. return cls._new_check(numerator, denominator)
  55. else:
  56. #
  57. # PythonMPQ(q)
  58. #
  59. # Here q can be PythonMPQ, int, Decimal, float, Fraction or str
  60. #
  61. if isinstance(numerator, int):
  62. return cls._new(numerator, 1)
  63. elif isinstance(numerator, PythonMPQ):
  64. return cls._new(numerator.numerator, numerator.denominator)
  65. # Let Fraction handle Decimal/float conversion and str parsing
  66. if isinstance(numerator, (Decimal, float, str)):
  67. numerator = Fraction(numerator)
  68. if isinstance(numerator, Fraction):
  69. return cls._new(numerator.numerator, numerator.denominator)
  70. #
  71. # Reject everything else. This is more strict than mpq which allows
  72. # things like mpq(Fraction, Fraction) or mpq(Decimal, any). The mpq
  73. # behaviour is somewhat inconsistent so we choose to accept only a
  74. # more strict subset of what mpq allows.
  75. #
  76. raise TypeError("PythonMPQ() requires numeric or string argument")
  77. @classmethod
  78. def _new_check(cls, numerator, denominator):
  79. """Construct PythonMPQ, check divide by zero and canonicalize signs"""
  80. if not denominator:
  81. raise ZeroDivisionError(f'Zero divisor {numerator}/{denominator}')
  82. elif denominator < 0:
  83. numerator = -numerator
  84. denominator = -denominator
  85. return cls._new(numerator, denominator)
  86. @classmethod
  87. def _new(cls, numerator, denominator):
  88. """Construct PythonMPQ efficiently (no checks)"""
  89. obj = super().__new__(cls)
  90. obj.numerator = numerator
  91. obj.denominator = denominator
  92. return obj
  93. def __int__(self):
  94. """Convert to int (truncates towards zero)"""
  95. p, q = self.numerator, self.denominator
  96. if p < 0:
  97. return -(-p//q)
  98. return p//q
  99. def __float__(self):
  100. """Convert to float (approximately)"""
  101. return self.numerator / self.denominator
  102. def __bool__(self):
  103. """True/False if nonzero/zero"""
  104. return bool(self.numerator)
  105. def __eq__(self, other):
  106. """Compare equal with PythonMPQ, int, float, Decimal or Fraction"""
  107. if isinstance(other, PythonMPQ):
  108. return (self.numerator == other.numerator
  109. and self.denominator == other.denominator)
  110. elif isinstance(other, self._compatible_types):
  111. return self.__eq__(PythonMPQ(other))
  112. else:
  113. return NotImplemented
  114. # The hashing algorithm for Fraction changed in Python 3.8
  115. if sys.version_info >= (3, 8):
  116. #
  117. # Hash for Python 3.8 onwards
  118. #
  119. def __hash__(self):
  120. """hash - same as mpq/Fraction"""
  121. try:
  122. dinv = pow(self.denominator, -1, _PyHASH_MODULUS)
  123. except ValueError:
  124. hash_ = _PyHASH_INF
  125. else:
  126. hash_ = hash(hash(abs(self.numerator)) * dinv)
  127. result = hash_ if self.numerator >= 0 else -hash_
  128. return -2 if result == -1 else result
  129. else:
  130. #
  131. # Hash for Python < 3.7
  132. #
  133. def __hash__(self):
  134. """hash - same as mpq/Fraction"""
  135. # This is from fractions.py in the stdlib.
  136. dinv = pow(self.denominator, _PyHASH_MODULUS - 2, _PyHASH_MODULUS)
  137. if not dinv:
  138. hash_ = _PyHASH_INF
  139. else:
  140. hash_ = abs(self.numerator) * dinv % _PyHASH_MODULUS
  141. result = hash_ if self >= 0 else -hash_
  142. return -2 if result == -1 else result
  143. def __reduce__(self):
  144. """Deconstruct for pickling"""
  145. return type(self), (self.numerator, self.denominator)
  146. def __str__(self):
  147. """Convert to string"""
  148. if self.denominator != 1:
  149. return f"{self.numerator}/{self.denominator}"
  150. else:
  151. return f"{self.numerator}"
  152. def __repr__(self):
  153. """Convert to string"""
  154. return f"MPQ({self.numerator},{self.denominator})"
  155. def _cmp(self, other, op):
  156. """Helper for lt/le/gt/ge"""
  157. if not isinstance(other, self._compatible_types):
  158. return NotImplemented
  159. lhs = self.numerator * other.denominator
  160. rhs = other.numerator * self.denominator
  161. return op(lhs, rhs)
  162. def __lt__(self, other):
  163. """self < other"""
  164. return self._cmp(other, operator.lt)
  165. def __le__(self, other):
  166. """self <= other"""
  167. return self._cmp(other, operator.le)
  168. def __gt__(self, other):
  169. """self > other"""
  170. return self._cmp(other, operator.gt)
  171. def __ge__(self, other):
  172. """self >= other"""
  173. return self._cmp(other, operator.ge)
  174. def __abs__(self):
  175. """abs(q)"""
  176. return self._new(abs(self.numerator), self.denominator)
  177. def __pos__(self):
  178. """+q"""
  179. return self
  180. def __neg__(self):
  181. """-q"""
  182. return self._new(-self.numerator, self.denominator)
  183. def __add__(self, other):
  184. """q1 + q2"""
  185. if isinstance(other, PythonMPQ):
  186. #
  187. # This is much faster than the naive method used in the stdlib
  188. # fractions module. Not sure where this method comes from
  189. # though...
  190. #
  191. # Compare timings for something like:
  192. # nums = range(1000)
  193. # rats = [PythonMPQ(n, d) for n, d in zip(nums[:-5], nums[5:])]
  194. # sum(rats) # <-- time this
  195. #
  196. ap, aq = self.numerator, self.denominator
  197. bp, bq = other.numerator, other.denominator
  198. g = gcd(aq, bq)
  199. if g == 1:
  200. p = ap*bq + aq*bp
  201. q = bq*aq
  202. else:
  203. q1, q2 = aq//g, bq//g
  204. p, q = ap*q2 + bp*q1, q1*q2
  205. g2 = gcd(p, g)
  206. p, q = (p // g2), q * (g // g2)
  207. elif isinstance(other, int):
  208. p = self.numerator + self.denominator * other
  209. q = self.denominator
  210. else:
  211. return NotImplemented
  212. return self._new(p, q)
  213. def __radd__(self, other):
  214. """z1 + q2"""
  215. if isinstance(other, int):
  216. p = self.numerator + self.denominator * other
  217. q = self.denominator
  218. return self._new(p, q)
  219. else:
  220. return NotImplemented
  221. def __sub__(self ,other):
  222. """q1 - q2"""
  223. if isinstance(other, PythonMPQ):
  224. ap, aq = self.numerator, self.denominator
  225. bp, bq = other.numerator, other.denominator
  226. g = gcd(aq, bq)
  227. if g == 1:
  228. p = ap*bq - aq*bp
  229. q = bq*aq
  230. else:
  231. q1, q2 = aq//g, bq//g
  232. p, q = ap*q2 - bp*q1, q1*q2
  233. g2 = gcd(p, g)
  234. p, q = (p // g2), q * (g // g2)
  235. elif isinstance(other, int):
  236. p = self.numerator - self.denominator*other
  237. q = self.denominator
  238. else:
  239. return NotImplemented
  240. return self._new(p, q)
  241. def __rsub__(self, other):
  242. """z1 - q2"""
  243. if isinstance(other, int):
  244. p = self.denominator * other - self.numerator
  245. q = self.denominator
  246. return self._new(p, q)
  247. else:
  248. return NotImplemented
  249. def __mul__(self, other):
  250. """q1 * q2"""
  251. if isinstance(other, PythonMPQ):
  252. ap, aq = self.numerator, self.denominator
  253. bp, bq = other.numerator, other.denominator
  254. x1 = gcd(ap, bq)
  255. x2 = gcd(bp, aq)
  256. p, q = ((ap//x1)*(bp//x2), (aq//x2)*(bq//x1))
  257. elif isinstance(other, int):
  258. x = gcd(other, self.denominator)
  259. p = self.numerator*(other//x)
  260. q = self.denominator//x
  261. else:
  262. return NotImplemented
  263. return self._new(p, q)
  264. def __rmul__(self, other):
  265. """z1 * q2"""
  266. if isinstance(other, int):
  267. x = gcd(self.denominator, other)
  268. p = self.numerator*(other//x)
  269. q = self.denominator//x
  270. return self._new(p, q)
  271. else:
  272. return NotImplemented
  273. def __pow__(self, exp):
  274. """q ** z"""
  275. p, q = self.numerator, self.denominator
  276. if exp < 0:
  277. p, q, exp = q, p, -exp
  278. return self._new_check(p**exp, q**exp)
  279. def __truediv__(self, other):
  280. """q1 / q2"""
  281. if isinstance(other, PythonMPQ):
  282. ap, aq = self.numerator, self.denominator
  283. bp, bq = other.numerator, other.denominator
  284. x1 = gcd(ap, bp)
  285. x2 = gcd(bq, aq)
  286. p, q = ((ap//x1)*(bq//x2), (aq//x2)*(bp//x1))
  287. elif isinstance(other, int):
  288. x = gcd(other, self.numerator)
  289. p = self.numerator//x
  290. q = self.denominator*(other//x)
  291. else:
  292. return NotImplemented
  293. return self._new_check(p, q)
  294. def __rtruediv__(self, other):
  295. """z / q"""
  296. if isinstance(other, int):
  297. x = gcd(self.numerator, other)
  298. p = self.denominator*(other//x)
  299. q = self.numerator//x
  300. return self._new_check(p, q)
  301. else:
  302. return NotImplemented
  303. _compatible_types: tTuple[Type, ...] = ()
  304. #
  305. # These are the types that PythonMPQ will interoperate with for operations
  306. # and comparisons such as ==, + etc. We define this down here so that we can
  307. # include PythonMPQ in the list as well.
  308. #
  309. PythonMPQ._compatible_types = (PythonMPQ, int, Decimal, Fraction)