rl.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. """ Generic Rules for SymPy
  2. This file assumes knowledge of Basic and little else.
  3. """
  4. from sympy.utilities.iterables import sift
  5. from .util import new
  6. # Functions that create rules
  7. def rm_id(isid, new=new):
  8. """ Create a rule to remove identities.
  9. isid - fn :: x -> Bool --- whether or not this element is an identity.
  10. Examples
  11. ========
  12. >>> from sympy.strategies import rm_id
  13. >>> from sympy import Basic, S
  14. >>> remove_zeros = rm_id(lambda x: x==0)
  15. >>> remove_zeros(Basic(S(1), S(0), S(2)))
  16. Basic(1, 2)
  17. >>> remove_zeros(Basic(S(0), S(0))) # If only identites then we keep one
  18. Basic(0)
  19. See Also:
  20. unpack
  21. """
  22. def ident_remove(expr):
  23. """ Remove identities """
  24. ids = list(map(isid, expr.args))
  25. if sum(ids) == 0: # No identities. Common case
  26. return expr
  27. elif sum(ids) != len(ids): # there is at least one non-identity
  28. return new(expr.__class__,
  29. *[arg for arg, x in zip(expr.args, ids) if not x])
  30. else:
  31. return new(expr.__class__, expr.args[0])
  32. return ident_remove
  33. def glom(key, count, combine):
  34. """ Create a rule to conglomerate identical args.
  35. Examples
  36. ========
  37. >>> from sympy.strategies import glom
  38. >>> from sympy import Add
  39. >>> from sympy.abc import x
  40. >>> key = lambda x: x.as_coeff_Mul()[1]
  41. >>> count = lambda x: x.as_coeff_Mul()[0]
  42. >>> combine = lambda cnt, arg: cnt * arg
  43. >>> rl = glom(key, count, combine)
  44. >>> rl(Add(x, -x, 3*x, 2, 3, evaluate=False))
  45. 3*x + 5
  46. Wait, how are key, count and combine supposed to work?
  47. >>> key(2*x)
  48. x
  49. >>> count(2*x)
  50. 2
  51. >>> combine(2, x)
  52. 2*x
  53. """
  54. def conglomerate(expr):
  55. """ Conglomerate together identical args x + x -> 2x """
  56. groups = sift(expr.args, key)
  57. counts = {k: sum(map(count, args)) for k, args in groups.items()}
  58. newargs = [combine(cnt, mat) for mat, cnt in counts.items()]
  59. if set(newargs) != set(expr.args):
  60. return new(type(expr), *newargs)
  61. else:
  62. return expr
  63. return conglomerate
  64. def sort(key, new=new):
  65. """ Create a rule to sort by a key function.
  66. Examples
  67. ========
  68. >>> from sympy.strategies import sort
  69. >>> from sympy import Basic, S
  70. >>> sort_rl = sort(str)
  71. >>> sort_rl(Basic(S(3), S(1), S(2)))
  72. Basic(1, 2, 3)
  73. """
  74. def sort_rl(expr):
  75. return new(expr.__class__, *sorted(expr.args, key=key))
  76. return sort_rl
  77. def distribute(A, B):
  78. """ Turns an A containing Bs into a B of As
  79. where A, B are container types
  80. >>> from sympy.strategies import distribute
  81. >>> from sympy import Add, Mul, symbols
  82. >>> x, y = symbols('x,y')
  83. >>> dist = distribute(Mul, Add)
  84. >>> expr = Mul(2, x+y, evaluate=False)
  85. >>> expr
  86. 2*(x + y)
  87. >>> dist(expr)
  88. 2*x + 2*y
  89. """
  90. def distribute_rl(expr):
  91. for i, arg in enumerate(expr.args):
  92. if isinstance(arg, B):
  93. first, b, tail = expr.args[:i], expr.args[i], expr.args[i+1:]
  94. return B(*[A(*(first + (arg,) + tail)) for arg in b.args])
  95. return expr
  96. return distribute_rl
  97. def subs(a, b):
  98. """ Replace expressions exactly """
  99. def subs_rl(expr):
  100. if expr == a:
  101. return b
  102. else:
  103. return expr
  104. return subs_rl
  105. # Functions that are rules
  106. def unpack(expr):
  107. """ Rule to unpack singleton args
  108. >>> from sympy.strategies import unpack
  109. >>> from sympy import Basic, S
  110. >>> unpack(Basic(S(2)))
  111. 2
  112. """
  113. if len(expr.args) == 1:
  114. return expr.args[0]
  115. else:
  116. return expr
  117. def flatten(expr, new=new):
  118. """ Flatten T(a, b, T(c, d), T2(e)) to T(a, b, c, d, T2(e)) """
  119. cls = expr.__class__
  120. args = []
  121. for arg in expr.args:
  122. if arg.__class__ == cls:
  123. args.extend(arg.args)
  124. else:
  125. args.append(arg)
  126. return new(expr.__class__, *args)
  127. def rebuild(expr):
  128. """ Rebuild a SymPy tree.
  129. Explanation
  130. ===========
  131. This function recursively calls constructors in the expression tree.
  132. This forces canonicalization and removes ugliness introduced by the use of
  133. Basic.__new__
  134. """
  135. if expr.is_Atom:
  136. return expr
  137. else:
  138. return expr.func(*list(map(rebuild, expr.args)))