123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920 |
- """ Tools for doing common subexpression elimination.
- """
- from sympy.core import Basic, Mul, Add, Pow, sympify
- from sympy.core.containers import Tuple, OrderedSet
- from sympy.core.exprtools import factor_terms
- from sympy.core.singleton import S
- from sympy.core.sorting import ordered
- from sympy.core.symbol import symbols, Symbol
- from sympy.utilities.iterables import numbered_symbols, sift, \
- topological_sort, iterable
- from . import cse_opts
- # (preprocessor, postprocessor) pairs which are commonly useful. They should
- # each take a SymPy expression and return a possibly transformed expression.
- # When used in the function ``cse()``, the target expressions will be transformed
- # by each of the preprocessor functions in order. After the common
- # subexpressions are eliminated, each resulting expression will have the
- # postprocessor functions transform them in *reverse* order in order to undo the
- # transformation if necessary. This allows the algorithm to operate on
- # a representation of the expressions that allows for more optimization
- # opportunities.
- # ``None`` can be used to specify no transformation for either the preprocessor or
- # postprocessor.
- basic_optimizations = [(cse_opts.sub_pre, cse_opts.sub_post),
- (factor_terms, None)]
- # sometimes we want the output in a different format; non-trivial
- # transformations can be put here for users
- # ===============================================================
- def reps_toposort(r):
- """Sort replacements ``r`` so (k1, v1) appears before (k2, v2)
- if k2 is in v1's free symbols. This orders items in the
- way that cse returns its results (hence, in order to use the
- replacements in a substitution option it would make sense
- to reverse the order).
- Examples
- ========
- >>> from sympy.simplify.cse_main import reps_toposort
- >>> from sympy.abc import x, y
- >>> from sympy import Eq
- >>> for l, r in reps_toposort([(x, y + 1), (y, 2)]):
- ... print(Eq(l, r))
- ...
- Eq(y, 2)
- Eq(x, y + 1)
- """
- r = sympify(r)
- E = []
- for c1, (k1, v1) in enumerate(r):
- for c2, (k2, v2) in enumerate(r):
- if k1 in v2.free_symbols:
- E.append((c1, c2))
- return [r[i] for i in topological_sort((range(len(r)), E))]
- def cse_separate(r, e):
- """Move expressions that are in the form (symbol, expr) out of the
- expressions and sort them into the replacements using the reps_toposort.
- Examples
- ========
- >>> from sympy.simplify.cse_main import cse_separate
- >>> from sympy.abc import x, y, z
- >>> from sympy import cos, exp, cse, Eq, symbols
- >>> x0, x1 = symbols('x:2')
- >>> eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1))
- >>> cse([eq, Eq(x, z + 1), z - 2], postprocess=cse_separate) in [
- ... [[(x0, y + 1), (x, z + 1), (x1, x + 1)],
- ... [x1 + exp(x1/x0) + cos(x0), z - 2]],
- ... [[(x1, y + 1), (x, z + 1), (x0, x + 1)],
- ... [x0 + exp(x0/x1) + cos(x1), z - 2]]]
- ...
- True
- """
- d = sift(e, lambda w: w.is_Equality and w.lhs.is_Symbol)
- r = r + [w.args for w in d[True]]
- e = d[False]
- return [reps_toposort(r), e]
- def cse_release_variables(r, e):
- """
- Return tuples giving ``(a, b)`` where ``a`` is a symbol and ``b`` is
- either an expression or None. The value of None is used when a
- symbol is no longer needed for subsequent expressions.
- Use of such output can reduce the memory footprint of lambdified
- expressions that contain large, repeated subexpressions.
- Examples
- ========
- >>> from sympy import cse
- >>> from sympy.simplify.cse_main import cse_release_variables
- >>> from sympy.abc import x, y
- >>> eqs = [(x + y - 1)**2, x, x + y, (x + y)/(2*x + 1) + (x + y - 1)**2, (2*x + 1)**(x + y)]
- >>> defs, rvs = cse_release_variables(*cse(eqs))
- >>> for i in defs:
- ... print(i)
- ...
- (x0, x + y)
- (x1, (x0 - 1)**2)
- (x2, 2*x + 1)
- (_3, x0/x2 + x1)
- (_4, x2**x0)
- (x2, None)
- (_0, x1)
- (x1, None)
- (_2, x0)
- (x0, None)
- (_1, x)
- >>> print(rvs)
- (_0, _1, _2, _3, _4)
- """
- if not r:
- return r, e
- s, p = zip(*r)
- esyms = symbols('_:%d' % len(e))
- syms = list(esyms)
- s = list(s)
- in_use = set(s)
- p = list(p)
- # sort e so those with most sub-expressions appear first
- e = [(e[i], syms[i]) for i in range(len(e))]
- e, syms = zip(*sorted(e,
- key=lambda x: -sum([p[s.index(i)].count_ops()
- for i in x[0].free_symbols & in_use])))
- syms = list(syms)
- p += e
- rv = []
- i = len(p) - 1
- while i >= 0:
- _p = p.pop()
- c = in_use & _p.free_symbols
- if c: # sorting for canonical results
- rv.extend([(s, None) for s in sorted(c, key=str)])
- if i >= len(r):
- rv.append((syms.pop(), _p))
- else:
- rv.append((s[i], _p))
- in_use -= c
- i -= 1
- rv.reverse()
- return rv, esyms
- # ====end of cse postprocess idioms===========================
- def preprocess_for_cse(expr, optimizations):
- """ Preprocess an expression to optimize for common subexpression
- elimination.
- Parameters
- ==========
- expr : SymPy expression
- The target expression to optimize.
- optimizations : list of (callable, callable) pairs
- The (preprocessor, postprocessor) pairs.
- Returns
- =======
- expr : SymPy expression
- The transformed expression.
- """
- for pre, post in optimizations:
- if pre is not None:
- expr = pre(expr)
- return expr
- def postprocess_for_cse(expr, optimizations):
- """Postprocess an expression after common subexpression elimination to
- return the expression to canonical SymPy form.
- Parameters
- ==========
- expr : SymPy expression
- The target expression to transform.
- optimizations : list of (callable, callable) pairs, optional
- The (preprocessor, postprocessor) pairs. The postprocessors will be
- applied in reversed order to undo the effects of the preprocessors
- correctly.
- Returns
- =======
- expr : SymPy expression
- The transformed expression.
- """
- for pre, post in reversed(optimizations):
- if post is not None:
- expr = post(expr)
- return expr
- class FuncArgTracker:
- """
- A class which manages a mapping from functions to arguments and an inverse
- mapping from arguments to functions.
- """
- def __init__(self, funcs):
- # To minimize the number of symbolic comparisons, all function arguments
- # get assigned a value number.
- self.value_numbers = {}
- self.value_number_to_value = []
- # Both of these maps use integer indices for arguments / functions.
- self.arg_to_funcset = []
- self.func_to_argset = []
- for func_i, func in enumerate(funcs):
- func_argset = OrderedSet()
- for func_arg in func.args:
- arg_number = self.get_or_add_value_number(func_arg)
- func_argset.add(arg_number)
- self.arg_to_funcset[arg_number].add(func_i)
- self.func_to_argset.append(func_argset)
- def get_args_in_value_order(self, argset):
- """
- Return the list of arguments in sorted order according to their value
- numbers.
- """
- return [self.value_number_to_value[argn] for argn in sorted(argset)]
- def get_or_add_value_number(self, value):
- """
- Return the value number for the given argument.
- """
- nvalues = len(self.value_numbers)
- value_number = self.value_numbers.setdefault(value, nvalues)
- if value_number == nvalues:
- self.value_number_to_value.append(value)
- self.arg_to_funcset.append(OrderedSet())
- return value_number
- def stop_arg_tracking(self, func_i):
- """
- Remove the function func_i from the argument to function mapping.
- """
- for arg in self.func_to_argset[func_i]:
- self.arg_to_funcset[arg].remove(func_i)
- def get_common_arg_candidates(self, argset, min_func_i=0):
- """Return a dict whose keys are function numbers. The entries of the dict are
- the number of arguments said function has in common with
- ``argset``. Entries have at least 2 items in common. All keys have
- value at least ``min_func_i``.
- """
- from collections import defaultdict
- count_map = defaultdict(lambda: 0)
- if not argset:
- return count_map
- funcsets = [self.arg_to_funcset[arg] for arg in argset]
- # As an optimization below, we handle the largest funcset separately from
- # the others.
- largest_funcset = max(funcsets, key=len)
- for funcset in funcsets:
- if largest_funcset is funcset:
- continue
- for func_i in funcset:
- if func_i >= min_func_i:
- count_map[func_i] += 1
- # We pick the smaller of the two containers (count_map, largest_funcset)
- # to iterate over to reduce the number of iterations needed.
- (smaller_funcs_container,
- larger_funcs_container) = sorted(
- [largest_funcset, count_map],
- key=len)
- for func_i in smaller_funcs_container:
- # Not already in count_map? It can't possibly be in the output, so
- # skip it.
- if count_map[func_i] < 1:
- continue
- if func_i in larger_funcs_container:
- count_map[func_i] += 1
- return {k: v for k, v in count_map.items() if v >= 2}
- def get_subset_candidates(self, argset, restrict_to_funcset=None):
- """
- Return a set of functions each of which whose argument list contains
- ``argset``, optionally filtered only to contain functions in
- ``restrict_to_funcset``.
- """
- iarg = iter(argset)
- indices = OrderedSet(
- fi for fi in self.arg_to_funcset[next(iarg)])
- if restrict_to_funcset is not None:
- indices &= restrict_to_funcset
- for arg in iarg:
- indices &= self.arg_to_funcset[arg]
- return indices
- def update_func_argset(self, func_i, new_argset):
- """
- Update a function with a new set of arguments.
- """
- new_args = OrderedSet(new_argset)
- old_args = self.func_to_argset[func_i]
- for deleted_arg in old_args - new_args:
- self.arg_to_funcset[deleted_arg].remove(func_i)
- for added_arg in new_args - old_args:
- self.arg_to_funcset[added_arg].add(func_i)
- self.func_to_argset[func_i].clear()
- self.func_to_argset[func_i].update(new_args)
- class Unevaluated:
- def __init__(self, func, args):
- self.func = func
- self.args = args
- def __str__(self):
- return "Uneval<{}>({})".format(
- self.func, ", ".join(str(a) for a in self.args))
- def as_unevaluated_basic(self):
- return self.func(*self.args, evaluate=False)
- @property
- def free_symbols(self):
- return set().union(*[a.free_symbols for a in self.args])
- __repr__ = __str__
- def match_common_args(func_class, funcs, opt_subs):
- """
- Recognize and extract common subexpressions of function arguments within a
- set of function calls. For instance, for the following function calls::
- x + z + y
- sin(x + y)
- this will extract a common subexpression of `x + y`::
- w = x + y
- w + z
- sin(w)
- The function we work with is assumed to be associative and commutative.
- Parameters
- ==========
- func_class: class
- The function class (e.g. Add, Mul)
- funcs: list of functions
- A list of function calls.
- opt_subs: dict
- A dictionary of substitutions which this function may update.
- """
- # Sort to ensure that whole-function subexpressions come before the items
- # that use them.
- funcs = sorted(funcs, key=lambda f: len(f.args))
- arg_tracker = FuncArgTracker(funcs)
- changed = OrderedSet()
- for i in range(len(funcs)):
- common_arg_candidates_counts = arg_tracker.get_common_arg_candidates(
- arg_tracker.func_to_argset[i], min_func_i=i + 1)
- # Sort the candidates in order of match size.
- # This makes us try combining smaller matches first.
- common_arg_candidates = OrderedSet(sorted(
- common_arg_candidates_counts.keys(),
- key=lambda k: (common_arg_candidates_counts[k], k)))
- while common_arg_candidates:
- j = common_arg_candidates.pop(last=False)
- com_args = arg_tracker.func_to_argset[i].intersection(
- arg_tracker.func_to_argset[j])
- if len(com_args) <= 1:
- # This may happen if a set of common arguments was already
- # combined in a previous iteration.
- continue
- # For all sets, replace the common symbols by the function
- # over them, to allow recursive matches.
- diff_i = arg_tracker.func_to_argset[i].difference(com_args)
- if diff_i:
- # com_func needs to be unevaluated to allow for recursive matches.
- com_func = Unevaluated(
- func_class, arg_tracker.get_args_in_value_order(com_args))
- com_func_number = arg_tracker.get_or_add_value_number(com_func)
- arg_tracker.update_func_argset(i, diff_i | OrderedSet([com_func_number]))
- changed.add(i)
- else:
- # Treat the whole expression as a CSE.
- #
- # The reason this needs to be done is somewhat subtle. Within
- # tree_cse(), to_eliminate only contains expressions that are
- # seen more than once. The problem is unevaluated expressions
- # do not compare equal to the evaluated equivalent. So
- # tree_cse() won't mark funcs[i] as a CSE if we use an
- # unevaluated version.
- com_func_number = arg_tracker.get_or_add_value_number(funcs[i])
- diff_j = arg_tracker.func_to_argset[j].difference(com_args)
- arg_tracker.update_func_argset(j, diff_j | OrderedSet([com_func_number]))
- changed.add(j)
- for k in arg_tracker.get_subset_candidates(
- com_args, common_arg_candidates):
- diff_k = arg_tracker.func_to_argset[k].difference(com_args)
- arg_tracker.update_func_argset(k, diff_k | OrderedSet([com_func_number]))
- changed.add(k)
- if i in changed:
- opt_subs[funcs[i]] = Unevaluated(func_class,
- arg_tracker.get_args_in_value_order(arg_tracker.func_to_argset[i]))
- arg_tracker.stop_arg_tracking(i)
- def opt_cse(exprs, order='canonical'):
- """Find optimization opportunities in Adds, Muls, Pows and negative
- coefficient Muls.
- Parameters
- ==========
- exprs : list of SymPy expressions
- The expressions to optimize.
- order : string, 'none' or 'canonical'
- The order by which Mul and Add arguments are processed. For large
- expressions where speed is a concern, use the setting order='none'.
- Returns
- =======
- opt_subs : dictionary of expression substitutions
- The expression substitutions which can be useful to optimize CSE.
- Examples
- ========
- >>> from sympy.simplify.cse_main import opt_cse
- >>> from sympy.abc import x
- >>> opt_subs = opt_cse([x**-2])
- >>> k, v = list(opt_subs.keys())[0], list(opt_subs.values())[0]
- >>> print((k, v.as_unevaluated_basic()))
- (x**(-2), 1/(x**2))
- """
- from sympy.matrices.expressions import MatAdd, MatMul, MatPow
- opt_subs = dict()
- adds = OrderedSet()
- muls = OrderedSet()
- seen_subexp = set()
- def _find_opts(expr):
- if not isinstance(expr, (Basic, Unevaluated)):
- return
- if expr.is_Atom or expr.is_Order:
- return
- if iterable(expr):
- list(map(_find_opts, expr))
- return
- if expr in seen_subexp:
- return expr
- seen_subexp.add(expr)
- list(map(_find_opts, expr.args))
- if expr.could_extract_minus_sign():
- neg_expr = -expr
- if not neg_expr.is_Atom:
- opt_subs[expr] = Unevaluated(Mul, (S.NegativeOne, neg_expr))
- seen_subexp.add(neg_expr)
- expr = neg_expr
- if isinstance(expr, (Mul, MatMul)):
- muls.add(expr)
- elif isinstance(expr, (Add, MatAdd)):
- adds.add(expr)
- elif isinstance(expr, (Pow, MatPow)):
- base, exp = expr.base, expr.exp
- if exp.could_extract_minus_sign():
- opt_subs[expr] = Unevaluated(Pow, (Pow(base, -exp), -1))
- for e in exprs:
- if isinstance(e, (Basic, Unevaluated)):
- _find_opts(e)
- # split muls into commutative
- commutative_muls = OrderedSet()
- for m in muls:
- c, nc = m.args_cnc(cset=False)
- if c:
- c_mul = m.func(*c)
- if nc:
- if c_mul == 1:
- new_obj = m.func(*nc)
- else:
- new_obj = m.func(c_mul, m.func(*nc), evaluate=False)
- opt_subs[m] = new_obj
- if len(c) > 1:
- commutative_muls.add(c_mul)
- match_common_args(Add, adds, opt_subs)
- match_common_args(Mul, commutative_muls, opt_subs)
- return opt_subs
- def tree_cse(exprs, symbols, opt_subs=None, order='canonical', ignore=()):
- """Perform raw CSE on expression tree, taking opt_subs into account.
- Parameters
- ==========
- exprs : list of SymPy expressions
- The expressions to reduce.
- symbols : infinite iterator yielding unique Symbols
- The symbols used to label the common subexpressions which are pulled
- out.
- opt_subs : dictionary of expression substitutions
- The expressions to be substituted before any CSE action is performed.
- order : string, 'none' or 'canonical'
- The order by which Mul and Add arguments are processed. For large
- expressions where speed is a concern, use the setting order='none'.
- ignore : iterable of Symbols
- Substitutions containing any Symbol from ``ignore`` will be ignored.
- """
- from sympy.matrices.expressions import MatrixExpr, MatrixSymbol, MatMul, MatAdd
- from sympy.matrices.expressions.matexpr import MatrixElement
- from sympy.polys.rootoftools import RootOf
- if opt_subs is None:
- opt_subs = dict()
- ## Find repeated sub-expressions
- to_eliminate = set()
- seen_subexp = set()
- excluded_symbols = set()
- def _find_repeated(expr):
- if not isinstance(expr, (Basic, Unevaluated)):
- return
- if isinstance(expr, RootOf):
- return
- if isinstance(expr, Basic) and (
- expr.is_Atom or
- expr.is_Order or
- isinstance(expr, (MatrixSymbol, MatrixElement))):
- if expr.is_Symbol:
- excluded_symbols.add(expr)
- return
- if iterable(expr):
- args = expr
- else:
- if expr in seen_subexp:
- for ign in ignore:
- if ign in expr.free_symbols:
- break
- else:
- to_eliminate.add(expr)
- return
- seen_subexp.add(expr)
- if expr in opt_subs:
- expr = opt_subs[expr]
- args = expr.args
- list(map(_find_repeated, args))
- for e in exprs:
- if isinstance(e, Basic):
- _find_repeated(e)
- ## Rebuild tree
- # Remove symbols from the generator that conflict with names in the expressions.
- symbols = (symbol for symbol in symbols if symbol not in excluded_symbols)
- replacements = []
- subs = dict()
- def _rebuild(expr):
- if not isinstance(expr, (Basic, Unevaluated)):
- return expr
- if not expr.args:
- return expr
- if iterable(expr):
- new_args = [_rebuild(arg) for arg in expr]
- return expr.func(*new_args)
- if expr in subs:
- return subs[expr]
- orig_expr = expr
- if expr in opt_subs:
- expr = opt_subs[expr]
- # If enabled, parse Muls and Adds arguments by order to ensure
- # replacement order independent from hashes
- if order != 'none':
- if isinstance(expr, (Mul, MatMul)):
- c, nc = expr.args_cnc()
- if c == [1]:
- args = nc
- else:
- args = list(ordered(c)) + nc
- elif isinstance(expr, (Add, MatAdd)):
- args = list(ordered(expr.args))
- else:
- args = expr.args
- else:
- args = expr.args
- new_args = list(map(_rebuild, args))
- if isinstance(expr, Unevaluated) or new_args != args:
- new_expr = expr.func(*new_args)
- else:
- new_expr = expr
- if orig_expr in to_eliminate:
- try:
- sym = next(symbols)
- except StopIteration:
- raise ValueError("Symbols iterator ran out of symbols.")
- if isinstance(orig_expr, MatrixExpr):
- sym = MatrixSymbol(sym.name, orig_expr.rows,
- orig_expr.cols)
- subs[orig_expr] = sym
- replacements.append((sym, new_expr))
- return sym
- else:
- return new_expr
- reduced_exprs = []
- for e in exprs:
- if isinstance(e, Basic):
- reduced_e = _rebuild(e)
- else:
- reduced_e = e
- reduced_exprs.append(reduced_e)
- return replacements, reduced_exprs
- def cse(exprs, symbols=None, optimizations=None, postprocess=None,
- order='canonical', ignore=(), list=True):
- """ Perform common subexpression elimination on an expression.
- Parameters
- ==========
- exprs : list of SymPy expressions, or a single SymPy expression
- The expressions to reduce.
- symbols : infinite iterator yielding unique Symbols
- The symbols used to label the common subexpressions which are pulled
- out. The ``numbered_symbols`` generator is useful. The default is a
- stream of symbols of the form "x0", "x1", etc. This must be an
- infinite iterator.
- optimizations : list of (callable, callable) pairs
- The (preprocessor, postprocessor) pairs of external optimization
- functions. Optionally 'basic' can be passed for a set of predefined
- basic optimizations. Such 'basic' optimizations were used by default
- in old implementation, however they can be really slow on larger
- expressions. Now, no pre or post optimizations are made by default.
- postprocess : a function which accepts the two return values of cse and
- returns the desired form of output from cse, e.g. if you want the
- replacements reversed the function might be the following lambda:
- lambda r, e: return reversed(r), e
- order : string, 'none' or 'canonical'
- The order by which Mul and Add arguments are processed. If set to
- 'canonical', arguments will be canonically ordered. If set to 'none',
- ordering will be faster but dependent on expressions hashes, thus
- machine dependent and variable. For large expressions where speed is a
- concern, use the setting order='none'.
- ignore : iterable of Symbols
- Substitutions containing any Symbol from ``ignore`` will be ignored.
- list : bool, (default True)
- Returns expression in list or else with same type as input (when False).
- Returns
- =======
- replacements : list of (Symbol, expression) pairs
- All of the common subexpressions that were replaced. Subexpressions
- earlier in this list might show up in subexpressions later in this
- list.
- reduced_exprs : list of SymPy expressions
- The reduced expressions with all of the replacements above.
- Examples
- ========
- >>> from sympy import cse, SparseMatrix
- >>> from sympy.abc import x, y, z, w
- >>> cse(((w + x + y + z)*(w + y + z))/(w + x)**3)
- ([(x0, y + z), (x1, w + x)], [(w + x0)*(x0 + x1)/x1**3])
- List of expressions with recursive substitutions:
- >>> m = SparseMatrix([x + y, x + y + z])
- >>> cse([(x+y)**2, x + y + z, y + z, x + z + y, m])
- ([(x0, x + y), (x1, x0 + z)], [x0**2, x1, y + z, x1, Matrix([
- [x0],
- [x1]])])
- Note: the type and mutability of input matrices is retained.
- >>> isinstance(_[1][-1], SparseMatrix)
- True
- The user may disallow substitutions containing certain symbols:
- >>> cse([y**2*(x + 1), 3*y**2*(x + 1)], ignore=(y,))
- ([(x0, x + 1)], [x0*y**2, 3*x0*y**2])
- The default return value for the reduced expression(s) is a list, even if there is only
- one expression. The `list` flag preserves the type of the input in the output:
- >>> cse(x)
- ([], [x])
- >>> cse(x, list=False)
- ([], x)
- """
- from sympy.matrices import (MatrixBase, Matrix, ImmutableMatrix,
- SparseMatrix, ImmutableSparseMatrix)
- if not list:
- return _cse_homogeneous(exprs,
- symbols=symbols, optimizations=optimizations,
- postprocess=postprocess, order=order, ignore=ignore)
- if isinstance(exprs, (int, float)):
- exprs = sympify(exprs)
- # Handle the case if just one expression was passed.
- if isinstance(exprs, (Basic, MatrixBase)):
- exprs = [exprs]
- copy = exprs
- temp = []
- for e in exprs:
- if isinstance(e, (Matrix, ImmutableMatrix)):
- temp.append(Tuple(*e.flat()))
- elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
- temp.append(Tuple(*e.todok().items()))
- else:
- temp.append(e)
- exprs = temp
- del temp
- if optimizations is None:
- optimizations = []
- elif optimizations == 'basic':
- optimizations = basic_optimizations
- # Preprocess the expressions to give us better optimization opportunities.
- reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]
- if symbols is None:
- symbols = numbered_symbols(cls=Symbol)
- else:
- # In case we get passed an iterable with an __iter__ method instead of
- # an actual iterator.
- symbols = iter(symbols)
- # Find other optimization opportunities.
- opt_subs = opt_cse(reduced_exprs, order)
- # Main CSE algorithm.
- replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs,
- order, ignore)
- # Postprocess the expressions to return the expressions to canonical form.
- exprs = copy
- for i, (sym, subtree) in enumerate(replacements):
- subtree = postprocess_for_cse(subtree, optimizations)
- replacements[i] = (sym, subtree)
- reduced_exprs = [postprocess_for_cse(e, optimizations)
- for e in reduced_exprs]
- # Get the matrices back
- for i, e in enumerate(exprs):
- if isinstance(e, (Matrix, ImmutableMatrix)):
- reduced_exprs[i] = Matrix(e.rows, e.cols, reduced_exprs[i])
- if isinstance(e, ImmutableMatrix):
- reduced_exprs[i] = reduced_exprs[i].as_immutable()
- elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
- m = SparseMatrix(e.rows, e.cols, {})
- for k, v in reduced_exprs[i]:
- m[k] = v
- if isinstance(e, ImmutableSparseMatrix):
- m = m.as_immutable()
- reduced_exprs[i] = m
- if postprocess is None:
- return replacements, reduced_exprs
- return postprocess(replacements, reduced_exprs)
- def _cse_homogeneous(exprs, **kwargs):
- """
- Same as ``cse`` but the ``reduced_exprs`` are returned
- with the same type as ``exprs`` or a sympified version of the same.
- Parameters
- ==========
- exprs : an Expr, iterable of Expr or dictionary with Expr values
- the expressions in which repeated subexpressions will be identified
- kwargs : additional arguments for the ``cse`` function
- Returns
- =======
- replacements : list of (Symbol, expression) pairs
- All of the common subexpressions that were replaced. Subexpressions
- earlier in this list might show up in subexpressions later in this
- list.
- reduced_exprs : list of SymPy expressions
- The reduced expressions with all of the replacements above.
- Examples
- ========
- >>> from sympy.simplify.cse_main import cse
- >>> from sympy import cos, Tuple, Matrix
- >>> from sympy.abc import x
- >>> output = lambda x: type(cse(x, list=False)[1])
- >>> output(1)
- <class 'sympy.core.numbers.One'>
- >>> output('cos(x)')
- <class 'str'>
- >>> output(cos(x))
- cos
- >>> output(Tuple(1, x))
- <class 'sympy.core.containers.Tuple'>
- >>> output(Matrix([[1,0], [0,1]]))
- <class 'sympy.matrices.dense.MutableDenseMatrix'>
- >>> output([1, x])
- <class 'list'>
- >>> output((1, x))
- <class 'tuple'>
- >>> output({1, x})
- <class 'set'>
- """
- if isinstance(exprs, str):
- replacements, reduced_exprs = _cse_homogeneous(
- sympify(exprs), **kwargs)
- return replacements, repr(reduced_exprs)
- if isinstance(exprs, (list, tuple, set)):
- replacements, reduced_exprs = cse(exprs, **kwargs)
- return replacements, type(exprs)(reduced_exprs)
- if isinstance(exprs, dict):
- keys = list(exprs.keys()) # In order to guarantee the order of the elements.
- replacements, values = cse([exprs[k] for k in keys], **kwargs)
- reduced_exprs = dict(zip(keys, values))
- return replacements, reduced_exprs
- try:
- replacements, (reduced_exprs,) = cse(exprs, **kwargs)
- except TypeError: # For example 'mpf' objects
- return [], exprs
- else:
- return replacements, reduced_exprs
|