approximations.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import math
  2. from sympy.sets.sets import Interval
  3. from sympy.calculus.singularities import is_increasing, is_decreasing
  4. from sympy.codegen.rewriting import Optimization
  5. from sympy.core.function import UndefinedFunction
  6. """
  7. This module collects classes useful for approimate rewriting of expressions.
  8. This can be beneficial when generating numeric code for which performance is
  9. of greater importance than precision (e.g. for preconditioners used in iterative
  10. methods).
  11. """
  12. class SumApprox(Optimization):
  13. """
  14. Approximates sum by neglecting small terms.
  15. Explanation
  16. ===========
  17. If terms are expressions which can be determined to be monotonic, then
  18. bounds for those expressions are added.
  19. Parameters
  20. ==========
  21. bounds : dict
  22. Mapping expressions to length 2 tuple of bounds (low, high).
  23. reltol : number
  24. Threshold for when to ignore a term. Taken relative to the largest
  25. lower bound among bounds.
  26. Examples
  27. ========
  28. >>> from sympy import exp
  29. >>> from sympy.abc import x, y, z
  30. >>> from sympy.codegen.rewriting import optimize
  31. >>> from sympy.codegen.approximations import SumApprox
  32. >>> bounds = {x: (-1, 1), y: (1000, 2000), z: (-10, 3)}
  33. >>> sum_approx3 = SumApprox(bounds, reltol=1e-3)
  34. >>> sum_approx2 = SumApprox(bounds, reltol=1e-2)
  35. >>> sum_approx1 = SumApprox(bounds, reltol=1e-1)
  36. >>> expr = 3*(x + y + exp(z))
  37. >>> optimize(expr, [sum_approx3])
  38. 3*(x + y + exp(z))
  39. >>> optimize(expr, [sum_approx2])
  40. 3*y + 3*exp(z)
  41. >>> optimize(expr, [sum_approx1])
  42. 3*y
  43. """
  44. def __init__(self, bounds, reltol, **kwargs):
  45. super().__init__(**kwargs)
  46. self.bounds = bounds
  47. self.reltol = reltol
  48. def __call__(self, expr):
  49. return expr.factor().replace(self.query, lambda arg: self.value(arg))
  50. def query(self, expr):
  51. return expr.is_Add
  52. def value(self, add):
  53. for term in add.args:
  54. if term.is_number or term in self.bounds or len(term.free_symbols) != 1:
  55. continue
  56. fs, = term.free_symbols
  57. if fs not in self.bounds:
  58. continue
  59. intrvl = Interval(*self.bounds[fs])
  60. if is_increasing(term, intrvl, fs):
  61. self.bounds[term] = (
  62. term.subs({fs: self.bounds[fs][0]}),
  63. term.subs({fs: self.bounds[fs][1]})
  64. )
  65. elif is_decreasing(term, intrvl, fs):
  66. self.bounds[term] = (
  67. term.subs({fs: self.bounds[fs][1]}),
  68. term.subs({fs: self.bounds[fs][0]})
  69. )
  70. else:
  71. return add
  72. if all(term.is_number or term in self.bounds for term in add.args):
  73. bounds = [(term, term) if term.is_number else self.bounds[term] for term in add.args]
  74. largest_abs_guarantee = 0
  75. for lo, hi in bounds:
  76. if lo <= 0 <= hi:
  77. continue
  78. largest_abs_guarantee = max(largest_abs_guarantee,
  79. min(abs(lo), abs(hi)))
  80. new_terms = []
  81. for term, (lo, hi) in zip(add.args, bounds):
  82. if max(abs(lo), abs(hi)) >= largest_abs_guarantee*self.reltol:
  83. new_terms.append(term)
  84. return add.func(*new_terms)
  85. else:
  86. return add
  87. class SeriesApprox(Optimization):
  88. """ Approximates functions by expanding them as a series.
  89. Parameters
  90. ==========
  91. bounds : dict
  92. Mapping expressions to length 2 tuple of bounds (low, high).
  93. reltol : number
  94. Threshold for when to ignore a term. Taken relative to the largest
  95. lower bound among bounds.
  96. max_order : int
  97. Largest order to include in series expansion
  98. n_point_checks : int (even)
  99. The validity of an expansion (with respect to reltol) is checked at
  100. discrete points (linearly spaced over the bounds of the variable). The
  101. number of points used in this numerical check is given by this number.
  102. Examples
  103. ========
  104. >>> from sympy import sin, pi
  105. >>> from sympy.abc import x, y
  106. >>> from sympy.codegen.rewriting import optimize
  107. >>> from sympy.codegen.approximations import SeriesApprox
  108. >>> bounds = {x: (-.1, .1), y: (pi-1, pi+1)}
  109. >>> series_approx2 = SeriesApprox(bounds, reltol=1e-2)
  110. >>> series_approx3 = SeriesApprox(bounds, reltol=1e-3)
  111. >>> series_approx8 = SeriesApprox(bounds, reltol=1e-8)
  112. >>> expr = sin(x)*sin(y)
  113. >>> optimize(expr, [series_approx2])
  114. x*(-y + (y - pi)**3/6 + pi)
  115. >>> optimize(expr, [series_approx3])
  116. (-x**3/6 + x)*sin(y)
  117. >>> optimize(expr, [series_approx8])
  118. sin(x)*sin(y)
  119. """
  120. def __init__(self, bounds, reltol, max_order=4, n_point_checks=4, **kwargs):
  121. super().__init__(**kwargs)
  122. self.bounds = bounds
  123. self.reltol = reltol
  124. self.max_order = max_order
  125. if n_point_checks % 2 == 1:
  126. raise ValueError("Checking the solution at expansion point is not helpful")
  127. self.n_point_checks = n_point_checks
  128. self._prec = math.ceil(-math.log10(self.reltol))
  129. def __call__(self, expr):
  130. return expr.factor().replace(self.query, lambda arg: self.value(arg))
  131. def query(self, expr):
  132. return (expr.is_Function and not isinstance(expr, UndefinedFunction)
  133. and len(expr.args) == 1)
  134. def value(self, fexpr):
  135. free_symbols = fexpr.free_symbols
  136. if len(free_symbols) != 1:
  137. return fexpr
  138. symb, = free_symbols
  139. if symb not in self.bounds:
  140. return fexpr
  141. lo, hi = self.bounds[symb]
  142. x0 = (lo + hi)/2
  143. cheapest = None
  144. for n in range(self.max_order+1, 0, -1):
  145. fseri = fexpr.series(symb, x0=x0, n=n).removeO()
  146. n_ok = True
  147. for idx in range(self.n_point_checks):
  148. x = lo + idx*(hi - lo)/(self.n_point_checks - 1)
  149. val = fseri.xreplace({symb: x})
  150. ref = fexpr.xreplace({symb: x})
  151. if abs((1 - val/ref).evalf(self._prec)) > self.reltol:
  152. n_ok = False
  153. break
  154. if n_ok:
  155. cheapest = fseri
  156. else:
  157. break
  158. if cheapest is None:
  159. return fexpr
  160. else:
  161. return cheapest