satask.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. """
  2. Module to evaluate the proposition with assumptions using SAT algorithm.
  3. """
  4. from sympy.core.singleton import S
  5. from sympy.core.symbol import Symbol
  6. from sympy.assumptions.ask_generated import get_all_known_facts
  7. from sympy.assumptions.assume import global_assumptions, AppliedPredicate
  8. from sympy.assumptions.sathandlers import class_fact_registry
  9. from sympy.core import oo
  10. from sympy.logic.inference import satisfiable
  11. from sympy.assumptions.cnf import CNF, EncodedCNF
  12. def satask(proposition, assumptions=True, context=global_assumptions,
  13. use_known_facts=True, iterations=oo):
  14. """
  15. Function to evaluate the proposition with assumptions using SAT algorithm.
  16. This function extracts every fact relevant to the expressions composing
  17. proposition and assumptions. For example, if a predicate containing
  18. ``Abs(x)`` is proposed, then ``Q.zero(Abs(x)) | Q.positive(Abs(x))``
  19. will be found and passed to SAT solver because ``Q.nonnegative`` is
  20. registered as a fact for ``Abs``.
  21. Proposition is evaluated to ``True`` or ``False`` if the truth value can be
  22. determined. If not, ``None`` is returned.
  23. Parameters
  24. ==========
  25. proposition : Any boolean expression.
  26. Proposition which will be evaluated to boolean value.
  27. assumptions : Any boolean expression, optional.
  28. Local assumptions to evaluate the *proposition*.
  29. context : AssumptionsContext, optional.
  30. Default assumptions to evaluate the *proposition*. By default,
  31. this is ``sympy.assumptions.global_assumptions`` variable.
  32. use_known_facts : bool, optional.
  33. If ``True``, facts from ``sympy.assumptions.ask_generated``
  34. module are passed to SAT solver as well.
  35. iterations : int, optional.
  36. Number of times that relevant facts are recursively extracted.
  37. Default is infinite times until no new fact is found.
  38. Returns
  39. =======
  40. ``True``, ``False``, or ``None``
  41. Examples
  42. ========
  43. >>> from sympy import Abs, Q
  44. >>> from sympy.assumptions.satask import satask
  45. >>> from sympy.abc import x
  46. >>> satask(Q.zero(Abs(x)), Q.zero(x))
  47. True
  48. """
  49. props = CNF.from_prop(proposition)
  50. _props = CNF.from_prop(~proposition)
  51. assumptions = CNF.from_prop(assumptions)
  52. context_cnf = CNF()
  53. if context:
  54. context_cnf = context_cnf.extend(context)
  55. sat = get_all_relevant_facts(props, assumptions, context_cnf,
  56. use_known_facts=use_known_facts, iterations=iterations)
  57. sat.add_from_cnf(assumptions)
  58. if context:
  59. sat.add_from_cnf(context_cnf)
  60. return check_satisfiability(props, _props, sat)
  61. def check_satisfiability(prop, _prop, factbase):
  62. sat_true = factbase.copy()
  63. sat_false = factbase.copy()
  64. sat_true.add_from_cnf(prop)
  65. sat_false.add_from_cnf(_prop)
  66. can_be_true = satisfiable(sat_true)
  67. can_be_false = satisfiable(sat_false)
  68. if can_be_true and can_be_false:
  69. return None
  70. if can_be_true and not can_be_false:
  71. return True
  72. if not can_be_true and can_be_false:
  73. return False
  74. if not can_be_true and not can_be_false:
  75. # TODO: Run additional checks to see which combination of the
  76. # assumptions, global_assumptions, and relevant_facts are
  77. # inconsistent.
  78. raise ValueError("Inconsistent assumptions")
  79. def extract_predargs(proposition, assumptions=None, context=None):
  80. """
  81. Extract every expression in the argument of predicates from *proposition*,
  82. *assumptions* and *context*.
  83. Parameters
  84. ==========
  85. proposition : sympy.assumptions.cnf.CNF
  86. assumptions : sympy.assumptions.cnf.CNF, optional.
  87. context : sympy.assumptions.cnf.CNF, optional.
  88. CNF generated from assumptions context.
  89. Examples
  90. ========
  91. >>> from sympy import Q, Abs
  92. >>> from sympy.assumptions.cnf import CNF
  93. >>> from sympy.assumptions.satask import extract_predargs
  94. >>> from sympy.abc import x, y
  95. >>> props = CNF.from_prop(Q.zero(Abs(x*y)))
  96. >>> assump = CNF.from_prop(Q.zero(x) & Q.zero(y))
  97. >>> extract_predargs(props, assump)
  98. {x, y, Abs(x*y)}
  99. """
  100. req_keys = find_symbols(proposition)
  101. keys = proposition.all_predicates()
  102. # XXX: We need this since True/False are not Basic
  103. lkeys = set()
  104. if assumptions:
  105. lkeys |= assumptions.all_predicates()
  106. if context:
  107. lkeys |= context.all_predicates()
  108. lkeys = lkeys - {S.true, S.false}
  109. tmp_keys = None
  110. while tmp_keys != set():
  111. tmp = set()
  112. for l in lkeys:
  113. syms = find_symbols(l)
  114. if (syms & req_keys) != set():
  115. tmp |= syms
  116. tmp_keys = tmp - req_keys
  117. req_keys |= tmp_keys
  118. keys |= {l for l in lkeys if find_symbols(l) & req_keys != set()}
  119. exprs = set()
  120. for key in keys:
  121. if isinstance(key, AppliedPredicate):
  122. exprs |= set(key.arguments)
  123. else:
  124. exprs.add(key)
  125. return exprs
  126. def find_symbols(pred):
  127. """
  128. Find every :obj:`~.Symbol` in *pred*.
  129. Parameters
  130. ==========
  131. pred : sympy.assumptions.cnf.CNF, or any Expr.
  132. """
  133. if isinstance(pred, CNF):
  134. symbols = set()
  135. for a in pred.all_predicates():
  136. symbols |= find_symbols(a)
  137. return symbols
  138. return pred.atoms(Symbol)
  139. def get_relevant_clsfacts(exprs, relevant_facts=None):
  140. """
  141. Extract relevant facts from the items in *exprs*. Facts are defined in
  142. ``assumptions.sathandlers`` module.
  143. This function is recursively called by ``get_all_relevant_facts()``.
  144. Parameters
  145. ==========
  146. exprs : set
  147. Expressions whose relevant facts are searched.
  148. relevant_facts : sympy.assumptions.cnf.CNF, optional.
  149. Pre-discovered relevant facts.
  150. Returns
  151. =======
  152. exprs : set
  153. Candidates for next relevant fact searching.
  154. relevant_facts : sympy.assumptions.cnf.CNF
  155. Updated relevant facts.
  156. Examples
  157. ========
  158. Here, we will see how facts relevant to ``Abs(x*y)`` are recursively
  159. extracted. On the first run, set containing the expression is passed
  160. without pre-discovered relevant facts. The result is a set containig
  161. candidates for next run, and ``CNF()`` instance containing facts
  162. which are relevant to ``Abs`` and its argument.
  163. >>> from sympy import Abs
  164. >>> from sympy.assumptions.satask import get_relevant_clsfacts
  165. >>> from sympy.abc import x, y
  166. >>> exprs = {Abs(x*y)}
  167. >>> exprs, facts = get_relevant_clsfacts(exprs)
  168. >>> exprs
  169. {x*y}
  170. >>> facts.clauses #doctest: +SKIP
  171. {frozenset({Literal(Q.odd(Abs(x*y)), False), Literal(Q.odd(x*y), True)}),
  172. frozenset({Literal(Q.zero(Abs(x*y)), False), Literal(Q.zero(x*y), True)}),
  173. frozenset({Literal(Q.even(Abs(x*y)), False), Literal(Q.even(x*y), True)}),
  174. frozenset({Literal(Q.zero(Abs(x*y)), True), Literal(Q.zero(x*y), False)}),
  175. frozenset({Literal(Q.even(Abs(x*y)), False),
  176. Literal(Q.odd(Abs(x*y)), False),
  177. Literal(Q.odd(x*y), True)}),
  178. frozenset({Literal(Q.even(Abs(x*y)), False),
  179. Literal(Q.even(x*y), True),
  180. Literal(Q.odd(Abs(x*y)), False)}),
  181. frozenset({Literal(Q.positive(Abs(x*y)), False),
  182. Literal(Q.zero(Abs(x*y)), False)})}
  183. We pass the first run's results to the second run, and get the expressions
  184. for next run and updated facts.
  185. >>> exprs, facts = get_relevant_clsfacts(exprs, relevant_facts=facts)
  186. >>> exprs
  187. {x, y}
  188. On final run, no more candidate is returned thus we know that all
  189. relevant facts are successfully retrieved.
  190. >>> exprs, facts = get_relevant_clsfacts(exprs, relevant_facts=facts)
  191. >>> exprs
  192. set()
  193. """
  194. if not relevant_facts:
  195. relevant_facts = CNF()
  196. newexprs = set()
  197. for expr in exprs:
  198. for fact in class_fact_registry(expr):
  199. newfact = CNF.to_CNF(fact)
  200. relevant_facts = relevant_facts._and(newfact)
  201. for key in newfact.all_predicates():
  202. if isinstance(key, AppliedPredicate):
  203. newexprs |= set(key.arguments)
  204. return newexprs - exprs, relevant_facts
  205. def get_all_relevant_facts(proposition, assumptions, context,
  206. use_known_facts=True, iterations=oo):
  207. """
  208. Extract all relevant facts from *proposition* and *assumptions*.
  209. This function extracts the facts by recursively calling
  210. ``get_relevant_clsfacts()``. Extracted facts are converted to
  211. ``EncodedCNF`` and returned.
  212. Parameters
  213. ==========
  214. proposition : sympy.assumptions.cnf.CNF
  215. CNF generated from proposition expression.
  216. assumptions : sympy.assumptions.cnf.CNF
  217. CNF generated from assumption expression.
  218. context : sympy.assumptions.cnf.CNF
  219. CNF generated from assumptions context.
  220. use_known_facts : bool, optional.
  221. If ``True``, facts from ``sympy.assumptions.ask_generated``
  222. module are encoded as well.
  223. iterations : int, optional.
  224. Number of times that relevant facts are recursively extracted.
  225. Default is infinite times until no new fact is found.
  226. Returns
  227. =======
  228. sympy.assumptions.cnf.EncodedCNF
  229. Examples
  230. ========
  231. >>> from sympy import Q
  232. >>> from sympy.assumptions.cnf import CNF
  233. >>> from sympy.assumptions.satask import get_all_relevant_facts
  234. >>> from sympy.abc import x, y
  235. >>> props = CNF.from_prop(Q.nonzero(x*y))
  236. >>> assump = CNF.from_prop(Q.nonzero(x))
  237. >>> context = CNF.from_prop(Q.nonzero(y))
  238. >>> get_all_relevant_facts(props, assump, context) #doctest: +SKIP
  239. <sympy.assumptions.cnf.EncodedCNF at 0x7f09faa6ccd0>
  240. """
  241. # The relevant facts might introduce new keys, e.g., Q.zero(x*y) will
  242. # introduce the keys Q.zero(x) and Q.zero(y), so we need to run it until
  243. # we stop getting new things. Hopefully this strategy won't lead to an
  244. # infinite loop in the future.
  245. i = 0
  246. relevant_facts = CNF()
  247. all_exprs = set()
  248. while True:
  249. if i == 0:
  250. exprs = extract_predargs(proposition, assumptions, context)
  251. all_exprs |= exprs
  252. exprs, relevant_facts = get_relevant_clsfacts(exprs, relevant_facts)
  253. i += 1
  254. if i >= iterations:
  255. break
  256. if not exprs:
  257. break
  258. if use_known_facts:
  259. known_facts_CNF = CNF()
  260. known_facts_CNF.add_clauses(get_all_known_facts())
  261. kf_encoded = EncodedCNF()
  262. kf_encoded.from_cnf(known_facts_CNF)
  263. def translate_literal(lit, delta):
  264. if lit > 0:
  265. return lit + delta
  266. else:
  267. return lit - delta
  268. def translate_data(data, delta):
  269. return [{translate_literal(i, delta) for i in clause} for clause in data]
  270. data = []
  271. symbols = []
  272. n_lit = len(kf_encoded.symbols)
  273. for i, expr in enumerate(all_exprs):
  274. symbols += [pred(expr) for pred in kf_encoded.symbols]
  275. data += translate_data(kf_encoded.data, i * n_lit)
  276. encoding = dict(list(zip(symbols, range(1, len(symbols)+1))))
  277. ctx = EncodedCNF(data, encoding)
  278. else:
  279. ctx = EncodedCNF()
  280. ctx.add_from_cnf(relevant_facts)
  281. return ctx