123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356 |
- """
- Classes and functions useful for rewriting expressions for optimized code
- generation. Some languages (or standards thereof), e.g. C99, offer specialized
- math functions for better performance and/or precision.
- Using the ``optimize`` function in this module, together with a collection of
- rules (represented as instances of ``Optimization``), one can rewrite the
- expressions for this purpose::
- >>> from sympy import Symbol, exp, log
- >>> from sympy.codegen.rewriting import optimize, optims_c99
- >>> x = Symbol('x')
- >>> optimize(3*exp(2*x) - 3, optims_c99)
- 3*expm1(2*x)
- >>> optimize(exp(2*x) - 1 - exp(-33), optims_c99)
- expm1(2*x) - exp(-33)
- >>> optimize(log(3*x + 3), optims_c99)
- log1p(x) + log(3)
- >>> optimize(log(2*x + 3), optims_c99)
- log(2*x + 3)
- The ``optims_c99`` imported above is tuple containing the following instances
- (which may be imported from ``sympy.codegen.rewriting``):
- - ``expm1_opt``
- - ``log1p_opt``
- - ``exp2_opt``
- - ``log2_opt``
- - ``log2const_opt``
- """
- from sympy.core.function import expand_log
- from sympy.core.singleton import S
- from sympy.core.symbol import Wild
- from sympy.functions.elementary.complexes import sign
- from sympy.functions.elementary.exponential import (exp, log)
- from sympy.functions.elementary.miscellaneous import (Max, Min)
- from sympy.functions.elementary.trigonometric import (cos, sin, sinc)
- from sympy.assumptions import Q, ask
- from sympy.codegen.cfunctions import log1p, log2, exp2, expm1
- from sympy.codegen.matrix_nodes import MatrixSolve
- from sympy.core.expr import UnevaluatedExpr
- from sympy.core.power import Pow
- from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
- from sympy.codegen.scipy_nodes import cosm1
- from sympy.core.mul import Mul
- from sympy.matrices.expressions.matexpr import MatrixSymbol
- from sympy.utilities.iterables import sift
- class Optimization:
- """ Abstract base class for rewriting optimization.
- Subclasses should implement ``__call__`` taking an expression
- as argument.
- Parameters
- ==========
- cost_function : callable returning number
- priority : number
- """
- def __init__(self, cost_function=None, priority=1):
- self.cost_function = cost_function
- self.priority=priority
- def cheapest(self, *args):
- return sorted(args, key=self.cost_function)[0]
- class ReplaceOptim(Optimization):
- """ Rewriting optimization calling replace on expressions.
- Explanation
- ===========
- The instance can be used as a function on expressions for which
- it will apply the ``replace`` method (see
- :meth:`sympy.core.basic.Basic.replace`).
- Parameters
- ==========
- query :
- First argument passed to replace.
- value :
- Second argument passed to replace.
- Examples
- ========
- >>> from sympy import Symbol
- >>> from sympy.codegen.rewriting import ReplaceOptim
- >>> from sympy.codegen.cfunctions import exp2
- >>> x = Symbol('x')
- >>> exp2_opt = ReplaceOptim(lambda p: p.is_Pow and p.base == 2,
- ... lambda p: exp2(p.exp))
- >>> exp2_opt(2**x)
- exp2(x)
- """
- def __init__(self, query, value, **kwargs):
- super().__init__(**kwargs)
- self.query = query
- self.value = value
- def __call__(self, expr):
- return expr.replace(self.query, self.value)
- def optimize(expr, optimizations):
- """ Apply optimizations to an expression.
- Parameters
- ==========
- expr : expression
- optimizations : iterable of ``Optimization`` instances
- The optimizations will be sorted with respect to ``priority`` (highest first).
- Examples
- ========
- >>> from sympy import log, Symbol
- >>> from sympy.codegen.rewriting import optims_c99, optimize
- >>> x = Symbol('x')
- >>> optimize(log(x+3)/log(2) + log(x**2 + 1), optims_c99)
- log1p(x**2) + log2(x + 3)
- """
- for optim in sorted(optimizations, key=lambda opt: opt.priority, reverse=True):
- new_expr = optim(expr)
- if optim.cost_function is None:
- expr = new_expr
- else:
- expr = optim.cheapest(expr, new_expr)
- return expr
- exp2_opt = ReplaceOptim(
- lambda p: p.is_Pow and p.base == 2,
- lambda p: exp2(p.exp)
- )
- _d = Wild('d', properties=[lambda x: x.is_Dummy])
- _u = Wild('u', properties=[lambda x: not x.is_number and not x.is_Add])
- _v = Wild('v')
- _w = Wild('w')
- _n = Wild('n', properties=[lambda x: x.is_number])
- sinc_opt1 = ReplaceOptim(
- sin(_w)/_w, sinc(_w)
- )
- sinc_opt2 = ReplaceOptim(
- sin(_n*_w)/_w, _n*sinc(_n*_w)
- )
- sinc_opts = (sinc_opt1, sinc_opt2)
- log2_opt = ReplaceOptim(_v*log(_w)/log(2), _v*log2(_w), cost_function=lambda expr: expr.count(
- lambda e: ( # division & eval of transcendentals are expensive floating point operations...
- e.is_Pow and e.exp.is_negative # division
- or (isinstance(e, (log, log2)) and not e.args[0].is_number)) # transcendental
- )
- )
- log2const_opt = ReplaceOptim(log(2)*log2(_w), log(_w))
- logsumexp_2terms_opt = ReplaceOptim(
- lambda l: (isinstance(l, log)
- and l.args[0].is_Add
- and len(l.args[0].args) == 2
- and all(isinstance(t, exp) for t in l.args[0].args)),
- lambda l: (
- Max(*[e.args[0] for e in l.args[0].args]) +
- log1p(exp(Min(*[e.args[0] for e in l.args[0].args])))
- )
- )
- class FuncMinusOneOptim(ReplaceOptim):
- """Specialization of ReplaceOptim for functions evaluating "f(x) - 1".
- Explanation
- ===========
- Numerical functions which go toward one as x go toward zero is often best
- implemented by a dedicated function in order to avoid catastrophic
- cancellation. One such example is ``expm1(x)`` in the C standard library
- which evaluates ``exp(x) - 1``. Such functions preserves many more
- significant digits when its argument is much smaller than one, compared
- to subtracting one afterwards.
- Parameters
- ==========
- func :
- The function which is subtracted by one.
- func_m_1 :
- The specialized function evaluating ``func(x) - 1``.
- opportunistic : bool
- When ``True``, apply the transformation as long as the magnitude of the
- remaining number terms decreases. When ``False``, only apply the
- transformation if it completely eliminates the number term.
- Examples
- ========
- >>> from sympy import symbols, exp
- >>> from sympy.codegen.rewriting import FuncMinusOneOptim
- >>> from sympy.codegen.cfunctions import expm1
- >>> x, y = symbols('x y')
- >>> expm1_opt = FuncMinusOneOptim(exp, expm1)
- >>> expm1_opt(exp(x) + 2*exp(5*y) - 3)
- expm1(x) + 2*expm1(5*y)
- """
- def __init__(self, func, func_m_1, opportunistic=True):
- weight = 10 # <-- this is an arbitrary number (heuristic)
- super().__init__(lambda e: e.is_Add, self.replace_in_Add,
- cost_function=lambda expr: expr.count_ops() - weight*expr.count(func_m_1))
- self.func = func
- self.func_m_1 = func_m_1
- self.opportunistic = opportunistic
- def _group_Add_terms(self, add):
- numbers, non_num = sift(add.args, lambda arg: arg.is_number, binary=True)
- numsum = sum(numbers)
- terms_with_func, other = sift(non_num, lambda arg: arg.has(self.func), binary=True)
- return numsum, terms_with_func, other
- def replace_in_Add(self, e):
- """ passed as second argument to Basic.replace(...) """
- numsum, terms_with_func, other_non_num_terms = self._group_Add_terms(e)
- if numsum == 0:
- return e
- substituted, untouched = [], []
- for with_func in terms_with_func:
- if with_func.is_Mul:
- func, coeff = sift(with_func.args, lambda arg: arg.func == self.func, binary=True)
- if len(func) == 1 and len(coeff) == 1:
- func, coeff = func[0], coeff[0]
- else:
- coeff = None
- elif with_func.func == self.func:
- func, coeff = with_func, S.One
- else:
- coeff = None
- if coeff is not None and coeff.is_number and sign(coeff) == -sign(numsum):
- if self.opportunistic:
- do_substitute = abs(coeff+numsum) < abs(numsum)
- else:
- do_substitute = coeff+numsum == 0
- if do_substitute: # advantageous substitution
- numsum += coeff
- substituted.append(coeff*self.func_m_1(*func.args))
- continue
- untouched.append(with_func)
- return e.func(numsum, *substituted, *untouched, *other_non_num_terms)
- def __call__(self, expr):
- alt1 = super().__call__(expr)
- alt2 = super().__call__(expr.factor())
- return self.cheapest(alt1, alt2)
- expm1_opt = FuncMinusOneOptim(exp, expm1)
- cosm1_opt = FuncMinusOneOptim(cos, cosm1)
- log1p_opt = ReplaceOptim(
- lambda e: isinstance(e, log),
- lambda l: expand_log(l.replace(
- log, lambda arg: log(arg.factor())
- )).replace(log(_u+1), log1p(_u))
- )
- def create_expand_pow_optimization(limit, *, base_req=lambda b: b.is_symbol):
- """ Creates an instance of :class:`ReplaceOptim` for expanding ``Pow``.
- Explanation
- ===========
- The requirements for expansions are that the base needs to be a symbol
- and the exponent needs to be an Integer (and be less than or equal to
- ``limit``).
- Parameters
- ==========
- limit : int
- The highest power which is expanded into multiplication.
- base_req : function returning bool
- Requirement on base for expansion to happen, default is to return
- the ``is_symbol`` attribute of the base.
- Examples
- ========
- >>> from sympy import Symbol, sin
- >>> from sympy.codegen.rewriting import create_expand_pow_optimization
- >>> x = Symbol('x')
- >>> expand_opt = create_expand_pow_optimization(3)
- >>> expand_opt(x**5 + x**3)
- x**5 + x*x*x
- >>> expand_opt(x**5 + x**3 + sin(x)**3)
- x**5 + sin(x)**3 + x*x*x
- >>> opt2 = create_expand_pow_optimization(3, base_req=lambda b: not b.is_Function)
- >>> opt2((x+1)**2 + sin(x)**2)
- sin(x)**2 + (x + 1)*(x + 1)
- """
- return ReplaceOptim(
- lambda e: e.is_Pow and base_req(e.base) and e.exp.is_Integer and abs(e.exp) <= limit,
- lambda p: (
- UnevaluatedExpr(Mul(*([p.base]*+p.exp), evaluate=False)) if p.exp > 0 else
- 1/UnevaluatedExpr(Mul(*([p.base]*-p.exp), evaluate=False))
- ))
- # Optimization procedures for turning A**(-1) * x into MatrixSolve(A, x)
- def _matinv_predicate(expr):
- # TODO: We should be able to support more than 2 elements
- if expr.is_MatMul and len(expr.args) == 2:
- left, right = expr.args
- if left.is_Inverse and right.shape[1] == 1:
- inv_arg = left.arg
- if isinstance(inv_arg, MatrixSymbol):
- return bool(ask(Q.fullrank(left.arg)))
- return False
- def _matinv_transform(expr):
- left, right = expr.args
- inv_arg = left.arg
- return MatrixSolve(inv_arg, right)
- matinv_opt = ReplaceOptim(_matinv_predicate, _matinv_transform)
- logaddexp_opt = ReplaceOptim(log(exp(_v)+exp(_w)), logaddexp(_v, _w))
- logaddexp2_opt = ReplaceOptim(log(Pow(2, _v)+Pow(2, _w)), logaddexp2(_v, _w)*log(2))
- # Collections of optimizations:
- optims_c99 = (expm1_opt, log1p_opt, exp2_opt, log2_opt, log2const_opt)
- optims_numpy = optims_c99 + (logaddexp_opt, logaddexp2_opt,) + sinc_opts
- optims_scipy = (cosm1_opt,)
|