rewrite.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. """ Functions to support rewriting of SymPy expressions """
  2. from sympy.core.expr import Expr
  3. from sympy.assumptions import ask
  4. from sympy.strategies.tools import subs
  5. from sympy.unify.usympy import rebuild, unify
  6. def rewriterule(source, target, variables=(), condition=None, assume=None):
  7. """ Rewrite rule.
  8. Transform expressions that match source into expressions that match target
  9. treating all ``variables`` as wilds.
  10. Examples
  11. ========
  12. >>> from sympy.abc import w, x, y, z
  13. >>> from sympy.unify.rewrite import rewriterule
  14. >>> from sympy import default_sort_key
  15. >>> rl = rewriterule(x + y, x**y, [x, y])
  16. >>> sorted(rl(z + 3), key=default_sort_key)
  17. [3**z, z**3]
  18. Use ``condition`` to specify additional requirements. Inputs are taken in
  19. the same order as is found in variables.
  20. >>> rl = rewriterule(x + y, x**y, [x, y], lambda x, y: x.is_integer)
  21. >>> list(rl(z + 3))
  22. [3**z]
  23. Use ``assume`` to specify additional requirements using new assumptions.
  24. >>> from sympy.assumptions import Q
  25. >>> rl = rewriterule(x + y, x**y, [x, y], assume=Q.integer(x))
  26. >>> list(rl(z + 3))
  27. [3**z]
  28. Assumptions for the local context are provided at rule runtime
  29. >>> list(rl(w + z, Q.integer(z)))
  30. [z**w]
  31. """
  32. def rewrite_rl(expr, assumptions=True):
  33. for match in unify(source, expr, {}, variables=variables):
  34. if (condition and
  35. not condition(*[match.get(var, var) for var in variables])):
  36. continue
  37. if (assume and not ask(assume.xreplace(match), assumptions)):
  38. continue
  39. expr2 = subs(match)(target)
  40. if isinstance(expr2, Expr):
  41. expr2 = rebuild(expr2)
  42. yield expr2
  43. return rewrite_rl