cnf.py 12 KB


  1. """
  2. The classes used here are for the internal use of assumptions system
  3. only and should not be used anywhere else as these do not possess the
  4. signatures common to SymPy objects. For general use of logic constructs
  5. please refer to sympy.logic classes And, Or, Not, etc.
  6. """
  7. from itertools import combinations, product
  8. from sympy.core.singleton import S
  9. from sympy.logic.boolalg import (Equivalent, ITE, Implies, Nand, Nor, Xor)
  10. from sympy.core.relational import Eq, Ne, Gt, Lt, Ge, Le
  11. from sympy.logic.boolalg import Or, And, Not, Xnor
  12. from itertools import zip_longest
  13. class Literal:
  14. """
  15. The smallest element of a CNF object.
  16. Parameters
  17. ==========
  18. lit : Boolean expression
  19. is_Not : bool
  20. Examples
  21. ========
  22. >>> from sympy import Q
  23. >>> from sympy.assumptions.cnf import Literal
  24. >>> from sympy.abc import x
  25. >>> Literal(Q.even(x))
  26. Literal(Q.even(x), False)
  27. >>> Literal(~Q.even(x))
  28. Literal(Q.even(x), True)
  29. """
  30. def __new__(cls, lit, is_Not=False):
  31. if isinstance(lit, Not):
  32. lit = lit.args[0]
  33. is_Not = True
  34. elif isinstance(lit, (AND, OR, Literal)):
  35. return ~lit if is_Not else lit
  36. obj = super().__new__(cls)
  37. obj.lit = lit
  38. obj.is_Not = is_Not
  39. return obj
  40. @property
  41. def arg(self):
  42. return self.lit
  43. def rcall(self, expr):
  44. if callable(self.lit):
  45. lit = self.lit(expr)
  46. else:
  47. try:
  48. lit = self.lit.apply(expr)
  49. except AttributeError:
  50. lit = self.lit.rcall(expr)
  51. return type(self)(lit, self.is_Not)
  52. def __invert__(self):
  53. is_Not = not self.is_Not
  54. return Literal(self.lit, is_Not)
  55. def __str__(self):
  56. return '{}({}, {})'.format(type(self).__name__, self.lit, self.is_Not)
  57. __repr__ = __str__
  58. def __eq__(self, other):
  59. return self.arg == other.arg and self.is_Not == other.is_Not
  60. def __hash__(self):
  61. h = hash((type(self).__name__, self.arg, self.is_Not))
  62. return h
  63. class OR:
  64. """
  65. A low-level implementation for Or
  66. """
  67. def __init__(self, *args):
  68. self._args = args
  69. @property
  70. def args(self):
  71. return sorted(self._args, key=str)
  72. def rcall(self, expr):
  73. return type(self)(*[arg.rcall(expr)
  74. for arg in self._args
  75. ])
  76. def __invert__(self):
  77. return AND(*[~arg for arg in self._args])
  78. def __hash__(self):
  79. return hash((type(self).__name__,) + tuple(self.args))
  80. def __eq__(self, other):
  81. return self.args == other.args
  82. def __str__(self):
  83. s = '(' + ' | '.join([str(arg) for arg in self.args]) + ')'
  84. return s
  85. __repr__ = __str__
  86. class AND:
  87. """
  88. A low-level implementation for And
  89. """
  90. def __init__(self, *args):
  91. self._args = args
  92. def __invert__(self):
  93. return OR(*[~arg for arg in self._args])
  94. @property
  95. def args(self):
  96. return sorted(self._args, key=str)
  97. def rcall(self, expr):
  98. return type(self)(*[arg.rcall(expr)
  99. for arg in self._args
  100. ])
  101. def __hash__(self):
  102. return hash((type(self).__name__,) + tuple(self.args))
  103. def __eq__(self, other):
  104. return self.args == other.args
  105. def __str__(self):
  106. s = '('+' & '.join([str(arg) for arg in self.args])+')'
  107. return s
  108. __repr__ = __str__
  109. def to_NNF(expr, composite_map=None):
  110. """
  111. Generates the Negation Normal Form of any boolean expression in terms
  112. of AND, OR, and Literal objects.
  113. Examples
  114. ========
  115. >>> from sympy import Q, Eq
  116. >>> from sympy.assumptions.cnf import to_NNF
  117. >>> from sympy.abc import x, y
  118. >>> expr = Q.even(x) & ~Q.positive(x)
  119. >>> to_NNF(expr)
  120. (Literal(Q.even(x), False) & Literal(Q.positive(x), True))
  121. Supported boolean objects are converted to corresponding predicates.
  122. >>> to_NNF(Eq(x, y))
  123. Literal(Q.eq(x, y), False)
  124. If ``composite_map`` argument is given, ``to_NNF`` decomposes the
  125. specified predicate into a combination of primitive predicates.
  126. >>> cmap = {Q.nonpositive: Q.negative | Q.zero}
  127. >>> to_NNF(Q.nonpositive, cmap)
  128. (Literal(Q.negative, False) | Literal(Q.zero, False))
  129. >>> to_NNF(Q.nonpositive(x), cmap)
  130. (Literal(Q.negative(x), False) | Literal(Q.zero(x), False))
  131. """
  132. from sympy.assumptions.ask import Q
  133. from sympy.assumptions.assume import AppliedPredicate, Predicate
  134. if composite_map is None:
  135. composite_map = dict()
  136. binrelpreds = {Eq: Q.eq, Ne: Q.ne, Gt: Q.gt, Lt: Q.lt, Ge: Q.ge, Le: Q.le}
  137. if type(expr) in binrelpreds:
  138. pred = binrelpreds[type(expr)]
  139. expr = pred(*expr.args)
  140. if isinstance(expr, Not):
  141. arg = expr.args[0]
  142. tmp = to_NNF(arg, composite_map) # Strategy: negate the NNF of expr
  143. return ~tmp
  144. if isinstance(expr, Or):
  145. return OR(*[to_NNF(x, composite_map) for x in Or.make_args(expr)])
  146. if isinstance(expr, And):
  147. return AND(*[to_NNF(x, composite_map) for x in And.make_args(expr)])
  148. if isinstance(expr, Nand):
  149. tmp = AND(*[to_NNF(x, composite_map) for x in expr.args])
  150. return ~tmp
  151. if isinstance(expr, Nor):
  152. tmp = OR(*[to_NNF(x, composite_map) for x in expr.args])
  153. return ~tmp
  154. if isinstance(expr, Xor):
  155. cnfs = []
  156. for i in range(0, len(expr.args) + 1, 2):
  157. for neg in combinations(expr.args, i):
  158. clause = [~to_NNF(s, composite_map) if s in neg else to_NNF(s, composite_map)
  159. for s in expr.args]
  160. cnfs.append(OR(*clause))
  161. return AND(*cnfs)
  162. if isinstance(expr, Xnor):
  163. cnfs = []
  164. for i in range(0, len(expr.args) + 1, 2):
  165. for neg in combinations(expr.args, i):
  166. clause = [~to_NNF(s, composite_map) if s in neg else to_NNF(s, composite_map)
  167. for s in expr.args]
  168. cnfs.append(OR(*clause))
  169. return ~AND(*cnfs)
  170. if isinstance(expr, Implies):
  171. L, R = to_NNF(expr.args[0], composite_map), to_NNF(expr.args[1], composite_map)
  172. return OR(~L, R)
  173. if isinstance(expr, Equivalent):
  174. cnfs = []
  175. for a, b in zip_longest(expr.args, expr.args[1:], fillvalue=expr.args[0]):
  176. a = to_NNF(a, composite_map)
  177. b = to_NNF(b, composite_map)
  178. cnfs.append(OR(~a, b))
  179. return AND(*cnfs)
  180. if isinstance(expr, ITE):
  181. L = to_NNF(expr.args[0], composite_map)
  182. M = to_NNF(expr.args[1], composite_map)
  183. R = to_NNF(expr.args[2], composite_map)
  184. return AND(OR(~L, M), OR(L, R))
  185. if isinstance(expr, AppliedPredicate):
  186. pred, args = expr.function, expr.arguments
  187. newpred = composite_map.get(pred, None)
  188. if newpred is not None:
  189. return to_NNF(newpred.rcall(*args), composite_map)
  190. if isinstance(expr, Predicate):
  191. newpred = composite_map.get(expr, None)
  192. if newpred is not None:
  193. return to_NNF(newpred, composite_map)
  194. return Literal(expr)
  195. def distribute_AND_over_OR(expr):
  196. """
  197. Distributes AND over OR in the NNF expression.
  198. Returns the result( Conjunctive Normal Form of expression)
  199. as a CNF object.
  200. """
  201. if not isinstance(expr, (AND, OR)):
  202. tmp = set()
  203. tmp.add(frozenset((expr,)))
  204. return CNF(tmp)
  205. if isinstance(expr, OR):
  206. return CNF.all_or(*[distribute_AND_over_OR(arg)
  207. for arg in expr._args])
  208. if isinstance(expr, AND):
  209. return CNF.all_and(*[distribute_AND_over_OR(arg)
  210. for arg in expr._args])
  211. class CNF:
  212. """
  213. Class to represent CNF of a Boolean expression.
  214. Consists of set of clauses, which themselves are stored as
  215. frozenset of Literal objects.
  216. Examples
  217. ========
  218. >>> from sympy import Q
  219. >>> from sympy.assumptions.cnf import CNF
  220. >>> from sympy.abc import x
  221. >>> cnf = CNF.from_prop(Q.real(x) & ~Q.zero(x))
  222. >>> cnf.clauses
  223. {frozenset({Literal(Q.zero(x), True)}),
  224. frozenset({Literal(Q.negative(x), False),
  225. Literal(Q.positive(x), False), Literal(Q.zero(x), False)})}
  226. """
  227. def __init__(self, clauses=None):
  228. if not clauses:
  229. clauses = set()
  230. self.clauses = clauses
  231. def add(self, prop):
  232. clauses = CNF.to_CNF(prop).clauses
  233. self.add_clauses(clauses)
  234. def __str__(self):
  235. s = ' & '.join(
  236. ['(' + ' | '.join([str(lit) for lit in clause]) +')'
  237. for clause in self.clauses]
  238. )
  239. return s
  240. def extend(self, props):
  241. for p in props:
  242. self.add(p)
  243. return self
  244. def copy(self):
  245. return CNF(set(self.clauses))
  246. def add_clauses(self, clauses):
  247. self.clauses |= clauses
  248. @classmethod
  249. def from_prop(cls, prop):
  250. res = cls()
  251. res.add(prop)
  252. return res
  253. def __iand__(self, other):
  254. self.add_clauses(other.clauses)
  255. return self
  256. def all_predicates(self):
  257. predicates = set()
  258. for c in self.clauses:
  259. predicates |= {arg.lit for arg in c}
  260. return predicates
  261. def _or(self, cnf):
  262. clauses = set()
  263. for a, b in product(self.clauses, cnf.clauses):
  264. tmp = set(a)
  265. for t in b:
  266. tmp.add(t)
  267. clauses.add(frozenset(tmp))
  268. return CNF(clauses)
  269. def _and(self, cnf):
  270. clauses = self.clauses.union(cnf.clauses)
  271. return CNF(clauses)
  272. def _not(self):
  273. clss = list(self.clauses)
  274. ll = set()
  275. for x in clss[-1]:
  276. ll.add(frozenset((~x,)))
  277. ll = CNF(ll)
  278. for rest in clss[:-1]:
  279. p = set()
  280. for x in rest:
  281. p.add(frozenset((~x,)))
  282. ll = ll._or(CNF(p))
  283. return ll
  284. def rcall(self, expr):
  285. clause_list = list()
  286. for clause in self.clauses:
  287. lits = [arg.rcall(expr) for arg in clause]
  288. clause_list.append(OR(*lits))
  289. expr = AND(*clause_list)
  290. return distribute_AND_over_OR(expr)
  291. @classmethod
  292. def all_or(cls, *cnfs):
  293. b = cnfs[0].copy()
  294. for rest in cnfs[1:]:
  295. b = b._or(rest)
  296. return b
  297. @classmethod
  298. def all_and(cls, *cnfs):
  299. b = cnfs[0].copy()
  300. for rest in cnfs[1:]:
  301. b = b._and(rest)
  302. return b
  303. @classmethod
  304. def to_CNF(cls, expr):
  305. from sympy.assumptions.facts import get_composite_predicates
  306. expr = to_NNF(expr, get_composite_predicates())
  307. expr = distribute_AND_over_OR(expr)
  308. return expr
  309. @classmethod
  310. def CNF_to_cnf(cls, cnf):
  311. """
  312. Converts CNF object to SymPy's boolean expression
  313. retaining the form of expression.
  314. """
  315. def remove_literal(arg):
  316. return Not(arg.lit) if arg.is_Not else arg.lit
  317. return And(*(Or(*(remove_literal(arg) for arg in clause)) for clause in cnf.clauses))
  318. class EncodedCNF:
  319. """
  320. Class for encoding the CNF expression.
  321. """
  322. def __init__(self, data=None, encoding=None):
  323. if not data and not encoding:
  324. data = list()
  325. encoding = dict()
  326. self.data = data
  327. self.encoding = encoding
  328. self._symbols = list(encoding.keys())
  329. def from_cnf(self, cnf):
  330. self._symbols = list(cnf.all_predicates())
  331. n = len(self._symbols)
  332. self.encoding = dict(list(zip(self._symbols, list(range(1, n + 1)))))
  333. self.data = [self.encode(clause) for clause in cnf.clauses]
  334. @property
  335. def symbols(self):
  336. return self._symbols
  337. @property
  338. def variables(self):
  339. return range(1, len(self._symbols) + 1)
  340. def copy(self):
  341. new_data = [set(clause) for clause in self.data]
  342. return EncodedCNF(new_data, dict(self.encoding))
  343. def add_prop(self, prop):
  344. cnf = CNF.from_prop(prop)
  345. self.add_from_cnf(cnf)
  346. def add_from_cnf(self, cnf):
  347. clauses = [self.encode(clause) for clause in cnf.clauses]
  348. self.data += clauses
  349. def encode_arg(self, arg):
  350. literal = arg.lit
  351. value = self.encoding.get(literal, None)
  352. if value is None:
  353. n = len(self._symbols)
  354. self._symbols.append(literal)
  355. value = self.encoding[literal] = n + 1
  356. if arg.is_Not:
  357. return -value
  358. else:
  359. return value
  360. def encode(self, clause):
  361. return {self.encode_arg(arg) if not arg.lit == S.false else 0 for arg in clause}