import math from sympy.sets.sets import Interval from sympy.calculus.singularities import is_increasing, is_decreasing from sympy.codegen.rewriting import Optimization from sympy.core.function import UndefinedFunction """ This module collects classes useful for approimate rewriting of expressions. This can be beneficial when generating numeric code for which performance is of greater importance than precision (e.g. for preconditioners used in iterative methods). """ class SumApprox(Optimization): """ Approximates sum by neglecting small terms. Explanation =========== If terms are expressions which can be determined to be monotonic, then bounds for those expressions are added. Parameters ========== bounds : dict Mapping expressions to length 2 tuple of bounds (low, high). reltol : number Threshold for when to ignore a term. Taken relative to the largest lower bound among bounds. Examples ======== >>> from sympy import exp >>> from sympy.abc import x, y, z >>> from sympy.codegen.rewriting import optimize >>> from sympy.codegen.approximations import SumApprox >>> bounds = {x: (-1, 1), y: (1000, 2000), z: (-10, 3)} >>> sum_approx3 = SumApprox(bounds, reltol=1e-3) >>> sum_approx2 = SumApprox(bounds, reltol=1e-2) >>> sum_approx1 = SumApprox(bounds, reltol=1e-1) >>> expr = 3*(x + y + exp(z)) >>> optimize(expr, [sum_approx3]) 3*(x + y + exp(z)) >>> optimize(expr, [sum_approx2]) 3*y + 3*exp(z) >>> optimize(expr, [sum_approx1]) 3*y """ def __init__(self, bounds, reltol, **kwargs): super().__init__(**kwargs) self.bounds = bounds self.reltol = reltol def __call__(self, expr): return expr.factor().replace(self.query, lambda arg: self.value(arg)) def query(self, expr): return expr.is_Add def value(self, add): for term in add.args: if term.is_number or term in self.bounds or len(term.free_symbols) != 1: continue fs, = term.free_symbols if fs not in self.bounds: continue intrvl = Interval(*self.bounds[fs]) if is_increasing(term, intrvl, fs): self.bounds[term] = ( term.subs({fs: self.bounds[fs][0]}), term.subs({fs: self.bounds[fs][1]}) ) elif is_decreasing(term, intrvl, fs): self.bounds[term] = ( term.subs({fs: self.bounds[fs][1]}), term.subs({fs: self.bounds[fs][0]}) ) else: return add if all(term.is_number or term in self.bounds for term in add.args): bounds = [(term, term) if term.is_number else self.bounds[term] for term in add.args] largest_abs_guarantee = 0 for lo, hi in bounds: if lo <= 0 <= hi: continue largest_abs_guarantee = max(largest_abs_guarantee, min(abs(lo), abs(hi))) new_terms = [] for term, (lo, hi) in zip(add.args, bounds): if max(abs(lo), abs(hi)) >= largest_abs_guarantee*self.reltol: new_terms.append(term) return add.func(*new_terms) else: return add class SeriesApprox(Optimization): """ Approximates functions by expanding them as a series. Parameters ========== bounds : dict Mapping expressions to length 2 tuple of bounds (low, high). reltol : number Threshold for when to ignore a term. Taken relative to the largest lower bound among bounds. max_order : int Largest order to include in series expansion n_point_checks : int (even) The validity of an expansion (with respect to reltol) is checked at discrete points (linearly spaced over the bounds of the variable). The number of points used in this numerical check is given by this number. Examples ======== >>> from sympy import sin, pi >>> from sympy.abc import x, y >>> from sympy.codegen.rewriting import optimize >>> from sympy.codegen.approximations import SeriesApprox >>> bounds = {x: (-.1, .1), y: (pi-1, pi+1)} >>> series_approx2 = SeriesApprox(bounds, reltol=1e-2) >>> series_approx3 = SeriesApprox(bounds, reltol=1e-3) >>> series_approx8 = SeriesApprox(bounds, reltol=1e-8) >>> expr = sin(x)*sin(y) >>> optimize(expr, [series_approx2]) x*(-y + (y - pi)**3/6 + pi) >>> optimize(expr, [series_approx3]) (-x**3/6 + x)*sin(y) >>> optimize(expr, [series_approx8]) sin(x)*sin(y) """ def __init__(self, bounds, reltol, max_order=4, n_point_checks=4, **kwargs): super().__init__(**kwargs) self.bounds = bounds self.reltol = reltol self.max_order = max_order if n_point_checks % 2 == 1: raise ValueError("Checking the solution at expansion point is not helpful") self.n_point_checks = n_point_checks self._prec = math.ceil(-math.log10(self.reltol)) def __call__(self, expr): return expr.factor().replace(self.query, lambda arg: self.value(arg)) def query(self, expr): return (expr.is_Function and not isinstance(expr, UndefinedFunction) and len(expr.args) == 1) def value(self, fexpr): free_symbols = fexpr.free_symbols if len(free_symbols) != 1: return fexpr symb, = free_symbols if symb not in self.bounds: return fexpr lo, hi = self.bounds[symb] x0 = (lo + hi)/2 cheapest = None for n in range(self.max_order+1, 0, -1): fseri = fexpr.series(symb, x0=x0, n=n).removeO() n_ok = True for idx in range(self.n_point_checks): x = lo + idx*(hi - lo)/(self.n_point_checks - 1) val = fseri.xreplace({symb: x}) ref = fexpr.xreplace({symb: x}) if abs((1 - val/ref).evalf(self._prec)) > self.reltol: n_ok = False break if n_ok: cheapest = fseri else: break if cheapest is None: return fexpr else: return cheapest