assume.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. """A module which implements predicates and assumption context."""
  2. from contextlib import contextmanager
  3. import inspect
  4. from sympy.core.assumptions import ManagedProperties
  5. from sympy.core.symbol import Str
  6. from sympy.core.sympify import _sympify
  7. from sympy.logic.boolalg import Boolean, false, true
  8. from sympy.multipledispatch.dispatcher import Dispatcher, str_signature
  9. from sympy.utilities.exceptions import sympy_deprecation_warning
  10. from sympy.utilities.iterables import is_sequence
  11. from sympy.utilities.source import get_class
  12. class AssumptionsContext(set):
  13. """
  14. Set containing default assumptions which are applied to the ``ask()``
  15. function.
  16. Explanation
  17. ===========
  18. This is used to represent global assumptions, but you can also use this
  19. class to create your own local assumptions contexts. It is basically a thin
  20. wrapper to Python's set, so see its documentation for advanced usage.
  21. Examples
  22. ========
  23. The default assumption context is ``global_assumptions``, which is initially empty:
  24. >>> from sympy import ask, Q
  25. >>> from sympy.assumptions import global_assumptions
  26. >>> global_assumptions
  27. AssumptionsContext()
  28. You can add default assumptions:
  29. >>> from sympy.abc import x
  30. >>> global_assumptions.add(Q.real(x))
  31. >>> global_assumptions
  32. AssumptionsContext({Q.real(x)})
  33. >>> ask(Q.real(x))
  34. True
  35. And remove them:
  36. >>> global_assumptions.remove(Q.real(x))
  37. >>> print(ask(Q.real(x)))
  38. None
  39. The ``clear()`` method removes every assumption:
  40. >>> global_assumptions.add(Q.positive(x))
  41. >>> global_assumptions
  42. AssumptionsContext({Q.positive(x)})
  43. >>> global_assumptions.clear()
  44. >>> global_assumptions
  45. AssumptionsContext()
  46. See Also
  47. ========
  48. assuming
  49. """
  50. def add(self, *assumptions):
  51. """Add assumptions."""
  52. for a in assumptions:
  53. super().add(a)
  54. def _sympystr(self, printer):
  55. if not self:
  56. return "%s()" % self.__class__.__name__
  57. return "{}({})".format(self.__class__.__name__, printer._print_set(self))
  58. global_assumptions = AssumptionsContext()
  59. class AppliedPredicate(Boolean):
  60. """
  61. The class of expressions resulting from applying ``Predicate`` to
  62. the arguments. ``AppliedPredicate`` merely wraps its argument and
  63. remain unevaluated. To evaluate it, use the ``ask()`` function.
  64. Examples
  65. ========
  66. >>> from sympy import Q, ask
  67. >>> Q.integer(1)
  68. Q.integer(1)
  69. The ``function`` attribute returns the predicate, and the ``arguments``
  70. attribute returns the tuple of arguments.
  71. >>> type(Q.integer(1))
  72. <class 'sympy.assumptions.assume.AppliedPredicate'>
  73. >>> Q.integer(1).function
  74. Q.integer
  75. >>> Q.integer(1).arguments
  76. (1,)
  77. Applied predicates can be evaluated to a boolean value with ``ask``:
  78. >>> ask(Q.integer(1))
  79. True
  80. """
  81. __slots__ = ()
  82. def __new__(cls, predicate, *args):
  83. if not isinstance(predicate, Predicate):
  84. raise TypeError("%s is not a Predicate." % predicate)
  85. args = map(_sympify, args)
  86. return super().__new__(cls, predicate, *args)
  87. @property
  88. def arg(self):
  89. """
  90. Return the expression used by this assumption.
  91. Examples
  92. ========
  93. >>> from sympy import Q, Symbol
  94. >>> x = Symbol('x')
  95. >>> a = Q.integer(x + 1)
  96. >>> a.arg
  97. x + 1
  98. """
  99. # Will be deprecated
  100. args = self._args
  101. if len(args) == 2:
  102. # backwards compatibility
  103. return args[1]
  104. raise TypeError("'arg' property is allowed only for unary predicates.")
  105. @property
  106. def function(self):
  107. """
  108. Return the predicate.
  109. """
  110. # Will be changed to self.args[0] after args overridding is removed
  111. return self._args[0]
  112. @property
  113. def arguments(self):
  114. """
  115. Return the arguments which are applied to the predicate.
  116. """
  117. # Will be changed to self.args[1:] after args overridding is removed
  118. return self._args[1:]
  119. def _eval_ask(self, assumptions):
  120. return self.function.eval(self.arguments, assumptions)
  121. @property
  122. def binary_symbols(self):
  123. from .ask import Q
  124. if self.function == Q.is_true:
  125. i = self.arguments[0]
  126. if i.is_Boolean or i.is_Symbol:
  127. return i.binary_symbols
  128. if self.function in (Q.eq, Q.ne):
  129. if true in self.arguments or false in self.arguments:
  130. if self.arguments[0].is_Symbol:
  131. return {self.arguments[0]}
  132. elif self.arguments[1].is_Symbol:
  133. return {self.arguments[1]}
  134. return set()
  135. class PredicateMeta(ManagedProperties):
  136. def __new__(cls, clsname, bases, dct):
  137. # If handler is not defined, assign empty dispatcher.
  138. if "handler" not in dct:
  139. name = f"Ask{clsname.capitalize()}Handler"
  140. handler = Dispatcher(name, doc="Handler for key %s" % name)
  141. dct["handler"] = handler
  142. dct["_orig_doc"] = dct.get("__doc__", "")
  143. return super().__new__(cls, clsname, bases, dct)
  144. @property
  145. def __doc__(cls):
  146. handler = cls.handler
  147. doc = cls._orig_doc
  148. if cls is not Predicate and handler is not None:
  149. doc += "Handler\n"
  150. doc += " =======\n\n"
  151. # Append the handler's doc without breaking sphinx documentation.
  152. docs = [" Multiply dispatched method: %s" % handler.name]
  153. if handler.doc:
  154. for line in handler.doc.splitlines():
  155. if not line:
  156. continue
  157. docs.append(" %s" % line)
  158. other = []
  159. for sig in handler.ordering[::-1]:
  160. func = handler.funcs[sig]
  161. if func.__doc__:
  162. s = ' Inputs: <%s>' % str_signature(sig)
  163. lines = []
  164. for line in func.__doc__.splitlines():
  165. lines.append(" %s" % line)
  166. s += "\n".join(lines)
  167. docs.append(s)
  168. else:
  169. other.append(str_signature(sig))
  170. if other:
  171. othersig = " Other signatures:"
  172. for line in other:
  173. othersig += "\n * %s" % line
  174. docs.append(othersig)
  175. doc += '\n\n'.join(docs)
  176. return doc
  177. class Predicate(Boolean, metaclass=PredicateMeta):
  178. """
  179. Base class for mathematical predicates. It also serves as a
  180. constructor for undefined predicate objects.
  181. Explanation
  182. ===========
  183. Predicate is a function that returns a boolean value [1].
  184. Predicate function is object, and it is instance of predicate class.
  185. When a predicate is applied to arguments, ``AppliedPredicate``
  186. instance is returned. This merely wraps the argument and remain
  187. unevaluated. To obtain the truth value of applied predicate, use the
  188. function ``ask``.
  189. Evaluation of predicate is done by multiple dispatching. You can
  190. register new handler to the predicate to support new types.
  191. Every predicate in SymPy can be accessed via the property of ``Q``.
  192. For example, ``Q.even`` returns the predicate which checks if the
  193. argument is even number.
  194. To define a predicate which can be evaluated, you must subclass this
  195. class, make an instance of it, and register it to ``Q``. After then,
  196. dispatch the handler by argument types.
  197. If you directly construct predicate using this class, you will get
  198. ``UndefinedPredicate`` which cannot be dispatched. This is useful
  199. when you are building boolean expressions which do not need to be
  200. evaluated.
  201. Examples
  202. ========
  203. Applying and evaluating to boolean value:
  204. >>> from sympy import Q, ask
  205. >>> ask(Q.prime(7))
  206. True
  207. You can define a new predicate by subclassing and dispatching. Here,
  208. we define a predicate for sexy primes [2] as an example.
  209. >>> from sympy import Predicate, Integer
  210. >>> class SexyPrimePredicate(Predicate):
  211. ... name = "sexyprime"
  212. >>> Q.sexyprime = SexyPrimePredicate()
  213. >>> @Q.sexyprime.register(Integer, Integer)
  214. ... def _(int1, int2, assumptions):
  215. ... args = sorted([int1, int2])
  216. ... if not all(ask(Q.prime(a), assumptions) for a in args):
  217. ... return False
  218. ... return args[1] - args[0] == 6
  219. >>> ask(Q.sexyprime(5, 11))
  220. True
  221. Direct constructing returns ``UndefinedPredicate``, which can be
  222. applied but cannot be dispatched.
  223. >>> from sympy import Predicate, Integer
  224. >>> Q.P = Predicate("P")
  225. >>> type(Q.P)
  226. <class 'sympy.assumptions.assume.UndefinedPredicate'>
  227. >>> Q.P(1)
  228. Q.P(1)
  229. >>> Q.P.register(Integer)(lambda expr, assump: True)
  230. Traceback (most recent call last):
  231. ...
  232. TypeError: <class 'sympy.assumptions.assume.UndefinedPredicate'> cannot be dispatched.
  233. References
  234. ==========
  235. .. [1] https://en.wikipedia.org/wiki/Predicate_(mathematical_logic)
  236. .. [2] https://en.wikipedia.org/wiki/Sexy_prime
  237. """
  238. is_Atom = True
  239. def __new__(cls, *args, **kwargs):
  240. if cls is Predicate:
  241. return UndefinedPredicate(*args, **kwargs)
  242. obj = super().__new__(cls, *args)
  243. return obj
  244. @property
  245. def name(self):
  246. # May be overridden
  247. return type(self).__name__
  248. @classmethod
  249. def register(cls, *types, **kwargs):
  250. """
  251. Register the signature to the handler.
  252. """
  253. if cls.handler is None:
  254. raise TypeError("%s cannot be dispatched." % type(cls))
  255. return cls.handler.register(*types, **kwargs)
  256. @classmethod
  257. def register_many(cls, *types, **kwargs):
  258. """
  259. Register multiple signatures to same handler.
  260. """
  261. def _(func):
  262. for t in types:
  263. if not is_sequence(t):
  264. t = (t,) # for convenience, allow passing `type` to mean `(type,)`
  265. cls.register(*t, **kwargs)(func)
  266. return _
  267. def __call__(self, *args):
  268. return AppliedPredicate(self, *args)
  269. def eval(self, args, assumptions=True):
  270. """
  271. Evaluate ``self(*args)`` under the given assumptions.
  272. This uses only direct resolution methods, not logical inference.
  273. """
  274. result = None
  275. try:
  276. result = self.handler(*args, assumptions=assumptions)
  277. except NotImplementedError:
  278. pass
  279. return result
  280. def _eval_refine(self, assumptions):
  281. # When Predicate is no longer Boolean, delete this method
  282. return self
  283. class UndefinedPredicate(Predicate):
  284. """
  285. Predicate without handler.
  286. Explanation
  287. ===========
  288. This predicate is generated by using ``Predicate`` directly for
  289. construction. It does not have a handler, and evaluating this with
  290. arguments is done by SAT solver.
  291. Examples
  292. ========
  293. >>> from sympy import Predicate, Q
  294. >>> Q.P = Predicate('P')
  295. >>> Q.P.func
  296. <class 'sympy.assumptions.assume.UndefinedPredicate'>
  297. >>> Q.P.name
  298. Str('P')
  299. """
  300. handler = None
  301. def __new__(cls, name, handlers=None):
  302. # "handlers" parameter supports old design
  303. if not isinstance(name, Str):
  304. name = Str(name)
  305. obj = super(Boolean, cls).__new__(cls, name)
  306. obj.handlers = handlers or []
  307. return obj
  308. @property
  309. def name(self):
  310. return self.args[0]
  311. def _hashable_content(self):
  312. return (self.name,)
  313. def __getnewargs__(self):
  314. return (self.name,)
  315. def __call__(self, expr):
  316. return AppliedPredicate(self, expr)
  317. def add_handler(self, handler):
  318. sympy_deprecation_warning(
  319. """
  320. The AskHandler system is deprecated. Predicate.add_handler()
  321. should be replaced with the multipledispatch handler of Predicate.
  322. """,
  323. deprecated_since_version="1.8",
  324. active_deprecations_target='deprecated-askhandler',
  325. )
  326. self.handlers.append(handler)
  327. def remove_handler(self, handler):
  328. sympy_deprecation_warning(
  329. """
  330. The AskHandler system is deprecated. Predicate.remove_handler()
  331. should be replaced with the multipledispatch handler of Predicate.
  332. """,
  333. deprecated_since_version="1.8",
  334. active_deprecations_target='deprecated-askhandler',
  335. )
  336. self.handlers.remove(handler)
  337. def eval(self, args, assumptions=True):
  338. # Support for deprecated design
  339. # When old design is removed, this will always return None
  340. sympy_deprecation_warning(
  341. """
  342. The AskHandler system is deprecated. Evaluating UndefinedPredicate
  343. objects should be replaced with the multipledispatch handler of
  344. Predicate.
  345. """,
  346. deprecated_since_version="1.8",
  347. active_deprecations_target='deprecated-askhandler',
  348. stacklevel=5,
  349. )
  350. expr, = args
  351. res, _res = None, None
  352. mro = inspect.getmro(type(expr))
  353. for handler in self.handlers:
  354. cls = get_class(handler)
  355. for subclass in mro:
  356. eval_ = getattr(cls, subclass.__name__, None)
  357. if eval_ is None:
  358. continue
  359. res = eval_(expr, assumptions)
  360. # Do not stop if value returned is None
  361. # Try to check for higher classes
  362. if res is None:
  363. continue
  364. if _res is None:
  365. _res = res
  366. else:
  367. # only check consistency if both resolutors have concluded
  368. if _res != res:
  369. raise ValueError('incompatible resolutors')
  370. break
  371. return res
  372. @contextmanager
  373. def assuming(*assumptions):
  374. """
  375. Context manager for assumptions.
  376. Examples
  377. ========
  378. >>> from sympy import assuming, Q, ask
  379. >>> from sympy.abc import x, y
  380. >>> print(ask(Q.integer(x + y)))
  381. None
  382. >>> with assuming(Q.integer(x), Q.integer(y)):
  383. ... print(ask(Q.integer(x + y)))
  384. True
  385. """
  386. old_global_assumptions = global_assumptions.copy()
  387. global_assumptions.update(assumptions)
  388. try:
  389. yield
  390. finally:
  391. global_assumptions.clear()
  392. global_assumptions.update(old_global_assumptions)