123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- """ Generic Rules for SymPy
- This file assumes knowledge of Basic and little else.
- """
- from sympy.utilities.iterables import sift
- from .util import new
- # Functions that create rules
- def rm_id(isid, new=new):
- """ Create a rule to remove identities.
- isid - fn :: x -> Bool --- whether or not this element is an identity.
- Examples
- ========
- >>> from sympy.strategies import rm_id
- >>> from sympy import Basic, S
- >>> remove_zeros = rm_id(lambda x: x==0)
- >>> remove_zeros(Basic(S(1), S(0), S(2)))
- Basic(1, 2)
- >>> remove_zeros(Basic(S(0), S(0))) # If only identites then we keep one
- Basic(0)
- See Also:
- unpack
- """
- def ident_remove(expr):
- """ Remove identities """
- ids = list(map(isid, expr.args))
- if sum(ids) == 0: # No identities. Common case
- return expr
- elif sum(ids) != len(ids): # there is at least one non-identity
- return new(expr.__class__,
- *[arg for arg, x in zip(expr.args, ids) if not x])
- else:
- return new(expr.__class__, expr.args[0])
- return ident_remove
- def glom(key, count, combine):
- """ Create a rule to conglomerate identical args.
- Examples
- ========
- >>> from sympy.strategies import glom
- >>> from sympy import Add
- >>> from sympy.abc import x
- >>> key = lambda x: x.as_coeff_Mul()[1]
- >>> count = lambda x: x.as_coeff_Mul()[0]
- >>> combine = lambda cnt, arg: cnt * arg
- >>> rl = glom(key, count, combine)
- >>> rl(Add(x, -x, 3*x, 2, 3, evaluate=False))
- 3*x + 5
- Wait, how are key, count and combine supposed to work?
- >>> key(2*x)
- x
- >>> count(2*x)
- 2
- >>> combine(2, x)
- 2*x
- """
- def conglomerate(expr):
- """ Conglomerate together identical args x + x -> 2x """
- groups = sift(expr.args, key)
- counts = {k: sum(map(count, args)) for k, args in groups.items()}
- newargs = [combine(cnt, mat) for mat, cnt in counts.items()]
- if set(newargs) != set(expr.args):
- return new(type(expr), *newargs)
- else:
- return expr
- return conglomerate
- def sort(key, new=new):
- """ Create a rule to sort by a key function.
- Examples
- ========
- >>> from sympy.strategies import sort
- >>> from sympy import Basic, S
- >>> sort_rl = sort(str)
- >>> sort_rl(Basic(S(3), S(1), S(2)))
- Basic(1, 2, 3)
- """
- def sort_rl(expr):
- return new(expr.__class__, *sorted(expr.args, key=key))
- return sort_rl
- def distribute(A, B):
- """ Turns an A containing Bs into a B of As
- where A, B are container types
- >>> from sympy.strategies import distribute
- >>> from sympy import Add, Mul, symbols
- >>> x, y = symbols('x,y')
- >>> dist = distribute(Mul, Add)
- >>> expr = Mul(2, x+y, evaluate=False)
- >>> expr
- 2*(x + y)
- >>> dist(expr)
- 2*x + 2*y
- """
- def distribute_rl(expr):
- for i, arg in enumerate(expr.args):
- if isinstance(arg, B):
- first, b, tail = expr.args[:i], expr.args[i], expr.args[i+1:]
- return B(*[A(*(first + (arg,) + tail)) for arg in b.args])
- return expr
- return distribute_rl
- def subs(a, b):
- """ Replace expressions exactly """
- def subs_rl(expr):
- if expr == a:
- return b
- else:
- return expr
- return subs_rl
- # Functions that are rules
- def unpack(expr):
- """ Rule to unpack singleton args
- >>> from sympy.strategies import unpack
- >>> from sympy import Basic, S
- >>> unpack(Basic(S(2)))
- 2
- """
- if len(expr.args) == 1:
- return expr.args[0]
- else:
- return expr
- def flatten(expr, new=new):
- """ Flatten T(a, b, T(c, d), T2(e)) to T(a, b, c, d, T2(e)) """
- cls = expr.__class__
- args = []
- for arg in expr.args:
- if arg.__class__ == cls:
- args.extend(arg.args)
- else:
- args.append(arg)
- return new(expr.__class__, *args)
- def rebuild(expr):
- """ Rebuild a SymPy tree.
- Explanation
- ===========
- This function recursively calls constructors in the expression tree.
- This forces canonicalization and removes ugliness introduced by the use of
- Basic.__new__
- """
- if expr.is_Atom:
- return expr
- else:
- return expr.func(*list(map(rebuild, expr.args)))
|