cse_main.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920
  1. """ Tools for doing common subexpression elimination.
  2. """
  3. from sympy.core import Basic, Mul, Add, Pow, sympify
  4. from sympy.core.containers import Tuple, OrderedSet
  5. from sympy.core.exprtools import factor_terms
  6. from sympy.core.singleton import S
  7. from sympy.core.sorting import ordered
  8. from sympy.core.symbol import symbols, Symbol
  9. from sympy.utilities.iterables import numbered_symbols, sift, \
  10. topological_sort, iterable
  11. from . import cse_opts
  12. # (preprocessor, postprocessor) pairs which are commonly useful. They should
  13. # each take a SymPy expression and return a possibly transformed expression.
  14. # When used in the function ``cse()``, the target expressions will be transformed
  15. # by each of the preprocessor functions in order. After the common
  16. # subexpressions are eliminated, each resulting expression will have the
  17. # postprocessor functions transform them in *reverse* order in order to undo the
  18. # transformation if necessary. This allows the algorithm to operate on
  19. # a representation of the expressions that allows for more optimization
  20. # opportunities.
  21. # ``None`` can be used to specify no transformation for either the preprocessor or
  22. # postprocessor.
  23. basic_optimizations = [(cse_opts.sub_pre, cse_opts.sub_post),
  24. (factor_terms, None)]
  25. # sometimes we want the output in a different format; non-trivial
  26. # transformations can be put here for users
  27. # ===============================================================
  28. def reps_toposort(r):
  29. """Sort replacements ``r`` so (k1, v1) appears before (k2, v2)
  30. if k2 is in v1's free symbols. This orders items in the
  31. way that cse returns its results (hence, in order to use the
  32. replacements in a substitution option it would make sense
  33. to reverse the order).
  34. Examples
  35. ========
  36. >>> from sympy.simplify.cse_main import reps_toposort
  37. >>> from sympy.abc import x, y
  38. >>> from sympy import Eq
  39. >>> for l, r in reps_toposort([(x, y + 1), (y, 2)]):
  40. ... print(Eq(l, r))
  41. ...
  42. Eq(y, 2)
  43. Eq(x, y + 1)
  44. """
  45. r = sympify(r)
  46. E = []
  47. for c1, (k1, v1) in enumerate(r):
  48. for c2, (k2, v2) in enumerate(r):
  49. if k1 in v2.free_symbols:
  50. E.append((c1, c2))
  51. return [r[i] for i in topological_sort((range(len(r)), E))]
  52. def cse_separate(r, e):
  53. """Move expressions that are in the form (symbol, expr) out of the
  54. expressions and sort them into the replacements using the reps_toposort.
  55. Examples
  56. ========
  57. >>> from sympy.simplify.cse_main import cse_separate
  58. >>> from sympy.abc import x, y, z
  59. >>> from sympy import cos, exp, cse, Eq, symbols
  60. >>> x0, x1 = symbols('x:2')
  61. >>> eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1))
  62. >>> cse([eq, Eq(x, z + 1), z - 2], postprocess=cse_separate) in [
  63. ... [[(x0, y + 1), (x, z + 1), (x1, x + 1)],
  64. ... [x1 + exp(x1/x0) + cos(x0), z - 2]],
  65. ... [[(x1, y + 1), (x, z + 1), (x0, x + 1)],
  66. ... [x0 + exp(x0/x1) + cos(x1), z - 2]]]
  67. ...
  68. True
  69. """
  70. d = sift(e, lambda w: w.is_Equality and w.lhs.is_Symbol)
  71. r = r + [w.args for w in d[True]]
  72. e = d[False]
  73. return [reps_toposort(r), e]
  74. def cse_release_variables(r, e):
  75. """
  76. Return tuples giving ``(a, b)`` where ``a`` is a symbol and ``b`` is
  77. either an expression or None. The value of None is used when a
  78. symbol is no longer needed for subsequent expressions.
  79. Use of such output can reduce the memory footprint of lambdified
  80. expressions that contain large, repeated subexpressions.
  81. Examples
  82. ========
  83. >>> from sympy import cse
  84. >>> from sympy.simplify.cse_main import cse_release_variables
  85. >>> from sympy.abc import x, y
  86. >>> eqs = [(x + y - 1)**2, x, x + y, (x + y)/(2*x + 1) + (x + y - 1)**2, (2*x + 1)**(x + y)]
  87. >>> defs, rvs = cse_release_variables(*cse(eqs))
  88. >>> for i in defs:
  89. ... print(i)
  90. ...
  91. (x0, x + y)
  92. (x1, (x0 - 1)**2)
  93. (x2, 2*x + 1)
  94. (_3, x0/x2 + x1)
  95. (_4, x2**x0)
  96. (x2, None)
  97. (_0, x1)
  98. (x1, None)
  99. (_2, x0)
  100. (x0, None)
  101. (_1, x)
  102. >>> print(rvs)
  103. (_0, _1, _2, _3, _4)
  104. """
  105. if not r:
  106. return r, e
  107. s, p = zip(*r)
  108. esyms = symbols('_:%d' % len(e))
  109. syms = list(esyms)
  110. s = list(s)
  111. in_use = set(s)
  112. p = list(p)
  113. # sort e so those with most sub-expressions appear first
  114. e = [(e[i], syms[i]) for i in range(len(e))]
  115. e, syms = zip(*sorted(e,
  116. key=lambda x: -sum([p[s.index(i)].count_ops()
  117. for i in x[0].free_symbols & in_use])))
  118. syms = list(syms)
  119. p += e
  120. rv = []
  121. i = len(p) - 1
  122. while i >= 0:
  123. _p = p.pop()
  124. c = in_use & _p.free_symbols
  125. if c: # sorting for canonical results
  126. rv.extend([(s, None) for s in sorted(c, key=str)])
  127. if i >= len(r):
  128. rv.append((syms.pop(), _p))
  129. else:
  130. rv.append((s[i], _p))
  131. in_use -= c
  132. i -= 1
  133. rv.reverse()
  134. return rv, esyms
  135. # ====end of cse postprocess idioms===========================
  136. def preprocess_for_cse(expr, optimizations):
  137. """ Preprocess an expression to optimize for common subexpression
  138. elimination.
  139. Parameters
  140. ==========
  141. expr : SymPy expression
  142. The target expression to optimize.
  143. optimizations : list of (callable, callable) pairs
  144. The (preprocessor, postprocessor) pairs.
  145. Returns
  146. =======
  147. expr : SymPy expression
  148. The transformed expression.
  149. """
  150. for pre, post in optimizations:
  151. if pre is not None:
  152. expr = pre(expr)
  153. return expr
  154. def postprocess_for_cse(expr, optimizations):
  155. """Postprocess an expression after common subexpression elimination to
  156. return the expression to canonical SymPy form.
  157. Parameters
  158. ==========
  159. expr : SymPy expression
  160. The target expression to transform.
  161. optimizations : list of (callable, callable) pairs, optional
  162. The (preprocessor, postprocessor) pairs. The postprocessors will be
  163. applied in reversed order to undo the effects of the preprocessors
  164. correctly.
  165. Returns
  166. =======
  167. expr : SymPy expression
  168. The transformed expression.
  169. """
  170. for pre, post in reversed(optimizations):
  171. if post is not None:
  172. expr = post(expr)
  173. return expr
  174. class FuncArgTracker:
  175. """
  176. A class which manages a mapping from functions to arguments and an inverse
  177. mapping from arguments to functions.
  178. """
  179. def __init__(self, funcs):
  180. # To minimize the number of symbolic comparisons, all function arguments
  181. # get assigned a value number.
  182. self.value_numbers = {}
  183. self.value_number_to_value = []
  184. # Both of these maps use integer indices for arguments / functions.
  185. self.arg_to_funcset = []
  186. self.func_to_argset = []
  187. for func_i, func in enumerate(funcs):
  188. func_argset = OrderedSet()
  189. for func_arg in func.args:
  190. arg_number = self.get_or_add_value_number(func_arg)
  191. func_argset.add(arg_number)
  192. self.arg_to_funcset[arg_number].add(func_i)
  193. self.func_to_argset.append(func_argset)
  194. def get_args_in_value_order(self, argset):
  195. """
  196. Return the list of arguments in sorted order according to their value
  197. numbers.
  198. """
  199. return [self.value_number_to_value[argn] for argn in sorted(argset)]
  200. def get_or_add_value_number(self, value):
  201. """
  202. Return the value number for the given argument.
  203. """
  204. nvalues = len(self.value_numbers)
  205. value_number = self.value_numbers.setdefault(value, nvalues)
  206. if value_number == nvalues:
  207. self.value_number_to_value.append(value)
  208. self.arg_to_funcset.append(OrderedSet())
  209. return value_number
  210. def stop_arg_tracking(self, func_i):
  211. """
  212. Remove the function func_i from the argument to function mapping.
  213. """
  214. for arg in self.func_to_argset[func_i]:
  215. self.arg_to_funcset[arg].remove(func_i)
  216. def get_common_arg_candidates(self, argset, min_func_i=0):
  217. """Return a dict whose keys are function numbers. The entries of the dict are
  218. the number of arguments said function has in common with
  219. ``argset``. Entries have at least 2 items in common. All keys have
  220. value at least ``min_func_i``.
  221. """
  222. from collections import defaultdict
  223. count_map = defaultdict(lambda: 0)
  224. if not argset:
  225. return count_map
  226. funcsets = [self.arg_to_funcset[arg] for arg in argset]
  227. # As an optimization below, we handle the largest funcset separately from
  228. # the others.
  229. largest_funcset = max(funcsets, key=len)
  230. for funcset in funcsets:
  231. if largest_funcset is funcset:
  232. continue
  233. for func_i in funcset:
  234. if func_i >= min_func_i:
  235. count_map[func_i] += 1
  236. # We pick the smaller of the two containers (count_map, largest_funcset)
  237. # to iterate over to reduce the number of iterations needed.
  238. (smaller_funcs_container,
  239. larger_funcs_container) = sorted(
  240. [largest_funcset, count_map],
  241. key=len)
  242. for func_i in smaller_funcs_container:
  243. # Not already in count_map? It can't possibly be in the output, so
  244. # skip it.
  245. if count_map[func_i] < 1:
  246. continue
  247. if func_i in larger_funcs_container:
  248. count_map[func_i] += 1
  249. return {k: v for k, v in count_map.items() if v >= 2}
  250. def get_subset_candidates(self, argset, restrict_to_funcset=None):
  251. """
  252. Return a set of functions each of which whose argument list contains
  253. ``argset``, optionally filtered only to contain functions in
  254. ``restrict_to_funcset``.
  255. """
  256. iarg = iter(argset)
  257. indices = OrderedSet(
  258. fi for fi in self.arg_to_funcset[next(iarg)])
  259. if restrict_to_funcset is not None:
  260. indices &= restrict_to_funcset
  261. for arg in iarg:
  262. indices &= self.arg_to_funcset[arg]
  263. return indices
  264. def update_func_argset(self, func_i, new_argset):
  265. """
  266. Update a function with a new set of arguments.
  267. """
  268. new_args = OrderedSet(new_argset)
  269. old_args = self.func_to_argset[func_i]
  270. for deleted_arg in old_args - new_args:
  271. self.arg_to_funcset[deleted_arg].remove(func_i)
  272. for added_arg in new_args - old_args:
  273. self.arg_to_funcset[added_arg].add(func_i)
  274. self.func_to_argset[func_i].clear()
  275. self.func_to_argset[func_i].update(new_args)
  276. class Unevaluated:
  277. def __init__(self, func, args):
  278. self.func = func
  279. self.args = args
  280. def __str__(self):
  281. return "Uneval<{}>({})".format(
  282. self.func, ", ".join(str(a) for a in self.args))
  283. def as_unevaluated_basic(self):
  284. return self.func(*self.args, evaluate=False)
  285. @property
  286. def free_symbols(self):
  287. return set().union(*[a.free_symbols for a in self.args])
  288. __repr__ = __str__
  289. def match_common_args(func_class, funcs, opt_subs):
  290. """
  291. Recognize and extract common subexpressions of function arguments within a
  292. set of function calls. For instance, for the following function calls::
  293. x + z + y
  294. sin(x + y)
  295. this will extract a common subexpression of `x + y`::
  296. w = x + y
  297. w + z
  298. sin(w)
  299. The function we work with is assumed to be associative and commutative.
  300. Parameters
  301. ==========
  302. func_class: class
  303. The function class (e.g. Add, Mul)
  304. funcs: list of functions
  305. A list of function calls.
  306. opt_subs: dict
  307. A dictionary of substitutions which this function may update.
  308. """
  309. # Sort to ensure that whole-function subexpressions come before the items
  310. # that use them.
  311. funcs = sorted(funcs, key=lambda f: len(f.args))
  312. arg_tracker = FuncArgTracker(funcs)
  313. changed = OrderedSet()
  314. for i in range(len(funcs)):
  315. common_arg_candidates_counts = arg_tracker.get_common_arg_candidates(
  316. arg_tracker.func_to_argset[i], min_func_i=i + 1)
  317. # Sort the candidates in order of match size.
  318. # This makes us try combining smaller matches first.
  319. common_arg_candidates = OrderedSet(sorted(
  320. common_arg_candidates_counts.keys(),
  321. key=lambda k: (common_arg_candidates_counts[k], k)))
  322. while common_arg_candidates:
  323. j = common_arg_candidates.pop(last=False)
  324. com_args = arg_tracker.func_to_argset[i].intersection(
  325. arg_tracker.func_to_argset[j])
  326. if len(com_args) <= 1:
  327. # This may happen if a set of common arguments was already
  328. # combined in a previous iteration.
  329. continue
  330. # For all sets, replace the common symbols by the function
  331. # over them, to allow recursive matches.
  332. diff_i = arg_tracker.func_to_argset[i].difference(com_args)
  333. if diff_i:
  334. # com_func needs to be unevaluated to allow for recursive matches.
  335. com_func = Unevaluated(
  336. func_class, arg_tracker.get_args_in_value_order(com_args))
  337. com_func_number = arg_tracker.get_or_add_value_number(com_func)
  338. arg_tracker.update_func_argset(i, diff_i | OrderedSet([com_func_number]))
  339. changed.add(i)
  340. else:
  341. # Treat the whole expression as a CSE.
  342. #
  343. # The reason this needs to be done is somewhat subtle. Within
  344. # tree_cse(), to_eliminate only contains expressions that are
  345. # seen more than once. The problem is unevaluated expressions
  346. # do not compare equal to the evaluated equivalent. So
  347. # tree_cse() won't mark funcs[i] as a CSE if we use an
  348. # unevaluated version.
  349. com_func_number = arg_tracker.get_or_add_value_number(funcs[i])
  350. diff_j = arg_tracker.func_to_argset[j].difference(com_args)
  351. arg_tracker.update_func_argset(j, diff_j | OrderedSet([com_func_number]))
  352. changed.add(j)
  353. for k in arg_tracker.get_subset_candidates(
  354. com_args, common_arg_candidates):
  355. diff_k = arg_tracker.func_to_argset[k].difference(com_args)
  356. arg_tracker.update_func_argset(k, diff_k | OrderedSet([com_func_number]))
  357. changed.add(k)
  358. if i in changed:
  359. opt_subs[funcs[i]] = Unevaluated(func_class,
  360. arg_tracker.get_args_in_value_order(arg_tracker.func_to_argset[i]))
  361. arg_tracker.stop_arg_tracking(i)
  362. def opt_cse(exprs, order='canonical'):
  363. """Find optimization opportunities in Adds, Muls, Pows and negative
  364. coefficient Muls.
  365. Parameters
  366. ==========
  367. exprs : list of SymPy expressions
  368. The expressions to optimize.
  369. order : string, 'none' or 'canonical'
  370. The order by which Mul and Add arguments are processed. For large
  371. expressions where speed is a concern, use the setting order='none'.
  372. Returns
  373. =======
  374. opt_subs : dictionary of expression substitutions
  375. The expression substitutions which can be useful to optimize CSE.
  376. Examples
  377. ========
  378. >>> from sympy.simplify.cse_main import opt_cse
  379. >>> from sympy.abc import x
  380. >>> opt_subs = opt_cse([x**-2])
  381. >>> k, v = list(opt_subs.keys())[0], list(opt_subs.values())[0]
  382. >>> print((k, v.as_unevaluated_basic()))
  383. (x**(-2), 1/(x**2))
  384. """
  385. from sympy.matrices.expressions import MatAdd, MatMul, MatPow
  386. opt_subs = dict()
  387. adds = OrderedSet()
  388. muls = OrderedSet()
  389. seen_subexp = set()
  390. def _find_opts(expr):
  391. if not isinstance(expr, (Basic, Unevaluated)):
  392. return
  393. if expr.is_Atom or expr.is_Order:
  394. return
  395. if iterable(expr):
  396. list(map(_find_opts, expr))
  397. return
  398. if expr in seen_subexp:
  399. return expr
  400. seen_subexp.add(expr)
  401. list(map(_find_opts, expr.args))
  402. if expr.could_extract_minus_sign():
  403. neg_expr = -expr
  404. if not neg_expr.is_Atom:
  405. opt_subs[expr] = Unevaluated(Mul, (S.NegativeOne, neg_expr))
  406. seen_subexp.add(neg_expr)
  407. expr = neg_expr
  408. if isinstance(expr, (Mul, MatMul)):
  409. muls.add(expr)
  410. elif isinstance(expr, (Add, MatAdd)):
  411. adds.add(expr)
  412. elif isinstance(expr, (Pow, MatPow)):
  413. base, exp = expr.base, expr.exp
  414. if exp.could_extract_minus_sign():
  415. opt_subs[expr] = Unevaluated(Pow, (Pow(base, -exp), -1))
  416. for e in exprs:
  417. if isinstance(e, (Basic, Unevaluated)):
  418. _find_opts(e)
  419. # split muls into commutative
  420. commutative_muls = OrderedSet()
  421. for m in muls:
  422. c, nc = m.args_cnc(cset=False)
  423. if c:
  424. c_mul = m.func(*c)
  425. if nc:
  426. if c_mul == 1:
  427. new_obj = m.func(*nc)
  428. else:
  429. new_obj = m.func(c_mul, m.func(*nc), evaluate=False)
  430. opt_subs[m] = new_obj
  431. if len(c) > 1:
  432. commutative_muls.add(c_mul)
  433. match_common_args(Add, adds, opt_subs)
  434. match_common_args(Mul, commutative_muls, opt_subs)
  435. return opt_subs
  436. def tree_cse(exprs, symbols, opt_subs=None, order='canonical', ignore=()):
  437. """Perform raw CSE on expression tree, taking opt_subs into account.
  438. Parameters
  439. ==========
  440. exprs : list of SymPy expressions
  441. The expressions to reduce.
  442. symbols : infinite iterator yielding unique Symbols
  443. The symbols used to label the common subexpressions which are pulled
  444. out.
  445. opt_subs : dictionary of expression substitutions
  446. The expressions to be substituted before any CSE action is performed.
  447. order : string, 'none' or 'canonical'
  448. The order by which Mul and Add arguments are processed. For large
  449. expressions where speed is a concern, use the setting order='none'.
  450. ignore : iterable of Symbols
  451. Substitutions containing any Symbol from ``ignore`` will be ignored.
  452. """
  453. from sympy.matrices.expressions import MatrixExpr, MatrixSymbol, MatMul, MatAdd
  454. from sympy.matrices.expressions.matexpr import MatrixElement
  455. from sympy.polys.rootoftools import RootOf
  456. if opt_subs is None:
  457. opt_subs = dict()
  458. ## Find repeated sub-expressions
  459. to_eliminate = set()
  460. seen_subexp = set()
  461. excluded_symbols = set()
  462. def _find_repeated(expr):
  463. if not isinstance(expr, (Basic, Unevaluated)):
  464. return
  465. if isinstance(expr, RootOf):
  466. return
  467. if isinstance(expr, Basic) and (
  468. expr.is_Atom or
  469. expr.is_Order or
  470. isinstance(expr, (MatrixSymbol, MatrixElement))):
  471. if expr.is_Symbol:
  472. excluded_symbols.add(expr)
  473. return
  474. if iterable(expr):
  475. args = expr
  476. else:
  477. if expr in seen_subexp:
  478. for ign in ignore:
  479. if ign in expr.free_symbols:
  480. break
  481. else:
  482. to_eliminate.add(expr)
  483. return
  484. seen_subexp.add(expr)
  485. if expr in opt_subs:
  486. expr = opt_subs[expr]
  487. args = expr.args
  488. list(map(_find_repeated, args))
  489. for e in exprs:
  490. if isinstance(e, Basic):
  491. _find_repeated(e)
  492. ## Rebuild tree
  493. # Remove symbols from the generator that conflict with names in the expressions.
  494. symbols = (symbol for symbol in symbols if symbol not in excluded_symbols)
  495. replacements = []
  496. subs = dict()
  497. def _rebuild(expr):
  498. if not isinstance(expr, (Basic, Unevaluated)):
  499. return expr
  500. if not expr.args:
  501. return expr
  502. if iterable(expr):
  503. new_args = [_rebuild(arg) for arg in expr]
  504. return expr.func(*new_args)
  505. if expr in subs:
  506. return subs[expr]
  507. orig_expr = expr
  508. if expr in opt_subs:
  509. expr = opt_subs[expr]
  510. # If enabled, parse Muls and Adds arguments by order to ensure
  511. # replacement order independent from hashes
  512. if order != 'none':
  513. if isinstance(expr, (Mul, MatMul)):
  514. c, nc = expr.args_cnc()
  515. if c == [1]:
  516. args = nc
  517. else:
  518. args = list(ordered(c)) + nc
  519. elif isinstance(expr, (Add, MatAdd)):
  520. args = list(ordered(expr.args))
  521. else:
  522. args = expr.args
  523. else:
  524. args = expr.args
  525. new_args = list(map(_rebuild, args))
  526. if isinstance(expr, Unevaluated) or new_args != args:
  527. new_expr = expr.func(*new_args)
  528. else:
  529. new_expr = expr
  530. if orig_expr in to_eliminate:
  531. try:
  532. sym = next(symbols)
  533. except StopIteration:
  534. raise ValueError("Symbols iterator ran out of symbols.")
  535. if isinstance(orig_expr, MatrixExpr):
  536. sym = MatrixSymbol(sym.name, orig_expr.rows,
  537. orig_expr.cols)
  538. subs[orig_expr] = sym
  539. replacements.append((sym, new_expr))
  540. return sym
  541. else:
  542. return new_expr
  543. reduced_exprs = []
  544. for e in exprs:
  545. if isinstance(e, Basic):
  546. reduced_e = _rebuild(e)
  547. else:
  548. reduced_e = e
  549. reduced_exprs.append(reduced_e)
  550. return replacements, reduced_exprs
  551. def cse(exprs, symbols=None, optimizations=None, postprocess=None,
  552. order='canonical', ignore=(), list=True):
  553. """ Perform common subexpression elimination on an expression.
  554. Parameters
  555. ==========
  556. exprs : list of SymPy expressions, or a single SymPy expression
  557. The expressions to reduce.
  558. symbols : infinite iterator yielding unique Symbols
  559. The symbols used to label the common subexpressions which are pulled
  560. out. The ``numbered_symbols`` generator is useful. The default is a
  561. stream of symbols of the form "x0", "x1", etc. This must be an
  562. infinite iterator.
  563. optimizations : list of (callable, callable) pairs
  564. The (preprocessor, postprocessor) pairs of external optimization
  565. functions. Optionally 'basic' can be passed for a set of predefined
  566. basic optimizations. Such 'basic' optimizations were used by default
  567. in old implementation, however they can be really slow on larger
  568. expressions. Now, no pre or post optimizations are made by default.
  569. postprocess : a function which accepts the two return values of cse and
  570. returns the desired form of output from cse, e.g. if you want the
  571. replacements reversed the function might be the following lambda:
  572. lambda r, e: return reversed(r), e
  573. order : string, 'none' or 'canonical'
  574. The order by which Mul and Add arguments are processed. If set to
  575. 'canonical', arguments will be canonically ordered. If set to 'none',
  576. ordering will be faster but dependent on expressions hashes, thus
  577. machine dependent and variable. For large expressions where speed is a
  578. concern, use the setting order='none'.
  579. ignore : iterable of Symbols
  580. Substitutions containing any Symbol from ``ignore`` will be ignored.
  581. list : bool, (default True)
  582. Returns expression in list or else with same type as input (when False).
  583. Returns
  584. =======
  585. replacements : list of (Symbol, expression) pairs
  586. All of the common subexpressions that were replaced. Subexpressions
  587. earlier in this list might show up in subexpressions later in this
  588. list.
  589. reduced_exprs : list of SymPy expressions
  590. The reduced expressions with all of the replacements above.
  591. Examples
  592. ========
  593. >>> from sympy import cse, SparseMatrix
  594. >>> from sympy.abc import x, y, z, w
  595. >>> cse(((w + x + y + z)*(w + y + z))/(w + x)**3)
  596. ([(x0, y + z), (x1, w + x)], [(w + x0)*(x0 + x1)/x1**3])
  597. List of expressions with recursive substitutions:
  598. >>> m = SparseMatrix([x + y, x + y + z])
  599. >>> cse([(x+y)**2, x + y + z, y + z, x + z + y, m])
  600. ([(x0, x + y), (x1, x0 + z)], [x0**2, x1, y + z, x1, Matrix([
  601. [x0],
  602. [x1]])])
  603. Note: the type and mutability of input matrices is retained.
  604. >>> isinstance(_[1][-1], SparseMatrix)
  605. True
  606. The user may disallow substitutions containing certain symbols:
  607. >>> cse([y**2*(x + 1), 3*y**2*(x + 1)], ignore=(y,))
  608. ([(x0, x + 1)], [x0*y**2, 3*x0*y**2])
  609. The default return value for the reduced expression(s) is a list, even if there is only
  610. one expression. The `list` flag preserves the type of the input in the output:
  611. >>> cse(x)
  612. ([], [x])
  613. >>> cse(x, list=False)
  614. ([], x)
  615. """
  616. from sympy.matrices import (MatrixBase, Matrix, ImmutableMatrix,
  617. SparseMatrix, ImmutableSparseMatrix)
  618. if not list:
  619. return _cse_homogeneous(exprs,
  620. symbols=symbols, optimizations=optimizations,
  621. postprocess=postprocess, order=order, ignore=ignore)
  622. if isinstance(exprs, (int, float)):
  623. exprs = sympify(exprs)
  624. # Handle the case if just one expression was passed.
  625. if isinstance(exprs, (Basic, MatrixBase)):
  626. exprs = [exprs]
  627. copy = exprs
  628. temp = []
  629. for e in exprs:
  630. if isinstance(e, (Matrix, ImmutableMatrix)):
  631. temp.append(Tuple(*e.flat()))
  632. elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
  633. temp.append(Tuple(*e.todok().items()))
  634. else:
  635. temp.append(e)
  636. exprs = temp
  637. del temp
  638. if optimizations is None:
  639. optimizations = []
  640. elif optimizations == 'basic':
  641. optimizations = basic_optimizations
  642. # Preprocess the expressions to give us better optimization opportunities.
  643. reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]
  644. if symbols is None:
  645. symbols = numbered_symbols(cls=Symbol)
  646. else:
  647. # In case we get passed an iterable with an __iter__ method instead of
  648. # an actual iterator.
  649. symbols = iter(symbols)
  650. # Find other optimization opportunities.
  651. opt_subs = opt_cse(reduced_exprs, order)
  652. # Main CSE algorithm.
  653. replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs,
  654. order, ignore)
  655. # Postprocess the expressions to return the expressions to canonical form.
  656. exprs = copy
  657. for i, (sym, subtree) in enumerate(replacements):
  658. subtree = postprocess_for_cse(subtree, optimizations)
  659. replacements[i] = (sym, subtree)
  660. reduced_exprs = [postprocess_for_cse(e, optimizations)
  661. for e in reduced_exprs]
  662. # Get the matrices back
  663. for i, e in enumerate(exprs):
  664. if isinstance(e, (Matrix, ImmutableMatrix)):
  665. reduced_exprs[i] = Matrix(e.rows, e.cols, reduced_exprs[i])
  666. if isinstance(e, ImmutableMatrix):
  667. reduced_exprs[i] = reduced_exprs[i].as_immutable()
  668. elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
  669. m = SparseMatrix(e.rows, e.cols, {})
  670. for k, v in reduced_exprs[i]:
  671. m[k] = v
  672. if isinstance(e, ImmutableSparseMatrix):
  673. m = m.as_immutable()
  674. reduced_exprs[i] = m
  675. if postprocess is None:
  676. return replacements, reduced_exprs
  677. return postprocess(replacements, reduced_exprs)
  678. def _cse_homogeneous(exprs, **kwargs):
  679. """
  680. Same as ``cse`` but the ``reduced_exprs`` are returned
  681. with the same type as ``exprs`` or a sympified version of the same.
  682. Parameters
  683. ==========
  684. exprs : an Expr, iterable of Expr or dictionary with Expr values
  685. the expressions in which repeated subexpressions will be identified
  686. kwargs : additional arguments for the ``cse`` function
  687. Returns
  688. =======
  689. replacements : list of (Symbol, expression) pairs
  690. All of the common subexpressions that were replaced. Subexpressions
  691. earlier in this list might show up in subexpressions later in this
  692. list.
  693. reduced_exprs : list of SymPy expressions
  694. The reduced expressions with all of the replacements above.
  695. Examples
  696. ========
  697. >>> from sympy.simplify.cse_main import cse
  698. >>> from sympy import cos, Tuple, Matrix
  699. >>> from sympy.abc import x
  700. >>> output = lambda x: type(cse(x, list=False)[1])
  701. >>> output(1)
  702. <class 'sympy.core.numbers.One'>
  703. >>> output('cos(x)')
  704. <class 'str'>
  705. >>> output(cos(x))
  706. cos
  707. >>> output(Tuple(1, x))
  708. <class 'sympy.core.containers.Tuple'>
  709. >>> output(Matrix([[1,0], [0,1]]))
  710. <class 'sympy.matrices.dense.MutableDenseMatrix'>
  711. >>> output([1, x])
  712. <class 'list'>
  713. >>> output((1, x))
  714. <class 'tuple'>
  715. >>> output({1, x})
  716. <class 'set'>
  717. """
  718. if isinstance(exprs, str):
  719. replacements, reduced_exprs = _cse_homogeneous(
  720. sympify(exprs), **kwargs)
  721. return replacements, repr(reduced_exprs)
  722. if isinstance(exprs, (list, tuple, set)):
  723. replacements, reduced_exprs = cse(exprs, **kwargs)
  724. return replacements, type(exprs)(reduced_exprs)
  725. if isinstance(exprs, dict):
  726. keys = list(exprs.keys()) # In order to guarantee the order of the elements.
  727. replacements, values = cse([exprs[k] for k in keys], **kwargs)
  728. reduced_exprs = dict(zip(keys, values))
  729. return replacements, reduced_exprs
  730. try:
  731. replacements, (reduced_exprs,) = cse(exprs, **kwargs)
  732. except TypeError: # For example 'mpf' objects
  733. return [], exprs
  734. else:
  735. return replacements, reduced_exprs