rewriting.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. """
  2. Classes and functions useful for rewriting expressions for optimized code
  3. generation. Some languages (or standards thereof), e.g. C99, offer specialized
  4. math functions for better performance and/or precision.
  5. Using the ``optimize`` function in this module, together with a collection of
  6. rules (represented as instances of ``Optimization``), one can rewrite the
  7. expressions for this purpose::
  8. >>> from sympy import Symbol, exp, log
  9. >>> from sympy.codegen.rewriting import optimize, optims_c99
  10. >>> x = Symbol('x')
  11. >>> optimize(3*exp(2*x) - 3, optims_c99)
  12. 3*expm1(2*x)
  13. >>> optimize(exp(2*x) - 1 - exp(-33), optims_c99)
  14. expm1(2*x) - exp(-33)
  15. >>> optimize(log(3*x + 3), optims_c99)
  16. log1p(x) + log(3)
  17. >>> optimize(log(2*x + 3), optims_c99)
  18. log(2*x + 3)
  19. The ``optims_c99`` imported above is tuple containing the following instances
  20. (which may be imported from ``sympy.codegen.rewriting``):
  21. - ``expm1_opt``
  22. - ``log1p_opt``
  23. - ``exp2_opt``
  24. - ``log2_opt``
  25. - ``log2const_opt``
  26. """
  27. from sympy.core.function import expand_log
  28. from sympy.core.singleton import S
  29. from sympy.core.symbol import Wild
  30. from sympy.functions.elementary.complexes import sign
  31. from sympy.functions.elementary.exponential import (exp, log)
  32. from sympy.functions.elementary.miscellaneous import (Max, Min)
  33. from sympy.functions.elementary.trigonometric import (cos, sin, sinc)
  34. from sympy.assumptions import Q, ask
  35. from sympy.codegen.cfunctions import log1p, log2, exp2, expm1
  36. from sympy.codegen.matrix_nodes import MatrixSolve
  37. from sympy.core.expr import UnevaluatedExpr
  38. from sympy.core.power import Pow
  39. from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
  40. from sympy.codegen.scipy_nodes import cosm1
  41. from sympy.core.mul import Mul
  42. from sympy.matrices.expressions.matexpr import MatrixSymbol
  43. from sympy.utilities.iterables import sift
  44. class Optimization:
  45. """ Abstract base class for rewriting optimization.
  46. Subclasses should implement ``__call__`` taking an expression
  47. as argument.
  48. Parameters
  49. ==========
  50. cost_function : callable returning number
  51. priority : number
  52. """
  53. def __init__(self, cost_function=None, priority=1):
  54. self.cost_function = cost_function
  55. self.priority=priority
  56. def cheapest(self, *args):
  57. return sorted(args, key=self.cost_function)[0]
  58. class ReplaceOptim(Optimization):
  59. """ Rewriting optimization calling replace on expressions.
  60. Explanation
  61. ===========
  62. The instance can be used as a function on expressions for which
  63. it will apply the ``replace`` method (see
  64. :meth:`sympy.core.basic.Basic.replace`).
  65. Parameters
  66. ==========
  67. query :
  68. First argument passed to replace.
  69. value :
  70. Second argument passed to replace.
  71. Examples
  72. ========
  73. >>> from sympy import Symbol
  74. >>> from sympy.codegen.rewriting import ReplaceOptim
  75. >>> from sympy.codegen.cfunctions import exp2
  76. >>> x = Symbol('x')
  77. >>> exp2_opt = ReplaceOptim(lambda p: p.is_Pow and p.base == 2,
  78. ... lambda p: exp2(p.exp))
  79. >>> exp2_opt(2**x)
  80. exp2(x)
  81. """
  82. def __init__(self, query, value, **kwargs):
  83. super().__init__(**kwargs)
  84. self.query = query
  85. self.value = value
  86. def __call__(self, expr):
  87. return expr.replace(self.query, self.value)
  88. def optimize(expr, optimizations):
  89. """ Apply optimizations to an expression.
  90. Parameters
  91. ==========
  92. expr : expression
  93. optimizations : iterable of ``Optimization`` instances
  94. The optimizations will be sorted with respect to ``priority`` (highest first).
  95. Examples
  96. ========
  97. >>> from sympy import log, Symbol
  98. >>> from sympy.codegen.rewriting import optims_c99, optimize
  99. >>> x = Symbol('x')
  100. >>> optimize(log(x+3)/log(2) + log(x**2 + 1), optims_c99)
  101. log1p(x**2) + log2(x + 3)
  102. """
  103. for optim in sorted(optimizations, key=lambda opt: opt.priority, reverse=True):
  104. new_expr = optim(expr)
  105. if optim.cost_function is None:
  106. expr = new_expr
  107. else:
  108. expr = optim.cheapest(expr, new_expr)
  109. return expr
  110. exp2_opt = ReplaceOptim(
  111. lambda p: p.is_Pow and p.base == 2,
  112. lambda p: exp2(p.exp)
  113. )
  114. _d = Wild('d', properties=[lambda x: x.is_Dummy])
  115. _u = Wild('u', properties=[lambda x: not x.is_number and not x.is_Add])
  116. _v = Wild('v')
  117. _w = Wild('w')
  118. _n = Wild('n', properties=[lambda x: x.is_number])
  119. sinc_opt1 = ReplaceOptim(
  120. sin(_w)/_w, sinc(_w)
  121. )
  122. sinc_opt2 = ReplaceOptim(
  123. sin(_n*_w)/_w, _n*sinc(_n*_w)
  124. )
  125. sinc_opts = (sinc_opt1, sinc_opt2)
  126. log2_opt = ReplaceOptim(_v*log(_w)/log(2), _v*log2(_w), cost_function=lambda expr: expr.count(
  127. lambda e: ( # division & eval of transcendentals are expensive floating point operations...
  128. e.is_Pow and e.exp.is_negative # division
  129. or (isinstance(e, (log, log2)) and not e.args[0].is_number)) # transcendental
  130. )
  131. )
  132. log2const_opt = ReplaceOptim(log(2)*log2(_w), log(_w))
  133. logsumexp_2terms_opt = ReplaceOptim(
  134. lambda l: (isinstance(l, log)
  135. and l.args[0].is_Add
  136. and len(l.args[0].args) == 2
  137. and all(isinstance(t, exp) for t in l.args[0].args)),
  138. lambda l: (
  139. Max(*[e.args[0] for e in l.args[0].args]) +
  140. log1p(exp(Min(*[e.args[0] for e in l.args[0].args])))
  141. )
  142. )
  143. class FuncMinusOneOptim(ReplaceOptim):
  144. """Specialization of ReplaceOptim for functions evaluating "f(x) - 1".
  145. Explanation
  146. ===========
  147. Numerical functions which go toward one as x go toward zero is often best
  148. implemented by a dedicated function in order to avoid catastrophic
  149. cancellation. One such example is ``expm1(x)`` in the C standard library
  150. which evaluates ``exp(x) - 1``. Such functions preserves many more
  151. significant digits when its argument is much smaller than one, compared
  152. to subtracting one afterwards.
  153. Parameters
  154. ==========
  155. func :
  156. The function which is subtracted by one.
  157. func_m_1 :
  158. The specialized function evaluating ``func(x) - 1``.
  159. opportunistic : bool
  160. When ``True``, apply the transformation as long as the magnitude of the
  161. remaining number terms decreases. When ``False``, only apply the
  162. transformation if it completely eliminates the number term.
  163. Examples
  164. ========
  165. >>> from sympy import symbols, exp
  166. >>> from sympy.codegen.rewriting import FuncMinusOneOptim
  167. >>> from sympy.codegen.cfunctions import expm1
  168. >>> x, y = symbols('x y')
  169. >>> expm1_opt = FuncMinusOneOptim(exp, expm1)
  170. >>> expm1_opt(exp(x) + 2*exp(5*y) - 3)
  171. expm1(x) + 2*expm1(5*y)
  172. """
  173. def __init__(self, func, func_m_1, opportunistic=True):
  174. weight = 10 # <-- this is an arbitrary number (heuristic)
  175. super().__init__(lambda e: e.is_Add, self.replace_in_Add,
  176. cost_function=lambda expr: expr.count_ops() - weight*expr.count(func_m_1))
  177. self.func = func
  178. self.func_m_1 = func_m_1
  179. self.opportunistic = opportunistic
  180. def _group_Add_terms(self, add):
  181. numbers, non_num = sift(add.args, lambda arg: arg.is_number, binary=True)
  182. numsum = sum(numbers)
  183. terms_with_func, other = sift(non_num, lambda arg: arg.has(self.func), binary=True)
  184. return numsum, terms_with_func, other
  185. def replace_in_Add(self, e):
  186. """ passed as second argument to Basic.replace(...) """
  187. numsum, terms_with_func, other_non_num_terms = self._group_Add_terms(e)
  188. if numsum == 0:
  189. return e
  190. substituted, untouched = [], []
  191. for with_func in terms_with_func:
  192. if with_func.is_Mul:
  193. func, coeff = sift(with_func.args, lambda arg: arg.func == self.func, binary=True)
  194. if len(func) == 1 and len(coeff) == 1:
  195. func, coeff = func[0], coeff[0]
  196. else:
  197. coeff = None
  198. elif with_func.func == self.func:
  199. func, coeff = with_func, S.One
  200. else:
  201. coeff = None
  202. if coeff is not None and coeff.is_number and sign(coeff) == -sign(numsum):
  203. if self.opportunistic:
  204. do_substitute = abs(coeff+numsum) < abs(numsum)
  205. else:
  206. do_substitute = coeff+numsum == 0
  207. if do_substitute: # advantageous substitution
  208. numsum += coeff
  209. substituted.append(coeff*self.func_m_1(*func.args))
  210. continue
  211. untouched.append(with_func)
  212. return e.func(numsum, *substituted, *untouched, *other_non_num_terms)
  213. def __call__(self, expr):
  214. alt1 = super().__call__(expr)
  215. alt2 = super().__call__(expr.factor())
  216. return self.cheapest(alt1, alt2)
  217. expm1_opt = FuncMinusOneOptim(exp, expm1)
  218. cosm1_opt = FuncMinusOneOptim(cos, cosm1)
  219. log1p_opt = ReplaceOptim(
  220. lambda e: isinstance(e, log),
  221. lambda l: expand_log(l.replace(
  222. log, lambda arg: log(arg.factor())
  223. )).replace(log(_u+1), log1p(_u))
  224. )
  225. def create_expand_pow_optimization(limit, *, base_req=lambda b: b.is_symbol):
  226. """ Creates an instance of :class:`ReplaceOptim` for expanding ``Pow``.
  227. Explanation
  228. ===========
  229. The requirements for expansions are that the base needs to be a symbol
  230. and the exponent needs to be an Integer (and be less than or equal to
  231. ``limit``).
  232. Parameters
  233. ==========
  234. limit : int
  235. The highest power which is expanded into multiplication.
  236. base_req : function returning bool
  237. Requirement on base for expansion to happen, default is to return
  238. the ``is_symbol`` attribute of the base.
  239. Examples
  240. ========
  241. >>> from sympy import Symbol, sin
  242. >>> from sympy.codegen.rewriting import create_expand_pow_optimization
  243. >>> x = Symbol('x')
  244. >>> expand_opt = create_expand_pow_optimization(3)
  245. >>> expand_opt(x**5 + x**3)
  246. x**5 + x*x*x
  247. >>> expand_opt(x**5 + x**3 + sin(x)**3)
  248. x**5 + sin(x)**3 + x*x*x
  249. >>> opt2 = create_expand_pow_optimization(3, base_req=lambda b: not b.is_Function)
  250. >>> opt2((x+1)**2 + sin(x)**2)
  251. sin(x)**2 + (x + 1)*(x + 1)
  252. """
  253. return ReplaceOptim(
  254. lambda e: e.is_Pow and base_req(e.base) and e.exp.is_Integer and abs(e.exp) <= limit,
  255. lambda p: (
  256. UnevaluatedExpr(Mul(*([p.base]*+p.exp), evaluate=False)) if p.exp > 0 else
  257. 1/UnevaluatedExpr(Mul(*([p.base]*-p.exp), evaluate=False))
  258. ))
  259. # Optimization procedures for turning A**(-1) * x into MatrixSolve(A, x)
  260. def _matinv_predicate(expr):
  261. # TODO: We should be able to support more than 2 elements
  262. if expr.is_MatMul and len(expr.args) == 2:
  263. left, right = expr.args
  264. if left.is_Inverse and right.shape[1] == 1:
  265. inv_arg = left.arg
  266. if isinstance(inv_arg, MatrixSymbol):
  267. return bool(ask(Q.fullrank(left.arg)))
  268. return False
  269. def _matinv_transform(expr):
  270. left, right = expr.args
  271. inv_arg = left.arg
  272. return MatrixSolve(inv_arg, right)
  273. matinv_opt = ReplaceOptim(_matinv_predicate, _matinv_transform)
  274. logaddexp_opt = ReplaceOptim(log(exp(_v)+exp(_w)), logaddexp(_v, _w))
  275. logaddexp2_opt = ReplaceOptim(log(Pow(2, _v)+Pow(2, _w)), logaddexp2(_v, _w)*log(2))
  276. # Collections of optimizations:
  277. optims_c99 = (expm1_opt, log1p_opt, exp2_opt, log2_opt, log2const_opt)
  278. optims_numpy = optims_c99 + (logaddexp_opt, logaddexp2_opt,) + sinc_opts
  279. optims_scipy = (cosm1_opt,)