modularinteger.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. """Implementation of :class:`ModularInteger` class. """
  2. from typing import Any, Dict as tDict, Tuple as tTuple, Type
  3. import operator
  4. from sympy.polys.polyutils import PicklableWithSlots
  5. from sympy.polys.polyerrors import CoercionFailed
  6. from sympy.polys.domains.domainelement import DomainElement
  7. from sympy.utilities import public
  8. @public
  9. class ModularInteger(PicklableWithSlots, DomainElement):
  10. """A class representing a modular integer. """
  11. mod, dom, sym, _parent = None, None, None, None
  12. __slots__ = ('val',)
  13. def parent(self):
  14. return self._parent
  15. def __init__(self, val):
  16. if isinstance(val, self.__class__):
  17. self.val = val.val % self.mod
  18. else:
  19. self.val = self.dom.convert(val) % self.mod
  20. def __hash__(self):
  21. return hash((self.val, self.mod))
  22. def __repr__(self):
  23. return "%s(%s)" % (self.__class__.__name__, self.val)
  24. def __str__(self):
  25. return "%s mod %s" % (self.val, self.mod)
  26. def __int__(self):
  27. return int(self.to_int())
  28. def to_int(self):
  29. if self.sym:
  30. if self.val <= self.mod // 2:
  31. return self.val
  32. else:
  33. return self.val - self.mod
  34. else:
  35. return self.val
  36. def __pos__(self):
  37. return self
  38. def __neg__(self):
  39. return self.__class__(-self.val)
  40. @classmethod
  41. def _get_val(cls, other):
  42. if isinstance(other, cls):
  43. return other.val
  44. else:
  45. try:
  46. return cls.dom.convert(other)
  47. except CoercionFailed:
  48. return None
  49. def __add__(self, other):
  50. val = self._get_val(other)
  51. if val is not None:
  52. return self.__class__(self.val + val)
  53. else:
  54. return NotImplemented
  55. def __radd__(self, other):
  56. return self.__add__(other)
  57. def __sub__(self, other):
  58. val = self._get_val(other)
  59. if val is not None:
  60. return self.__class__(self.val - val)
  61. else:
  62. return NotImplemented
  63. def __rsub__(self, other):
  64. return (-self).__add__(other)
  65. def __mul__(self, other):
  66. val = self._get_val(other)
  67. if val is not None:
  68. return self.__class__(self.val * val)
  69. else:
  70. return NotImplemented
  71. def __rmul__(self, other):
  72. return self.__mul__(other)
  73. def __truediv__(self, other):
  74. val = self._get_val(other)
  75. if val is not None:
  76. return self.__class__(self.val * self._invert(val))
  77. else:
  78. return NotImplemented
  79. def __rtruediv__(self, other):
  80. return self.invert().__mul__(other)
  81. def __mod__(self, other):
  82. val = self._get_val(other)
  83. if val is not None:
  84. return self.__class__(self.val % val)
  85. else:
  86. return NotImplemented
  87. def __rmod__(self, other):
  88. val = self._get_val(other)
  89. if val is not None:
  90. return self.__class__(val % self.val)
  91. else:
  92. return NotImplemented
  93. def __pow__(self, exp):
  94. if not exp:
  95. return self.__class__(self.dom.one)
  96. if exp < 0:
  97. val, exp = self.invert().val, -exp
  98. else:
  99. val = self.val
  100. return self.__class__(pow(val, int(exp), self.mod))
  101. def _compare(self, other, op):
  102. val = self._get_val(other)
  103. if val is not None:
  104. return op(self.val, val % self.mod)
  105. else:
  106. return NotImplemented
  107. def __eq__(self, other):
  108. return self._compare(other, operator.eq)
  109. def __ne__(self, other):
  110. return self._compare(other, operator.ne)
  111. def __lt__(self, other):
  112. return self._compare(other, operator.lt)
  113. def __le__(self, other):
  114. return self._compare(other, operator.le)
  115. def __gt__(self, other):
  116. return self._compare(other, operator.gt)
  117. def __ge__(self, other):
  118. return self._compare(other, operator.ge)
  119. def __bool__(self):
  120. return bool(self.val)
  121. @classmethod
  122. def _invert(cls, value):
  123. return cls.dom.invert(value, cls.mod)
  124. def invert(self):
  125. return self.__class__(self._invert(self.val))
  126. _modular_integer_cache = {} # type: tDict[tTuple[Any, Any, Any], Type[ModularInteger]]
  127. def ModularIntegerFactory(_mod, _dom, _sym, parent):
  128. """Create custom class for specific integer modulus."""
  129. try:
  130. _mod = _dom.convert(_mod)
  131. except CoercionFailed:
  132. ok = False
  133. else:
  134. ok = True
  135. if not ok or _mod < 1:
  136. raise ValueError("modulus must be a positive integer, got %s" % _mod)
  137. key = _mod, _dom, _sym
  138. try:
  139. cls = _modular_integer_cache[key]
  140. except KeyError:
  141. class cls(ModularInteger):
  142. mod, dom, sym = _mod, _dom, _sym
  143. _parent = parent
  144. if _sym:
  145. cls.__name__ = "SymmetricModularIntegerMod%s" % _mod
  146. else:
  147. cls.__name__ = "ModularIntegerMod%s" % _mod
  148. _modular_integer_cache[key] = cls
  149. return cls