123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- """Implementation of :class:`ModularInteger` class. """
- from typing import Any, Dict as tDict, Tuple as tTuple, Type
- import operator
- from sympy.polys.polyutils import PicklableWithSlots
- from sympy.polys.polyerrors import CoercionFailed
- from sympy.polys.domains.domainelement import DomainElement
- from sympy.utilities import public
- @public
- class ModularInteger(PicklableWithSlots, DomainElement):
- """A class representing a modular integer. """
- mod, dom, sym, _parent = None, None, None, None
- __slots__ = ('val',)
- def parent(self):
- return self._parent
- def __init__(self, val):
- if isinstance(val, self.__class__):
- self.val = val.val % self.mod
- else:
- self.val = self.dom.convert(val) % self.mod
- def __hash__(self):
- return hash((self.val, self.mod))
- def __repr__(self):
- return "%s(%s)" % (self.__class__.__name__, self.val)
- def __str__(self):
- return "%s mod %s" % (self.val, self.mod)
- def __int__(self):
- return int(self.to_int())
- def to_int(self):
- if self.sym:
- if self.val <= self.mod // 2:
- return self.val
- else:
- return self.val - self.mod
- else:
- return self.val
- def __pos__(self):
- return self
- def __neg__(self):
- return self.__class__(-self.val)
- @classmethod
- def _get_val(cls, other):
- if isinstance(other, cls):
- return other.val
- else:
- try:
- return cls.dom.convert(other)
- except CoercionFailed:
- return None
- def __add__(self, other):
- val = self._get_val(other)
- if val is not None:
- return self.__class__(self.val + val)
- else:
- return NotImplemented
- def __radd__(self, other):
- return self.__add__(other)
- def __sub__(self, other):
- val = self._get_val(other)
- if val is not None:
- return self.__class__(self.val - val)
- else:
- return NotImplemented
- def __rsub__(self, other):
- return (-self).__add__(other)
- def __mul__(self, other):
- val = self._get_val(other)
- if val is not None:
- return self.__class__(self.val * val)
- else:
- return NotImplemented
- def __rmul__(self, other):
- return self.__mul__(other)
- def __truediv__(self, other):
- val = self._get_val(other)
- if val is not None:
- return self.__class__(self.val * self._invert(val))
- else:
- return NotImplemented
- def __rtruediv__(self, other):
- return self.invert().__mul__(other)
- def __mod__(self, other):
- val = self._get_val(other)
- if val is not None:
- return self.__class__(self.val % val)
- else:
- return NotImplemented
- def __rmod__(self, other):
- val = self._get_val(other)
- if val is not None:
- return self.__class__(val % self.val)
- else:
- return NotImplemented
- def __pow__(self, exp):
- if not exp:
- return self.__class__(self.dom.one)
- if exp < 0:
- val, exp = self.invert().val, -exp
- else:
- val = self.val
- return self.__class__(pow(val, int(exp), self.mod))
- def _compare(self, other, op):
- val = self._get_val(other)
- if val is not None:
- return op(self.val, val % self.mod)
- else:
- return NotImplemented
- def __eq__(self, other):
- return self._compare(other, operator.eq)
- def __ne__(self, other):
- return self._compare(other, operator.ne)
- def __lt__(self, other):
- return self._compare(other, operator.lt)
- def __le__(self, other):
- return self._compare(other, operator.le)
- def __gt__(self, other):
- return self._compare(other, operator.gt)
- def __ge__(self, other):
- return self._compare(other, operator.ge)
- def __bool__(self):
- return bool(self.val)
- @classmethod
- def _invert(cls, value):
- return cls.dom.invert(value, cls.mod)
- def invert(self):
- return self.__class__(self._invert(self.val))
- _modular_integer_cache = {} # type: tDict[tTuple[Any, Any, Any], Type[ModularInteger]]
- def ModularIntegerFactory(_mod, _dom, _sym, parent):
- """Create custom class for specific integer modulus."""
- try:
- _mod = _dom.convert(_mod)
- except CoercionFailed:
- ok = False
- else:
- ok = True
- if not ok or _mod < 1:
- raise ValueError("modulus must be a positive integer, got %s" % _mod)
- key = _mod, _dom, _sym
- try:
- cls = _modular_integer_cache[key]
- except KeyError:
- class cls(ModularInteger):
- mod, dom, sym = _mod, _dom, _sym
- _parent = parent
- if _sym:
- cls.__name__ = "SymmetricModularIntegerMod%s" % _mod
- else:
- cls.__name__ = "ModularIntegerMod%s" % _mod
- _modular_integer_cache[key] = cls
- return cls
|