libintmath.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. """
  2. Utility functions for integer math.
  3. TODO: rename, cleanup, perhaps move the gmpy wrapper code
  4. here from settings.py
  5. """
  6. import math
  7. from bisect import bisect
  8. from .backend import xrange
  9. from .backend import BACKEND, gmpy, sage, sage_utils, MPZ, MPZ_ONE, MPZ_ZERO
  10. small_trailing = [0] * 256
  11. for j in range(1,8):
  12. small_trailing[1<<j::1<<(j+1)] = [j] * (1<<(7-j))
  13. def giant_steps(start, target, n=2):
  14. """
  15. Return a list of integers ~=
  16. [start, n*start, ..., target/n^2, target/n, target]
  17. but conservatively rounded so that the quotient between two
  18. successive elements is actually slightly less than n.
  19. With n = 2, this describes suitable precision steps for a
  20. quadratically convergent algorithm such as Newton's method;
  21. with n = 3 steps for cubic convergence (Halley's method), etc.
  22. >>> giant_steps(50,1000)
  23. [66, 128, 253, 502, 1000]
  24. >>> giant_steps(50,1000,4)
  25. [65, 252, 1000]
  26. """
  27. L = [target]
  28. while L[-1] > start*n:
  29. L = L + [L[-1]//n + 2]
  30. return L[::-1]
  31. def rshift(x, n):
  32. """For an integer x, calculate x >> n with the fastest (floor)
  33. rounding. Unlike the plain Python expression (x >> n), n is
  34. allowed to be negative, in which case a left shift is performed."""
  35. if n >= 0: return x >> n
  36. else: return x << (-n)
  37. def lshift(x, n):
  38. """For an integer x, calculate x << n. Unlike the plain Python
  39. expression (x << n), n is allowed to be negative, in which case a
  40. right shift with default (floor) rounding is performed."""
  41. if n >= 0: return x << n
  42. else: return x >> (-n)
  43. if BACKEND == 'sage':
  44. import operator
  45. rshift = operator.rshift
  46. lshift = operator.lshift
  47. def python_trailing(n):
  48. """Count the number of trailing zero bits in abs(n)."""
  49. if not n:
  50. return 0
  51. low_byte = n & 0xff
  52. if low_byte:
  53. return small_trailing[low_byte]
  54. t = 8
  55. n >>= 8
  56. while not n & 0xff:
  57. n >>= 8
  58. t += 8
  59. return t + small_trailing[n & 0xff]
  60. if BACKEND == 'gmpy':
  61. if gmpy.version() >= '2':
  62. def gmpy_trailing(n):
  63. """Count the number of trailing zero bits in abs(n) using gmpy."""
  64. if n: return MPZ(n).bit_scan1()
  65. else: return 0
  66. else:
  67. def gmpy_trailing(n):
  68. """Count the number of trailing zero bits in abs(n) using gmpy."""
  69. if n: return MPZ(n).scan1()
  70. else: return 0
  71. # Small powers of 2
  72. powers = [1<<_ for _ in range(300)]
  73. def python_bitcount(n):
  74. """Calculate bit size of the nonnegative integer n."""
  75. bc = bisect(powers, n)
  76. if bc != 300:
  77. return bc
  78. bc = int(math.log(n, 2)) - 4
  79. return bc + bctable[n>>bc]
  80. def gmpy_bitcount(n):
  81. """Calculate bit size of the nonnegative integer n."""
  82. if n: return MPZ(n).numdigits(2)
  83. else: return 0
  84. #def sage_bitcount(n):
  85. # if n: return MPZ(n).nbits()
  86. # else: return 0
  87. def sage_trailing(n):
  88. return MPZ(n).trailing_zero_bits()
  89. if BACKEND == 'gmpy':
  90. bitcount = gmpy_bitcount
  91. trailing = gmpy_trailing
  92. elif BACKEND == 'sage':
  93. sage_bitcount = sage_utils.bitcount
  94. bitcount = sage_bitcount
  95. trailing = sage_trailing
  96. else:
  97. bitcount = python_bitcount
  98. trailing = python_trailing
  99. if BACKEND == 'gmpy' and 'bit_length' in dir(gmpy):
  100. bitcount = gmpy.bit_length
  101. # Used to avoid slow function calls as far as possible
  102. trailtable = [trailing(n) for n in range(256)]
  103. bctable = [bitcount(n) for n in range(1024)]
  104. # TODO: speed up for bases 2, 4, 8, 16, ...
  105. def bin_to_radix(x, xbits, base, bdigits):
  106. """Changes radix of a fixed-point number; i.e., converts
  107. x * 2**xbits to floor(x * 10**bdigits)."""
  108. return x * (MPZ(base)**bdigits) >> xbits
  109. stddigits = '0123456789abcdefghijklmnopqrstuvwxyz'
  110. def small_numeral(n, base=10, digits=stddigits):
  111. """Return the string numeral of a positive integer in an arbitrary
  112. base. Most efficient for small input."""
  113. if base == 10:
  114. return str(n)
  115. digs = []
  116. while n:
  117. n, digit = divmod(n, base)
  118. digs.append(digits[digit])
  119. return "".join(digs[::-1])
  120. def numeral_python(n, base=10, size=0, digits=stddigits):
  121. """Represent the integer n as a string of digits in the given base.
  122. Recursive division is used to make this function about 3x faster
  123. than Python's str() for converting integers to decimal strings.
  124. The 'size' parameters specifies the number of digits in n; this
  125. number is only used to determine splitting points and need not be
  126. exact."""
  127. if n <= 0:
  128. if not n:
  129. return "0"
  130. return "-" + numeral(-n, base, size, digits)
  131. # Fast enough to do directly
  132. if size < 250:
  133. return small_numeral(n, base, digits)
  134. # Divide in half
  135. half = (size // 2) + (size & 1)
  136. A, B = divmod(n, base**half)
  137. ad = numeral(A, base, half, digits)
  138. bd = numeral(B, base, half, digits).rjust(half, "0")
  139. return ad + bd
  140. def numeral_gmpy(n, base=10, size=0, digits=stddigits):
  141. """Represent the integer n as a string of digits in the given base.
  142. Recursive division is used to make this function about 3x faster
  143. than Python's str() for converting integers to decimal strings.
  144. The 'size' parameters specifies the number of digits in n; this
  145. number is only used to determine splitting points and need not be
  146. exact."""
  147. if n < 0:
  148. return "-" + numeral(-n, base, size, digits)
  149. # gmpy.digits() may cause a segmentation fault when trying to convert
  150. # extremely large values to a string. The size limit may need to be
  151. # adjusted on some platforms, but 1500000 works on Windows and Linux.
  152. if size < 1500000:
  153. return gmpy.digits(n, base)
  154. # Divide in half
  155. half = (size // 2) + (size & 1)
  156. A, B = divmod(n, MPZ(base)**half)
  157. ad = numeral(A, base, half, digits)
  158. bd = numeral(B, base, half, digits).rjust(half, "0")
  159. return ad + bd
  160. if BACKEND == "gmpy":
  161. numeral = numeral_gmpy
  162. else:
  163. numeral = numeral_python
  164. _1_800 = 1<<800
  165. _1_600 = 1<<600
  166. _1_400 = 1<<400
  167. _1_200 = 1<<200
  168. _1_100 = 1<<100
  169. _1_50 = 1<<50
  170. def isqrt_small_python(x):
  171. """
  172. Correctly (floor) rounded integer square root, using
  173. division. Fast up to ~200 digits.
  174. """
  175. if not x:
  176. return x
  177. if x < _1_800:
  178. # Exact with IEEE double precision arithmetic
  179. if x < _1_50:
  180. return int(x**0.5)
  181. # Initial estimate can be any integer >= the true root; round up
  182. r = int(x**0.5 * 1.00000000000001) + 1
  183. else:
  184. bc = bitcount(x)
  185. n = bc//2
  186. r = int((x>>(2*n-100))**0.5+2)<<(n-50) # +2 is to round up
  187. # The following iteration now precisely computes floor(sqrt(x))
  188. # See e.g. Crandall & Pomerance, "Prime Numbers: A Computational
  189. # Perspective"
  190. while 1:
  191. y = (r+x//r)>>1
  192. if y >= r:
  193. return r
  194. r = y
  195. def isqrt_fast_python(x):
  196. """
  197. Fast approximate integer square root, computed using division-free
  198. Newton iteration for large x. For random integers the result is almost
  199. always correct (floor(sqrt(x))), but is 1 ulp too small with a roughly
  200. 0.1% probability. If x is very close to an exact square, the answer is
  201. 1 ulp wrong with high probability.
  202. With 0 guard bits, the largest error over a set of 10^5 random
  203. inputs of size 1-10^5 bits was 3 ulp. The use of 10 guard bits
  204. almost certainly guarantees a max 1 ulp error.
  205. """
  206. # Use direct division-based iteration if sqrt(x) < 2^400
  207. # Assume floating-point square root accurate to within 1 ulp, then:
  208. # 0 Newton iterations good to 52 bits
  209. # 1 Newton iterations good to 104 bits
  210. # 2 Newton iterations good to 208 bits
  211. # 3 Newton iterations good to 416 bits
  212. if x < _1_800:
  213. y = int(x**0.5)
  214. if x >= _1_100:
  215. y = (y + x//y) >> 1
  216. if x >= _1_200:
  217. y = (y + x//y) >> 1
  218. if x >= _1_400:
  219. y = (y + x//y) >> 1
  220. return y
  221. bc = bitcount(x)
  222. guard_bits = 10
  223. x <<= 2*guard_bits
  224. bc += 2*guard_bits
  225. bc += (bc&1)
  226. hbc = bc//2
  227. startprec = min(50, hbc)
  228. # Newton iteration for 1/sqrt(x), with floating-point starting value
  229. r = int(2.0**(2*startprec) * (x >> (bc-2*startprec)) ** -0.5)
  230. pp = startprec
  231. for p in giant_steps(startprec, hbc):
  232. # r**2, scaled from real size 2**(-bc) to 2**p
  233. r2 = (r*r) >> (2*pp - p)
  234. # x*r**2, scaled from real size ~1.0 to 2**p
  235. xr2 = ((x >> (bc-p)) * r2) >> p
  236. # New value of r, scaled from real size 2**(-bc/2) to 2**p
  237. r = (r * ((3<<p) - xr2)) >> (pp+1)
  238. pp = p
  239. # (1/sqrt(x))*x = sqrt(x)
  240. return (r*(x>>hbc)) >> (p+guard_bits)
  241. def sqrtrem_python(x):
  242. """Correctly rounded integer (floor) square root with remainder."""
  243. # to check cutoff:
  244. # plot(lambda x: timing(isqrt, 2**int(x)), [0,2000])
  245. if x < _1_600:
  246. y = isqrt_small_python(x)
  247. return y, x - y*y
  248. y = isqrt_fast_python(x) + 1
  249. rem = x - y*y
  250. # Correct remainder
  251. while rem < 0:
  252. y -= 1
  253. rem += (1+2*y)
  254. else:
  255. if rem:
  256. while rem > 2*(1+y):
  257. y += 1
  258. rem -= (1+2*y)
  259. return y, rem
  260. def isqrt_python(x):
  261. """Integer square root with correct (floor) rounding."""
  262. return sqrtrem_python(x)[0]
  263. def sqrt_fixed(x, prec):
  264. return isqrt_fast(x<<prec)
  265. sqrt_fixed2 = sqrt_fixed
  266. if BACKEND == 'gmpy':
  267. if gmpy.version() >= '2':
  268. isqrt_small = isqrt_fast = isqrt = gmpy.isqrt
  269. sqrtrem = gmpy.isqrt_rem
  270. else:
  271. isqrt_small = isqrt_fast = isqrt = gmpy.sqrt
  272. sqrtrem = gmpy.sqrtrem
  273. elif BACKEND == 'sage':
  274. isqrt_small = isqrt_fast = isqrt = \
  275. getattr(sage_utils, "isqrt", lambda n: MPZ(n).isqrt())
  276. sqrtrem = lambda n: MPZ(n).sqrtrem()
  277. else:
  278. isqrt_small = isqrt_small_python
  279. isqrt_fast = isqrt_fast_python
  280. isqrt = isqrt_python
  281. sqrtrem = sqrtrem_python
  282. def ifib(n, _cache={}):
  283. """Computes the nth Fibonacci number as an integer, for
  284. integer n."""
  285. if n < 0:
  286. return (-1)**(-n+1) * ifib(-n)
  287. if n in _cache:
  288. return _cache[n]
  289. m = n
  290. # Use Dijkstra's logarithmic algorithm
  291. # The following implementation is basically equivalent to
  292. # http://en.literateprograms.org/Fibonacci_numbers_(Scheme)
  293. a, b, p, q = MPZ_ONE, MPZ_ZERO, MPZ_ZERO, MPZ_ONE
  294. while n:
  295. if n & 1:
  296. aq = a*q
  297. a, b = b*q+aq+a*p, b*p+aq
  298. n -= 1
  299. else:
  300. qq = q*q
  301. p, q = p*p+qq, qq+2*p*q
  302. n >>= 1
  303. if m < 250:
  304. _cache[m] = b
  305. return b
  306. MAX_FACTORIAL_CACHE = 1000
  307. def ifac(n, memo={0:1, 1:1}):
  308. """Return n factorial (for integers n >= 0 only)."""
  309. f = memo.get(n)
  310. if f:
  311. return f
  312. k = len(memo)
  313. p = memo[k-1]
  314. MAX = MAX_FACTORIAL_CACHE
  315. while k <= n:
  316. p *= k
  317. if k <= MAX:
  318. memo[k] = p
  319. k += 1
  320. return p
  321. def ifac2(n, memo_pair=[{0:1}, {1:1}]):
  322. """Return n!! (double factorial), integers n >= 0 only."""
  323. memo = memo_pair[n&1]
  324. f = memo.get(n)
  325. if f:
  326. return f
  327. k = max(memo)
  328. p = memo[k]
  329. MAX = MAX_FACTORIAL_CACHE
  330. while k < n:
  331. k += 2
  332. p *= k
  333. if k <= MAX:
  334. memo[k] = p
  335. return p
  336. if BACKEND == 'gmpy':
  337. ifac = gmpy.fac
  338. elif BACKEND == 'sage':
  339. ifac = lambda n: int(sage.factorial(n))
  340. ifib = sage.fibonacci
  341. def list_primes(n):
  342. n = n + 1
  343. sieve = list(xrange(n))
  344. sieve[:2] = [0, 0]
  345. for i in xrange(2, int(n**0.5)+1):
  346. if sieve[i]:
  347. for j in xrange(i**2, n, i):
  348. sieve[j] = 0
  349. return [p for p in sieve if p]
  350. if BACKEND == 'sage':
  351. # Note: it is *VERY* important for performance that we convert
  352. # the list to Python ints.
  353. def list_primes(n):
  354. return [int(_) for _ in sage.primes(n+1)]
  355. small_odd_primes = (3,5,7,11,13,17,19,23,29,31,37,41,43,47)
  356. small_odd_primes_set = set(small_odd_primes)
  357. def isprime(n):
  358. """
  359. Determines whether n is a prime number. A probabilistic test is
  360. performed if n is very large. No special trick is used for detecting
  361. perfect powers.
  362. >>> sum(list_primes(100000))
  363. 454396537
  364. >>> sum(n*isprime(n) for n in range(100000))
  365. 454396537
  366. """
  367. n = int(n)
  368. if not n & 1:
  369. return n == 2
  370. if n < 50:
  371. return n in small_odd_primes_set
  372. for p in small_odd_primes:
  373. if not n % p:
  374. return False
  375. m = n-1
  376. s = trailing(m)
  377. d = m >> s
  378. def test(a):
  379. x = pow(a,d,n)
  380. if x == 1 or x == m:
  381. return True
  382. for r in xrange(1,s):
  383. x = x**2 % n
  384. if x == m:
  385. return True
  386. return False
  387. # See http://primes.utm.edu/prove/prove2_3.html
  388. if n < 1373653:
  389. witnesses = [2,3]
  390. elif n < 341550071728321:
  391. witnesses = [2,3,5,7,11,13,17]
  392. else:
  393. witnesses = small_odd_primes
  394. for a in witnesses:
  395. if not test(a):
  396. return False
  397. return True
  398. def moebius(n):
  399. """
  400. Evaluates the Moebius function which is `mu(n) = (-1)^k` if `n`
  401. is a product of `k` distinct primes and `mu(n) = 0` otherwise.
  402. TODO: speed up using factorization
  403. """
  404. n = abs(int(n))
  405. if n < 2:
  406. return n
  407. factors = []
  408. for p in xrange(2, n+1):
  409. if not (n % p):
  410. if not (n % p**2):
  411. return 0
  412. if not sum(p % f for f in factors):
  413. factors.append(p)
  414. return (-1)**len(factors)
  415. def gcd(*args):
  416. a = 0
  417. for b in args:
  418. if a:
  419. while b:
  420. a, b = b, a % b
  421. else:
  422. a = b
  423. return a
  424. # Comment by Juan Arias de Reyna:
  425. #
  426. # I learn this method to compute EulerE[2n] from van de Lune.
  427. #
  428. # We apply the formula EulerE[2n] = (-1)^n 2**(-2n) sum_{j=0}^n a(2n,2j+1)
  429. #
  430. # where the numbers a(n,j) vanish for j > n+1 or j <= -1 and satisfies
  431. #
  432. # a(0,-1) = a(0,0) = 0; a(0,1)= 1; a(0,2) = a(0,3) = 0
  433. #
  434. # a(n,j) = a(n-1,j) when n+j is even
  435. # a(n,j) = (j-1) a(n-1,j-1) + (j+1) a(n-1,j+1) when n+j is odd
  436. #
  437. #
  438. # But we can use only one array unidimensional a(j) since to compute
  439. # a(n,j) we only need to know a(n-1,k) where k and j are of different parity
  440. # and we have not to conserve the used values.
  441. #
  442. # We cached up the values of Euler numbers to sufficiently high order.
  443. #
  444. # Important Observation: If we pretend to use the numbers
  445. # EulerE[1], EulerE[2], ... , EulerE[n]
  446. # it is convenient to compute first EulerE[n], since the algorithm
  447. # computes first all
  448. # the previous ones, and keeps them in the CACHE
  449. MAX_EULER_CACHE = 500
  450. def eulernum(m, _cache={0:MPZ_ONE}):
  451. r"""
  452. Computes the Euler numbers `E(n)`, which can be defined as
  453. coefficients of the Taylor expansion of `1/cosh x`:
  454. .. math ::
  455. \frac{1}{\cosh x} = \sum_{n=0}^\infty \frac{E_n}{n!} x^n
  456. Example::
  457. >>> [int(eulernum(n)) for n in range(11)]
  458. [1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521]
  459. >>> [int(eulernum(n)) for n in range(11)] # test cache
  460. [1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521]
  461. """
  462. # for odd m > 1, the Euler numbers are zero
  463. if m & 1:
  464. return MPZ_ZERO
  465. f = _cache.get(m)
  466. if f:
  467. return f
  468. MAX = MAX_EULER_CACHE
  469. n = m
  470. a = [MPZ(_) for _ in [0,0,1,0,0,0]]
  471. for n in range(1, m+1):
  472. for j in range(n+1, -1, -2):
  473. a[j+1] = (j-1)*a[j] + (j+1)*a[j+2]
  474. a.append(0)
  475. suma = 0
  476. for k in range(n+1, -1, -2):
  477. suma += a[k+1]
  478. if n <= MAX:
  479. _cache[n] = ((-1)**(n//2))*(suma // 2**n)
  480. if n == m:
  481. return ((-1)**(n//2))*suma // 2**n
  482. def stirling1(n, k):
  483. """
  484. Stirling number of the first kind.
  485. """
  486. if n < 0 or k < 0:
  487. raise ValueError
  488. if k >= n:
  489. return MPZ(n == k)
  490. if k < 1:
  491. return MPZ_ZERO
  492. L = [MPZ_ZERO] * (k+1)
  493. L[1] = MPZ_ONE
  494. for m in xrange(2, n+1):
  495. for j in xrange(min(k, m), 0, -1):
  496. L[j] = (m-1) * L[j] + L[j-1]
  497. return (-1)**(n+k) * L[k]
  498. def stirling2(n, k):
  499. """
  500. Stirling number of the second kind.
  501. """
  502. if n < 0 or k < 0:
  503. raise ValueError
  504. if k >= n:
  505. return MPZ(n == k)
  506. if k <= 1:
  507. return MPZ(k == 1)
  508. s = MPZ_ZERO
  509. t = MPZ_ONE
  510. for j in xrange(k+1):
  511. if (k + j) & 1:
  512. s -= t * MPZ(j)**n
  513. else:
  514. s += t * MPZ(j)**n
  515. t = t * (k - j) // (j + 1)
  516. return s // ifac(k)