power.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. from sympy.core import Basic, Expr
  2. from sympy.core.function import Lambda
  3. from sympy.core.numbers import oo, Infinity, NegativeInfinity, Zero, Integer
  4. from sympy.core.singleton import S
  5. from sympy.core.symbol import symbols
  6. from sympy.functions.elementary.miscellaneous import (Max, Min)
  7. from sympy.sets.fancysets import ImageSet
  8. from sympy.sets.setexpr import set_div
  9. from sympy.sets.sets import Set, Interval, FiniteSet, Union
  10. from sympy.multipledispatch import Dispatcher
  11. _x, _y = symbols("x y")
  12. _set_pow = Dispatcher('_set_pow')
  13. @_set_pow.register(Basic, Basic)
  14. def _(x, y):
  15. return None
  16. @_set_pow.register(Set, Set)
  17. def _(x, y):
  18. return ImageSet(Lambda((_x, _y), (_x ** _y)), x, y)
  19. @_set_pow.register(Expr, Expr)
  20. def _(x, y):
  21. return x**y
  22. @_set_pow.register(Interval, Zero)
  23. def _(x, z):
  24. return FiniteSet(S.One)
  25. @_set_pow.register(Interval, Integer)
  26. def _(x, exponent):
  27. """
  28. Powers in interval arithmetic
  29. https://en.wikipedia.org/wiki/Interval_arithmetic
  30. """
  31. s1 = x.start**exponent
  32. s2 = x.end**exponent
  33. if ((s2 > s1) if exponent > 0 else (x.end > -x.start)) == True:
  34. left_open = x.left_open
  35. right_open = x.right_open
  36. # TODO: handle unevaluated condition.
  37. sleft = s2
  38. else:
  39. # TODO: `s2 > s1` could be unevaluated.
  40. left_open = x.right_open
  41. right_open = x.left_open
  42. sleft = s1
  43. if x.start.is_positive:
  44. return Interval(
  45. Min(s1, s2),
  46. Max(s1, s2), left_open, right_open)
  47. elif x.end.is_negative:
  48. return Interval(
  49. Min(s1, s2),
  50. Max(s1, s2), left_open, right_open)
  51. # Case where x.start < 0 and x.end > 0:
  52. if exponent.is_odd:
  53. if exponent.is_negative:
  54. if x.start.is_zero:
  55. return Interval(s2, oo, x.right_open)
  56. if x.end.is_zero:
  57. return Interval(-oo, s1, True, x.left_open)
  58. return Union(Interval(-oo, s1, True, x.left_open), Interval(s2, oo, x.right_open))
  59. else:
  60. return Interval(s1, s2, x.left_open, x.right_open)
  61. elif exponent.is_even:
  62. if exponent.is_negative:
  63. if x.start.is_zero:
  64. return Interval(s2, oo, x.right_open)
  65. if x.end.is_zero:
  66. return Interval(s1, oo, x.left_open)
  67. return Interval(0, oo)
  68. else:
  69. return Interval(S.Zero, sleft, S.Zero not in x, left_open)
  70. @_set_pow.register(Interval, Infinity)
  71. def _(b, e):
  72. # TODO: add logic for open intervals?
  73. if b.start.is_nonnegative:
  74. if b.end < 1:
  75. return FiniteSet(S.Zero)
  76. if b.start > 1:
  77. return FiniteSet(S.Infinity)
  78. return Interval(0, oo)
  79. elif b.end.is_negative:
  80. if b.start > -1:
  81. return FiniteSet(S.Zero)
  82. if b.end < -1:
  83. return FiniteSet(-oo, oo)
  84. return Interval(-oo, oo)
  85. else:
  86. if b.start > -1:
  87. if b.end < 1:
  88. return FiniteSet(S.Zero)
  89. return Interval(0, oo)
  90. return Interval(-oo, oo)
  91. @_set_pow.register(Interval, NegativeInfinity)
  92. def _(b, e):
  93. return _set_pow(set_div(S.One, b), oo)