functions.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. from sympy.core.singleton import S
  2. from sympy.sets.sets import Set
  3. from sympy.calculus.singularities import singularities
  4. from sympy.core import Expr, Add
  5. from sympy.core.function import Lambda, FunctionClass, diff, expand_mul
  6. from sympy.core.numbers import Float, oo
  7. from sympy.core.symbol import Dummy, symbols, Wild
  8. from sympy.functions.elementary.exponential import exp, log
  9. from sympy.functions.elementary.miscellaneous import Min, Max
  10. from sympy.logic.boolalg import true
  11. from sympy.multipledispatch import Dispatcher
  12. from sympy.sets import (imageset, Interval, FiniteSet, Union, ImageSet,
  13. Intersection, Range, Complement)
  14. from sympy.sets.sets import EmptySet, is_function_invertible_in_set
  15. from sympy.sets.fancysets import Integers, Naturals, Reals
  16. from sympy.functions.elementary.exponential import match_real_imag
  17. _x, _y = symbols("x y")
  18. FunctionUnion = (FunctionClass, Lambda)
  19. _set_function = Dispatcher('_set_function')
  20. @_set_function.register(FunctionClass, Set)
  21. def _(f, x):
  22. return None
  23. @_set_function.register(FunctionUnion, FiniteSet)
  24. def _(f, x):
  25. return FiniteSet(*map(f, x))
  26. @_set_function.register(Lambda, Interval)
  27. def _(f, x):
  28. from sympy.solvers.solveset import solveset
  29. from sympy.series import limit
  30. # TODO: handle functions with infinitely many solutions (eg, sin, tan)
  31. # TODO: handle multivariate functions
  32. expr = f.expr
  33. if len(expr.free_symbols) > 1 or len(f.variables) != 1:
  34. return
  35. var = f.variables[0]
  36. if not var.is_real:
  37. if expr.subs(var, Dummy(real=True)).is_real is False:
  38. return
  39. if expr.is_Piecewise:
  40. result = S.EmptySet
  41. domain_set = x
  42. for (p_expr, p_cond) in expr.args:
  43. if p_cond is true:
  44. intrvl = domain_set
  45. else:
  46. intrvl = p_cond.as_set()
  47. intrvl = Intersection(domain_set, intrvl)
  48. if p_expr.is_Number:
  49. image = FiniteSet(p_expr)
  50. else:
  51. image = imageset(Lambda(var, p_expr), intrvl)
  52. result = Union(result, image)
  53. # remove the part which has been `imaged`
  54. domain_set = Complement(domain_set, intrvl)
  55. if domain_set is S.EmptySet:
  56. break
  57. return result
  58. if not x.start.is_comparable or not x.end.is_comparable:
  59. return
  60. try:
  61. from sympy.polys.polyutils import _nsort
  62. sing = list(singularities(expr, var, x))
  63. if len(sing) > 1:
  64. sing = _nsort(sing)
  65. except NotImplementedError:
  66. return
  67. if x.left_open:
  68. _start = limit(expr, var, x.start, dir="+")
  69. elif x.start not in sing:
  70. _start = f(x.start)
  71. if x.right_open:
  72. _end = limit(expr, var, x.end, dir="-")
  73. elif x.end not in sing:
  74. _end = f(x.end)
  75. if len(sing) == 0:
  76. soln_expr = solveset(diff(expr, var), var)
  77. if not (isinstance(soln_expr, FiniteSet)
  78. or soln_expr is S.EmptySet):
  79. return
  80. solns = list(soln_expr)
  81. extr = [_start, _end] + [f(i) for i in solns
  82. if i.is_real and i in x]
  83. start, end = Min(*extr), Max(*extr)
  84. left_open, right_open = False, False
  85. if _start <= _end:
  86. # the minimum or maximum value can occur simultaneously
  87. # on both the edge of the interval and in some interior
  88. # point
  89. if start == _start and start not in solns:
  90. left_open = x.left_open
  91. if end == _end and end not in solns:
  92. right_open = x.right_open
  93. else:
  94. if start == _end and start not in solns:
  95. left_open = x.right_open
  96. if end == _start and end not in solns:
  97. right_open = x.left_open
  98. return Interval(start, end, left_open, right_open)
  99. else:
  100. return imageset(f, Interval(x.start, sing[0],
  101. x.left_open, True)) + \
  102. Union(*[imageset(f, Interval(sing[i], sing[i + 1], True, True))
  103. for i in range(0, len(sing) - 1)]) + \
  104. imageset(f, Interval(sing[-1], x.end, True, x.right_open))
  105. @_set_function.register(FunctionClass, Interval)
  106. def _(f, x):
  107. if f == exp:
  108. return Interval(exp(x.start), exp(x.end), x.left_open, x.right_open)
  109. elif f == log:
  110. return Interval(log(x.start), log(x.end), x.left_open, x.right_open)
  111. return ImageSet(Lambda(_x, f(_x)), x)
  112. @_set_function.register(FunctionUnion, Union)
  113. def _(f, x):
  114. return Union(*(imageset(f, arg) for arg in x.args))
  115. @_set_function.register(FunctionUnion, Intersection)
  116. def _(f, x):
  117. # If the function is invertible, intersect the maps of the sets.
  118. if is_function_invertible_in_set(f, x):
  119. return Intersection(*(imageset(f, arg) for arg in x.args))
  120. else:
  121. return ImageSet(Lambda(_x, f(_x)), x)
  122. @_set_function.register(FunctionUnion, EmptySet)
  123. def _(f, x):
  124. return x
  125. @_set_function.register(FunctionUnion, Set)
  126. def _(f, x):
  127. return ImageSet(Lambda(_x, f(_x)), x)
  128. @_set_function.register(FunctionUnion, Range)
  129. def _(f, self):
  130. if not self:
  131. return S.EmptySet
  132. if not isinstance(f.expr, Expr):
  133. return
  134. if self.size == 1:
  135. return FiniteSet(f(self[0]))
  136. if f is S.IdentityFunction:
  137. return self
  138. x = f.variables[0]
  139. expr = f.expr
  140. # handle f that is linear in f's variable
  141. if x not in expr.free_symbols or x in expr.diff(x).free_symbols:
  142. return
  143. if self.start.is_finite:
  144. F = f(self.step*x + self.start) # for i in range(len(self))
  145. else:
  146. F = f(-self.step*x + self[-1])
  147. F = expand_mul(F)
  148. if F != expr:
  149. return imageset(x, F, Range(self.size))
  150. @_set_function.register(FunctionUnion, Integers)
  151. def _(f, self):
  152. expr = f.expr
  153. if not isinstance(expr, Expr):
  154. return
  155. n = f.variables[0]
  156. if expr == abs(n):
  157. return S.Naturals0
  158. # f(x) + c and f(-x) + c cover the same integers
  159. # so choose the form that has the fewest negatives
  160. c = f(0)
  161. fx = f(n) - c
  162. f_x = f(-n) - c
  163. neg_count = lambda e: sum(_.could_extract_minus_sign()
  164. for _ in Add.make_args(e))
  165. if neg_count(f_x) < neg_count(fx):
  166. expr = f_x + c
  167. a = Wild('a', exclude=[n])
  168. b = Wild('b', exclude=[n])
  169. match = expr.match(a*n + b)
  170. if match and match[a] and (
  171. not match[a].atoms(Float) and
  172. not match[b].atoms(Float)):
  173. # canonical shift
  174. a, b = match[a], match[b]
  175. if a in [1, -1]:
  176. # drop integer addends in b
  177. nonint = []
  178. for bi in Add.make_args(b):
  179. if not bi.is_integer:
  180. nonint.append(bi)
  181. b = Add(*nonint)
  182. if b.is_number and a.is_real:
  183. # avoid Mod for complex numbers, #11391
  184. br, bi = match_real_imag(b)
  185. if br and br.is_comparable and a.is_comparable:
  186. br %= a
  187. b = br + S.ImaginaryUnit*bi
  188. elif b.is_number and a.is_imaginary:
  189. br, bi = match_real_imag(b)
  190. ai = a/S.ImaginaryUnit
  191. if bi and bi.is_comparable and ai.is_comparable:
  192. bi %= ai
  193. b = br + S.ImaginaryUnit*bi
  194. expr = a*n + b
  195. if expr != f.expr:
  196. return ImageSet(Lambda(n, expr), S.Integers)
  197. @_set_function.register(FunctionUnion, Naturals)
  198. def _(f, self):
  199. expr = f.expr
  200. if not isinstance(expr, Expr):
  201. return
  202. x = f.variables[0]
  203. if not expr.free_symbols - {x}:
  204. if expr == abs(x):
  205. if self is S.Naturals:
  206. return self
  207. return S.Naturals0
  208. step = expr.coeff(x)
  209. c = expr.subs(x, 0)
  210. if c.is_Integer and step.is_Integer and expr == step*x + c:
  211. if self is S.Naturals:
  212. c += step
  213. if step > 0:
  214. if step == 1:
  215. if c == 0:
  216. return S.Naturals0
  217. elif c == 1:
  218. return S.Naturals
  219. return Range(c, oo, step)
  220. return Range(c, -oo, step)
  221. @_set_function.register(FunctionUnion, Reals)
  222. def _(f, self):
  223. expr = f.expr
  224. if not isinstance(expr, Expr):
  225. return
  226. return _set_function(f, Interval(-oo, oo))