123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- import math
- from sympy.sets.sets import Interval
- from sympy.calculus.singularities import is_increasing, is_decreasing
- from sympy.codegen.rewriting import Optimization
- from sympy.core.function import UndefinedFunction
- """
- This module collects classes useful for approimate rewriting of expressions.
- This can be beneficial when generating numeric code for which performance is
- of greater importance than precision (e.g. for preconditioners used in iterative
- methods).
- """
- class SumApprox(Optimization):
- """
- Approximates sum by neglecting small terms.
- Explanation
- ===========
- If terms are expressions which can be determined to be monotonic, then
- bounds for those expressions are added.
- Parameters
- ==========
- bounds : dict
- Mapping expressions to length 2 tuple of bounds (low, high).
- reltol : number
- Threshold for when to ignore a term. Taken relative to the largest
- lower bound among bounds.
- Examples
- ========
- >>> from sympy import exp
- >>> from sympy.abc import x, y, z
- >>> from sympy.codegen.rewriting import optimize
- >>> from sympy.codegen.approximations import SumApprox
- >>> bounds = {x: (-1, 1), y: (1000, 2000), z: (-10, 3)}
- >>> sum_approx3 = SumApprox(bounds, reltol=1e-3)
- >>> sum_approx2 = SumApprox(bounds, reltol=1e-2)
- >>> sum_approx1 = SumApprox(bounds, reltol=1e-1)
- >>> expr = 3*(x + y + exp(z))
- >>> optimize(expr, [sum_approx3])
- 3*(x + y + exp(z))
- >>> optimize(expr, [sum_approx2])
- 3*y + 3*exp(z)
- >>> optimize(expr, [sum_approx1])
- 3*y
- """
- def __init__(self, bounds, reltol, **kwargs):
- super().__init__(**kwargs)
- self.bounds = bounds
- self.reltol = reltol
- def __call__(self, expr):
- return expr.factor().replace(self.query, lambda arg: self.value(arg))
- def query(self, expr):
- return expr.is_Add
- def value(self, add):
- for term in add.args:
- if term.is_number or term in self.bounds or len(term.free_symbols) != 1:
- continue
- fs, = term.free_symbols
- if fs not in self.bounds:
- continue
- intrvl = Interval(*self.bounds[fs])
- if is_increasing(term, intrvl, fs):
- self.bounds[term] = (
- term.subs({fs: self.bounds[fs][0]}),
- term.subs({fs: self.bounds[fs][1]})
- )
- elif is_decreasing(term, intrvl, fs):
- self.bounds[term] = (
- term.subs({fs: self.bounds[fs][1]}),
- term.subs({fs: self.bounds[fs][0]})
- )
- else:
- return add
- if all(term.is_number or term in self.bounds for term in add.args):
- bounds = [(term, term) if term.is_number else self.bounds[term] for term in add.args]
- largest_abs_guarantee = 0
- for lo, hi in bounds:
- if lo <= 0 <= hi:
- continue
- largest_abs_guarantee = max(largest_abs_guarantee,
- min(abs(lo), abs(hi)))
- new_terms = []
- for term, (lo, hi) in zip(add.args, bounds):
- if max(abs(lo), abs(hi)) >= largest_abs_guarantee*self.reltol:
- new_terms.append(term)
- return add.func(*new_terms)
- else:
- return add
- class SeriesApprox(Optimization):
- """ Approximates functions by expanding them as a series.
- Parameters
- ==========
- bounds : dict
- Mapping expressions to length 2 tuple of bounds (low, high).
- reltol : number
- Threshold for when to ignore a term. Taken relative to the largest
- lower bound among bounds.
- max_order : int
- Largest order to include in series expansion
- n_point_checks : int (even)
- The validity of an expansion (with respect to reltol) is checked at
- discrete points (linearly spaced over the bounds of the variable). The
- number of points used in this numerical check is given by this number.
- Examples
- ========
- >>> from sympy import sin, pi
- >>> from sympy.abc import x, y
- >>> from sympy.codegen.rewriting import optimize
- >>> from sympy.codegen.approximations import SeriesApprox
- >>> bounds = {x: (-.1, .1), y: (pi-1, pi+1)}
- >>> series_approx2 = SeriesApprox(bounds, reltol=1e-2)
- >>> series_approx3 = SeriesApprox(bounds, reltol=1e-3)
- >>> series_approx8 = SeriesApprox(bounds, reltol=1e-8)
- >>> expr = sin(x)*sin(y)
- >>> optimize(expr, [series_approx2])
- x*(-y + (y - pi)**3/6 + pi)
- >>> optimize(expr, [series_approx3])
- (-x**3/6 + x)*sin(y)
- >>> optimize(expr, [series_approx8])
- sin(x)*sin(y)
- """
- def __init__(self, bounds, reltol, max_order=4, n_point_checks=4, **kwargs):
- super().__init__(**kwargs)
- self.bounds = bounds
- self.reltol = reltol
- self.max_order = max_order
- if n_point_checks % 2 == 1:
- raise ValueError("Checking the solution at expansion point is not helpful")
- self.n_point_checks = n_point_checks
- self._prec = math.ceil(-math.log10(self.reltol))
- def __call__(self, expr):
- return expr.factor().replace(self.query, lambda arg: self.value(arg))
- def query(self, expr):
- return (expr.is_Function and not isinstance(expr, UndefinedFunction)
- and len(expr.args) == 1)
- def value(self, fexpr):
- free_symbols = fexpr.free_symbols
- if len(free_symbols) != 1:
- return fexpr
- symb, = free_symbols
- if symb not in self.bounds:
- return fexpr
- lo, hi = self.bounds[symb]
- x0 = (lo + hi)/2
- cheapest = None
- for n in range(self.max_order+1, 0, -1):
- fseri = fexpr.series(symb, x0=x0, n=n).removeO()
- n_ok = True
- for idx in range(self.n_point_checks):
- x = lo + idx*(hi - lo)/(self.n_point_checks - 1)
- val = fseri.xreplace({symb: x})
- ref = fexpr.xreplace({symb: x})
- if abs((1 - val/ref).evalf(self._prec)) > self.reltol:
- n_ok = False
- break
- if n_ok:
- cheapest = fseri
- else:
- break
- if cheapest is None:
- return fexpr
- else:
- return cheapest
|