cse_opts.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. """ Optimizations of the expression tree representation for better CSE
  2. opportunities.
  3. """
  4. from sympy.core import Add, Basic, Mul
  5. from sympy.core.singleton import S
  6. from sympy.core.sorting import default_sort_key
  7. from sympy.core.traversal import preorder_traversal
  8. def sub_pre(e):
  9. """ Replace y - x with -(x - y) if -1 can be extracted from y - x.
  10. """
  11. # replacing Add, A, from which -1 can be extracted with -1*-A
  12. adds = [a for a in e.atoms(Add) if a.could_extract_minus_sign()]
  13. reps = {}
  14. ignore = set()
  15. for a in adds:
  16. na = -a
  17. if na.is_Mul: # e.g. MatExpr
  18. ignore.add(a)
  19. continue
  20. reps[a] = Mul._from_args([S.NegativeOne, na])
  21. e = e.xreplace(reps)
  22. # repeat again for persisting Adds but mark these with a leading 1, -1
  23. # e.g. y - x -> 1*-1*(x - y)
  24. if isinstance(e, Basic):
  25. negs = {}
  26. for a in sorted(e.atoms(Add), key=default_sort_key):
  27. if a in ignore:
  28. continue
  29. if a in reps:
  30. negs[a] = reps[a]
  31. elif a.could_extract_minus_sign():
  32. negs[a] = Mul._from_args([S.One, S.NegativeOne, -a])
  33. e = e.xreplace(negs)
  34. return e
  35. def sub_post(e):
  36. """ Replace 1*-1*x with -x.
  37. """
  38. replacements = []
  39. for node in preorder_traversal(e):
  40. if isinstance(node, Mul) and \
  41. node.args[0] is S.One and node.args[1] is S.NegativeOne:
  42. replacements.append((node, -Mul._from_args(node.args[2:])))
  43. for node, replacement in replacements:
  44. e = e.xreplace({node: replacement})
  45. return e