rational.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import operator
  2. import sys
  3. from .libmp import int_types, mpf_hash, bitcount, from_man_exp, HASH_MODULUS
  4. new = object.__new__
  5. def create_reduced(p, q, _cache={}):
  6. key = p, q
  7. if key in _cache:
  8. return _cache[key]
  9. x, y = p, q
  10. while y:
  11. x, y = y, x % y
  12. if x != 1:
  13. p //= x
  14. q //= x
  15. v = new(mpq)
  16. v._mpq_ = p, q
  17. # Speedup integers, half-integers and other small fractions
  18. if q <= 4 and abs(key[0]) < 100:
  19. _cache[key] = v
  20. return v
  21. class mpq(object):
  22. """
  23. Exact rational type, currently only intended for internal use.
  24. """
  25. __slots__ = ["_mpq_"]
  26. def __new__(cls, p, q=1):
  27. if type(p) is tuple:
  28. p, q = p
  29. elif hasattr(p, '_mpq_'):
  30. p, q = p._mpq_
  31. return create_reduced(p, q)
  32. def __repr__(s):
  33. return "mpq(%s,%s)" % s._mpq_
  34. def __str__(s):
  35. return "(%s/%s)" % s._mpq_
  36. def __int__(s):
  37. a, b = s._mpq_
  38. return a // b
  39. def __nonzero__(s):
  40. return bool(s._mpq_[0])
  41. __bool__ = __nonzero__
  42. def __hash__(s):
  43. a, b = s._mpq_
  44. if sys.version_info >= (3, 2):
  45. inverse = pow(b, HASH_MODULUS-2, HASH_MODULUS)
  46. if not inverse:
  47. h = sys.hash_info.inf
  48. else:
  49. h = (abs(a) * inverse) % HASH_MODULUS
  50. if a < 0: h = -h
  51. if h == -1: h = -2
  52. return h
  53. else:
  54. if b == 1:
  55. return hash(a)
  56. # Power of two: mpf compatible hash
  57. if not (b & (b-1)):
  58. return mpf_hash(from_man_exp(a, 1-bitcount(b)))
  59. return hash((a,b))
  60. def __eq__(s, t):
  61. ttype = type(t)
  62. if ttype is mpq:
  63. return s._mpq_ == t._mpq_
  64. if ttype in int_types:
  65. a, b = s._mpq_
  66. if b != 1:
  67. return False
  68. return a == t
  69. return NotImplemented
  70. def __ne__(s, t):
  71. ttype = type(t)
  72. if ttype is mpq:
  73. return s._mpq_ != t._mpq_
  74. if ttype in int_types:
  75. a, b = s._mpq_
  76. if b != 1:
  77. return True
  78. return a != t
  79. return NotImplemented
  80. def _cmp(s, t, op):
  81. ttype = type(t)
  82. if ttype in int_types:
  83. a, b = s._mpq_
  84. return op(a, t*b)
  85. if ttype is mpq:
  86. a, b = s._mpq_
  87. c, d = t._mpq_
  88. return op(a*d, b*c)
  89. return NotImplementedError
  90. def __lt__(s, t): return s._cmp(t, operator.lt)
  91. def __le__(s, t): return s._cmp(t, operator.le)
  92. def __gt__(s, t): return s._cmp(t, operator.gt)
  93. def __ge__(s, t): return s._cmp(t, operator.ge)
  94. def __abs__(s):
  95. a, b = s._mpq_
  96. if a >= 0:
  97. return s
  98. v = new(mpq)
  99. v._mpq_ = -a, b
  100. return v
  101. def __neg__(s):
  102. a, b = s._mpq_
  103. v = new(mpq)
  104. v._mpq_ = -a, b
  105. return v
  106. def __pos__(s):
  107. return s
  108. def __add__(s, t):
  109. ttype = type(t)
  110. if ttype is mpq:
  111. a, b = s._mpq_
  112. c, d = t._mpq_
  113. return create_reduced(a*d+b*c, b*d)
  114. if ttype in int_types:
  115. a, b = s._mpq_
  116. v = new(mpq)
  117. v._mpq_ = a+b*t, b
  118. return v
  119. return NotImplemented
  120. __radd__ = __add__
  121. def __sub__(s, t):
  122. ttype = type(t)
  123. if ttype is mpq:
  124. a, b = s._mpq_
  125. c, d = t._mpq_
  126. return create_reduced(a*d-b*c, b*d)
  127. if ttype in int_types:
  128. a, b = s._mpq_
  129. v = new(mpq)
  130. v._mpq_ = a-b*t, b
  131. return v
  132. return NotImplemented
  133. def __rsub__(s, t):
  134. ttype = type(t)
  135. if ttype is mpq:
  136. a, b = s._mpq_
  137. c, d = t._mpq_
  138. return create_reduced(b*c-a*d, b*d)
  139. if ttype in int_types:
  140. a, b = s._mpq_
  141. v = new(mpq)
  142. v._mpq_ = b*t-a, b
  143. return v
  144. return NotImplemented
  145. def __mul__(s, t):
  146. ttype = type(t)
  147. if ttype is mpq:
  148. a, b = s._mpq_
  149. c, d = t._mpq_
  150. return create_reduced(a*c, b*d)
  151. if ttype in int_types:
  152. a, b = s._mpq_
  153. return create_reduced(a*t, b)
  154. return NotImplemented
  155. __rmul__ = __mul__
  156. def __div__(s, t):
  157. ttype = type(t)
  158. if ttype is mpq:
  159. a, b = s._mpq_
  160. c, d = t._mpq_
  161. return create_reduced(a*d, b*c)
  162. if ttype in int_types:
  163. a, b = s._mpq_
  164. return create_reduced(a, b*t)
  165. return NotImplemented
  166. def __rdiv__(s, t):
  167. ttype = type(t)
  168. if ttype is mpq:
  169. a, b = s._mpq_
  170. c, d = t._mpq_
  171. return create_reduced(b*c, a*d)
  172. if ttype in int_types:
  173. a, b = s._mpq_
  174. return create_reduced(b*t, a)
  175. return NotImplemented
  176. def __pow__(s, t):
  177. ttype = type(t)
  178. if ttype in int_types:
  179. a, b = s._mpq_
  180. if t:
  181. if t < 0:
  182. a, b, t = b, a, -t
  183. v = new(mpq)
  184. v._mpq_ = a**t, b**t
  185. return v
  186. raise ZeroDivisionError
  187. return NotImplemented
  188. mpq_1 = mpq((1,1))
  189. mpq_0 = mpq((0,1))
  190. mpq_1_2 = mpq((1,2))
  191. mpq_3_2 = mpq((3,2))
  192. mpq_1_4 = mpq((1,4))
  193. mpq_1_16 = mpq((1,16))
  194. mpq_3_16 = mpq((3,16))
  195. mpq_5_2 = mpq((5,2))
  196. mpq_3_4 = mpq((3,4))
  197. mpq_7_4 = mpq((7,4))
  198. mpq_5_4 = mpq((5,4))
  199. # Register with "numbers" ABC
  200. # We do not subclass, hence we do not use the @abstractmethod checks. While
  201. # this is less invasive it may turn out that we do not actually support
  202. # parts of the expected interfaces. See
  203. # http://docs.python.org/2/library/numbers.html for list of abstract
  204. # methods.
  205. try:
  206. import numbers
  207. numbers.Rational.register(mpq)
  208. except ImportError:
  209. pass