conditionset.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. from sympy.core.singleton import S
  2. from sympy.core.basic import Basic
  3. from sympy.core.containers import Tuple
  4. from sympy.core.function import Lambda, BadSignatureError
  5. from sympy.core.logic import fuzzy_bool
  6. from sympy.core.relational import Eq
  7. from sympy.core.symbol import Dummy
  8. from sympy.core.sympify import _sympify
  9. from sympy.logic.boolalg import And, as_Boolean
  10. from sympy.utilities.iterables import sift, flatten, has_dups
  11. from sympy.utilities.exceptions import sympy_deprecation_warning
  12. from .contains import Contains
  13. from .sets import Set, Union, FiniteSet
  14. adummy = Dummy('conditionset')
  15. class ConditionSet(Set):
  16. r"""
  17. Set of elements which satisfies a given condition.
  18. .. math:: \{x \mid \textrm{condition}(x) = \texttt{True}, x \in S\}
  19. Examples
  20. ========
  21. >>> from sympy import Symbol, S, ConditionSet, pi, Eq, sin, Interval
  22. >>> from sympy.abc import x, y, z
  23. >>> sin_sols = ConditionSet(x, Eq(sin(x), 0), Interval(0, 2*pi))
  24. >>> 2*pi in sin_sols
  25. True
  26. >>> pi/2 in sin_sols
  27. False
  28. >>> 3*pi in sin_sols
  29. False
  30. >>> 5 in ConditionSet(x, x**2 > 4, S.Reals)
  31. True
  32. If the value is not in the base set, the result is false:
  33. >>> 5 in ConditionSet(x, x**2 > 4, Interval(2, 4))
  34. False
  35. Notes
  36. =====
  37. Symbols with assumptions should be avoided or else the
  38. condition may evaluate without consideration of the set:
  39. >>> n = Symbol('n', negative=True)
  40. >>> cond = (n > 0); cond
  41. False
  42. >>> ConditionSet(n, cond, S.Integers)
  43. EmptySet
  44. Only free symbols can be changed by using `subs`:
  45. >>> c = ConditionSet(x, x < 1, {x, z})
  46. >>> c.subs(x, y)
  47. ConditionSet(x, x < 1, {y, z})
  48. To check if ``pi`` is in ``c`` use:
  49. >>> pi in c
  50. False
  51. If no base set is specified, the universal set is implied:
  52. >>> ConditionSet(x, x < 1).base_set
  53. UniversalSet
  54. Only symbols or symbol-like expressions can be used:
  55. >>> ConditionSet(x + 1, x + 1 < 1, S.Integers)
  56. Traceback (most recent call last):
  57. ...
  58. ValueError: non-symbol dummy not recognized in condition
  59. When the base set is a ConditionSet, the symbols will be
  60. unified if possible with preference for the outermost symbols:
  61. >>> ConditionSet(x, x < y, ConditionSet(z, z + y < 2, S.Integers))
  62. ConditionSet(x, (x < y) & (x + y < 2), Integers)
  63. """
  64. def __new__(cls, sym, condition, base_set=S.UniversalSet):
  65. sym = _sympify(sym)
  66. flat = flatten([sym])
  67. if has_dups(flat):
  68. raise BadSignatureError("Duplicate symbols detected")
  69. base_set = _sympify(base_set)
  70. if not isinstance(base_set, Set):
  71. raise TypeError(
  72. 'base set should be a Set object, not %s' % base_set)
  73. condition = _sympify(condition)
  74. if isinstance(condition, FiniteSet):
  75. condition_orig = condition
  76. temp = (Eq(lhs, 0) for lhs in condition)
  77. condition = And(*temp)
  78. sympy_deprecation_warning(
  79. f"""
  80. Using a set for the condition in ConditionSet is deprecated. Use a boolean
  81. instead.
  82. In this case, replace
  83. {condition_orig}
  84. with
  85. {condition}
  86. """,
  87. deprecated_since_version='1.5',
  88. active_deprecations_target="deprecated-conditionset-set",
  89. )
  90. condition = as_Boolean(condition)
  91. if condition is S.true:
  92. return base_set
  93. if condition is S.false:
  94. return S.EmptySet
  95. if base_set is S.EmptySet:
  96. return S.EmptySet
  97. # no simple answers, so now check syms
  98. for i in flat:
  99. if not getattr(i, '_diff_wrt', False):
  100. raise ValueError('`%s` is not symbol-like' % i)
  101. if base_set.contains(sym) is S.false:
  102. raise TypeError('sym `%s` is not in base_set `%s`' % (sym, base_set))
  103. know = None
  104. if isinstance(base_set, FiniteSet):
  105. sifted = sift(
  106. base_set, lambda _: fuzzy_bool(condition.subs(sym, _)))
  107. if sifted[None]:
  108. know = FiniteSet(*sifted[True])
  109. base_set = FiniteSet(*sifted[None])
  110. else:
  111. return FiniteSet(*sifted[True])
  112. if isinstance(base_set, cls):
  113. s, c, b = base_set.args
  114. def sig(s):
  115. return cls(s, Eq(adummy, 0)).as_dummy().sym
  116. sa, sb = map(sig, (sym, s))
  117. if sa != sb:
  118. raise BadSignatureError('sym does not match sym of base set')
  119. reps = dict(zip(flatten([sym]), flatten([s])))
  120. if s == sym:
  121. condition = And(condition, c)
  122. base_set = b
  123. elif not c.free_symbols & sym.free_symbols:
  124. reps = {v: k for k, v in reps.items()}
  125. condition = And(condition, c.xreplace(reps))
  126. base_set = b
  127. elif not condition.free_symbols & s.free_symbols:
  128. sym = sym.xreplace(reps)
  129. condition = And(condition.xreplace(reps), c)
  130. base_set = b
  131. # flatten ConditionSet(Contains(ConditionSet())) expressions
  132. if isinstance(condition, Contains) and (sym == condition.args[0]):
  133. if isinstance(condition.args[1], Set):
  134. return condition.args[1].intersect(base_set)
  135. rv = Basic.__new__(cls, sym, condition, base_set)
  136. return rv if know is None else Union(know, rv)
  137. sym = property(lambda self: self.args[0])
  138. condition = property(lambda self: self.args[1])
  139. base_set = property(lambda self: self.args[2])
  140. @property
  141. def free_symbols(self):
  142. cond_syms = self.condition.free_symbols - self.sym.free_symbols
  143. return cond_syms | self.base_set.free_symbols
  144. @property
  145. def bound_symbols(self):
  146. return flatten([self.sym])
  147. def _contains(self, other):
  148. def ok_sig(a, b):
  149. tuples = [isinstance(i, Tuple) for i in (a, b)]
  150. c = tuples.count(True)
  151. if c == 1:
  152. return False
  153. if c == 0:
  154. return True
  155. return len(a) == len(b) and all(
  156. ok_sig(i, j) for i, j in zip(a, b))
  157. if not ok_sig(self.sym, other):
  158. return S.false
  159. # try doing base_cond first and return
  160. # False immediately if it is False
  161. base_cond = Contains(other, self.base_set)
  162. if base_cond is S.false:
  163. return S.false
  164. # Substitute other into condition. This could raise e.g. for
  165. # ConditionSet(x, 1/x >= 0, Reals).contains(0)
  166. lamda = Lambda((self.sym,), self.condition)
  167. try:
  168. lambda_cond = lamda(other)
  169. except TypeError:
  170. return Contains(other, self, evaluate=False)
  171. else:
  172. return And(base_cond, lambda_cond)
  173. def as_relational(self, other):
  174. f = Lambda(self.sym, self.condition)
  175. if isinstance(self.sym, Tuple):
  176. f = f(*other)
  177. else:
  178. f = f(other)
  179. return And(f, self.base_set.contains(other))
  180. def _eval_subs(self, old, new):
  181. sym, cond, base = self.args
  182. dsym = sym.subs(old, adummy)
  183. insym = dsym.has(adummy)
  184. # prioritize changing a symbol in the base
  185. newbase = base.subs(old, new)
  186. if newbase != base:
  187. if not insym:
  188. cond = cond.subs(old, new)
  189. return self.func(sym, cond, newbase)
  190. if insym:
  191. pass # no change of bound symbols via subs
  192. elif getattr(new, '_diff_wrt', False):
  193. cond = cond.subs(old, new)
  194. else:
  195. pass # let error about the symbol raise from __new__
  196. return self.func(sym, cond, base)