modular.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. from functools import reduce
  2. from sympy.core.mul import prod
  3. from sympy.core.numbers import igcdex, igcd
  4. from sympy.ntheory.primetest import isprime
  5. from sympy.polys.domains import ZZ
  6. from sympy.polys.galoistools import gf_crt, gf_crt1, gf_crt2
  7. from sympy.utilities.misc import as_int
  8. def symmetric_residue(a, m):
  9. """Return the residual mod m such that it is within half of the modulus.
  10. >>> from sympy.ntheory.modular import symmetric_residue
  11. >>> symmetric_residue(1, 6)
  12. 1
  13. >>> symmetric_residue(4, 6)
  14. -2
  15. """
  16. if a <= m // 2:
  17. return a
  18. return a - m
  19. def crt(m, v, symmetric=False, check=True):
  20. r"""Chinese Remainder Theorem.
  21. The moduli in m are assumed to be pairwise coprime. The output
  22. is then an integer f, such that f = v_i mod m_i for each pair out
  23. of v and m. If ``symmetric`` is False a positive integer will be
  24. returned, else \|f\| will be less than or equal to the LCM of the
  25. moduli, and thus f may be negative.
  26. If the moduli are not co-prime the correct result will be returned
  27. if/when the test of the result is found to be incorrect. This result
  28. will be None if there is no solution.
  29. The keyword ``check`` can be set to False if it is known that the moduli
  30. are coprime.
  31. Examples
  32. ========
  33. As an example consider a set of residues ``U = [49, 76, 65]``
  34. and a set of moduli ``M = [99, 97, 95]``. Then we have::
  35. >>> from sympy.ntheory.modular import crt
  36. >>> crt([99, 97, 95], [49, 76, 65])
  37. (639985, 912285)
  38. This is the correct result because::
  39. >>> [639985 % m for m in [99, 97, 95]]
  40. [49, 76, 65]
  41. If the moduli are not co-prime, you may receive an incorrect result
  42. if you use ``check=False``:
  43. >>> crt([12, 6, 17], [3, 4, 2], check=False)
  44. (954, 1224)
  45. >>> [954 % m for m in [12, 6, 17]]
  46. [6, 0, 2]
  47. >>> crt([12, 6, 17], [3, 4, 2]) is None
  48. True
  49. >>> crt([3, 6], [2, 5])
  50. (5, 6)
  51. Note: the order of gf_crt's arguments is reversed relative to crt,
  52. and that solve_congruence takes residue, modulus pairs.
  53. Programmer's note: rather than checking that all pairs of moduli share
  54. no GCD (an O(n**2) test) and rather than factoring all moduli and seeing
  55. that there is no factor in common, a check that the result gives the
  56. indicated residuals is performed -- an O(n) operation.
  57. See Also
  58. ========
  59. solve_congruence
  60. sympy.polys.galoistools.gf_crt : low level crt routine used by this routine
  61. """
  62. if check:
  63. m = list(map(as_int, m))
  64. v = list(map(as_int, v))
  65. result = gf_crt(v, m, ZZ)
  66. mm = prod(m)
  67. if check:
  68. if not all(v % m == result % m for v, m in zip(v, m)):
  69. result = solve_congruence(*list(zip(v, m)),
  70. check=False, symmetric=symmetric)
  71. if result is None:
  72. return result
  73. result, mm = result
  74. if symmetric:
  75. return symmetric_residue(result, mm), mm
  76. return result, mm
  77. def crt1(m):
  78. """First part of Chinese Remainder Theorem, for multiple application.
  79. Examples
  80. ========
  81. >>> from sympy.ntheory.modular import crt1
  82. >>> crt1([18, 42, 6])
  83. (4536, [252, 108, 756], [0, 2, 0])
  84. """
  85. return gf_crt1(m, ZZ)
  86. def crt2(m, v, mm, e, s, symmetric=False):
  87. """Second part of Chinese Remainder Theorem, for multiple application.
  88. Examples
  89. ========
  90. >>> from sympy.ntheory.modular import crt1, crt2
  91. >>> mm, e, s = crt1([18, 42, 6])
  92. >>> crt2([18, 42, 6], [0, 0, 0], mm, e, s)
  93. (0, 4536)
  94. """
  95. result = gf_crt2(v, m, mm, e, s, ZZ)
  96. if symmetric:
  97. return symmetric_residue(result, mm), mm
  98. return result, mm
  99. def solve_congruence(*remainder_modulus_pairs, **hint):
  100. """Compute the integer ``n`` that has the residual ``ai`` when it is
  101. divided by ``mi`` where the ``ai`` and ``mi`` are given as pairs to
  102. this function: ((a1, m1), (a2, m2), ...). If there is no solution,
  103. return None. Otherwise return ``n`` and its modulus.
  104. The ``mi`` values need not be co-prime. If it is known that the moduli are
  105. not co-prime then the hint ``check`` can be set to False (default=True) and
  106. the check for a quicker solution via crt() (valid when the moduli are
  107. co-prime) will be skipped.
  108. If the hint ``symmetric`` is True (default is False), the value of ``n``
  109. will be within 1/2 of the modulus, possibly negative.
  110. Examples
  111. ========
  112. >>> from sympy.ntheory.modular import solve_congruence
  113. What number is 2 mod 3, 3 mod 5 and 2 mod 7?
  114. >>> solve_congruence((2, 3), (3, 5), (2, 7))
  115. (23, 105)
  116. >>> [23 % m for m in [3, 5, 7]]
  117. [2, 3, 2]
  118. If you prefer to work with all remainder in one list and
  119. all moduli in another, send the arguments like this:
  120. >>> solve_congruence(*zip((2, 3, 2), (3, 5, 7)))
  121. (23, 105)
  122. The moduli need not be co-prime; in this case there may or
  123. may not be a solution:
  124. >>> solve_congruence((2, 3), (4, 6)) is None
  125. True
  126. >>> solve_congruence((2, 3), (5, 6))
  127. (5, 6)
  128. The symmetric flag will make the result be within 1/2 of the modulus:
  129. >>> solve_congruence((2, 3), (5, 6), symmetric=True)
  130. (-1, 6)
  131. See Also
  132. ========
  133. crt : high level routine implementing the Chinese Remainder Theorem
  134. """
  135. def combine(c1, c2):
  136. """Return the tuple (a, m) which satisfies the requirement
  137. that n = a + i*m satisfy n = a1 + j*m1 and n = a2 = k*m2.
  138. References
  139. ==========
  140. .. [1] https://en.wikipedia.org/wiki/Method_of_successive_substitution
  141. """
  142. a1, m1 = c1
  143. a2, m2 = c2
  144. a, b, c = m1, a2 - a1, m2
  145. g = reduce(igcd, [a, b, c])
  146. a, b, c = [i//g for i in [a, b, c]]
  147. if a != 1:
  148. inv_a, _, g = igcdex(a, c)
  149. if g != 1:
  150. return None
  151. b *= inv_a
  152. a, m = a1 + m1*b, m1*c
  153. return a, m
  154. rm = remainder_modulus_pairs
  155. symmetric = hint.get('symmetric', False)
  156. if hint.get('check', True):
  157. rm = [(as_int(r), as_int(m)) for r, m in rm]
  158. # ignore redundant pairs but raise an error otherwise; also
  159. # make sure that a unique set of bases is sent to gf_crt if
  160. # they are all prime.
  161. #
  162. # The routine will work out less-trivial violations and
  163. # return None, e.g. for the pairs (1,3) and (14,42) there
  164. # is no answer because 14 mod 42 (having a gcd of 14) implies
  165. # (14/2) mod (42/2), (14/7) mod (42/7) and (14/14) mod (42/14)
  166. # which, being 0 mod 3, is inconsistent with 1 mod 3. But to
  167. # preprocess the input beyond checking of another pair with 42
  168. # or 3 as the modulus (for this example) is not necessary.
  169. uniq = {}
  170. for r, m in rm:
  171. r %= m
  172. if m in uniq:
  173. if r != uniq[m]:
  174. return None
  175. continue
  176. uniq[m] = r
  177. rm = [(r, m) for m, r in uniq.items()]
  178. del uniq
  179. # if the moduli are co-prime, the crt will be significantly faster;
  180. # checking all pairs for being co-prime gets to be slow but a prime
  181. # test is a good trade-off
  182. if all(isprime(m) for r, m in rm):
  183. r, m = list(zip(*rm))
  184. return crt(m, r, symmetric=symmetric, check=False)
  185. rv = (0, 1)
  186. for rmi in rm:
  187. rv = combine(rv, rmi)
  188. if rv is None:
  189. break
  190. n, m = rv
  191. n = n % m
  192. else:
  193. if symmetric:
  194. return symmetric_residue(n, m), m
  195. return n, m