123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206 |
- from collections import defaultdict
- from sympy import SYMPY_DEBUG
- from sympy.core import sympify, S, Mul, Derivative, Pow
- from sympy.core.add import _unevaluated_Add, Add
- from sympy.core.assumptions import assumptions
- from sympy.core.exprtools import Factors, gcd_terms
- from sympy.core.function import _mexpand, expand_mul, expand_power_base
- from sympy.core.mul import _keep_coeff, _unevaluated_Mul, _mulsort
- from sympy.core.numbers import Rational, zoo, nan
- from sympy.core.parameters import global_parameters
- from sympy.core.sorting import ordered, default_sort_key
- from sympy.core.symbol import Dummy, Wild, symbols
- from sympy.functions import exp, sqrt, log
- from sympy.functions.elementary.complexes import Abs
- from sympy.polys import gcd
- from sympy.simplify.sqrtdenest import sqrtdenest
- from sympy.utilities.iterables import iterable, sift
- def collect(expr, syms, func=None, evaluate=None, exact=False, distribute_order_term=True):
- """
- Collect additive terms of an expression.
- Explanation
- ===========
- This function collects additive terms of an expression with respect
- to a list of expression up to powers with rational exponents. By the
- term symbol here are meant arbitrary expressions, which can contain
- powers, products, sums etc. In other words symbol is a pattern which
- will be searched for in the expression's terms.
- The input expression is not expanded by :func:`collect`, so user is
- expected to provide an expression in an appropriate form. This makes
- :func:`collect` more predictable as there is no magic happening behind the
- scenes. However, it is important to note, that powers of products are
- converted to products of powers using the :func:`~.expand_power_base`
- function.
- There are two possible types of output. First, if ``evaluate`` flag is
- set, this function will return an expression with collected terms or
- else it will return a dictionary with expressions up to rational powers
- as keys and collected coefficients as values.
- Examples
- ========
- >>> from sympy import S, collect, expand, factor, Wild
- >>> from sympy.abc import a, b, c, x, y
- This function can collect symbolic coefficients in polynomials or
- rational expressions. It will manage to find all integer or rational
- powers of collection variable::
- >>> collect(a*x**2 + b*x**2 + a*x - b*x + c, x)
- c + x**2*(a + b) + x*(a - b)
- The same result can be achieved in dictionary form::
- >>> d = collect(a*x**2 + b*x**2 + a*x - b*x + c, x, evaluate=False)
- >>> d[x**2]
- a + b
- >>> d[x]
- a - b
- >>> d[S.One]
- c
- You can also work with multivariate polynomials. However, remember that
- this function is greedy so it will care only about a single symbol at time,
- in specification order::
- >>> collect(x**2 + y*x**2 + x*y + y + a*y, [x, y])
- x**2*(y + 1) + x*y + y*(a + 1)
- Also more complicated expressions can be used as patterns::
- >>> from sympy import sin, log
- >>> collect(a*sin(2*x) + b*sin(2*x), sin(2*x))
- (a + b)*sin(2*x)
- >>> collect(a*x*log(x) + b*(x*log(x)), x*log(x))
- x*(a + b)*log(x)
- You can use wildcards in the pattern::
- >>> w = Wild('w1')
- >>> collect(a*x**y - b*x**y, w**y)
- x**y*(a - b)
- It is also possible to work with symbolic powers, although it has more
- complicated behavior, because in this case power's base and symbolic part
- of the exponent are treated as a single symbol::
- >>> collect(a*x**c + b*x**c, x)
- a*x**c + b*x**c
- >>> collect(a*x**c + b*x**c, x**c)
- x**c*(a + b)
- However if you incorporate rationals to the exponents, then you will get
- well known behavior::
- >>> collect(a*x**(2*c) + b*x**(2*c), x**c)
- x**(2*c)*(a + b)
- Note also that all previously stated facts about :func:`collect` function
- apply to the exponential function, so you can get::
- >>> from sympy import exp
- >>> collect(a*exp(2*x) + b*exp(2*x), exp(x))
- (a + b)*exp(2*x)
- If you are interested only in collecting specific powers of some symbols
- then set ``exact`` flag in arguments::
- >>> collect(a*x**7 + b*x**7, x, exact=True)
- a*x**7 + b*x**7
- >>> collect(a*x**7 + b*x**7, x**7, exact=True)
- x**7*(a + b)
- You can also apply this function to differential equations, where
- derivatives of arbitrary order can be collected. Note that if you
- collect with respect to a function or a derivative of a function, all
- derivatives of that function will also be collected. Use
- ``exact=True`` to prevent this from happening::
- >>> from sympy import Derivative as D, collect, Function
- >>> f = Function('f') (x)
- >>> collect(a*D(f,x) + b*D(f,x), D(f,x))
- (a + b)*Derivative(f(x), x)
- >>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), f)
- (a + b)*Derivative(f(x), (x, 2))
- >>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), D(f,x), exact=True)
- a*Derivative(f(x), (x, 2)) + b*Derivative(f(x), (x, 2))
- >>> collect(a*D(f,x) + b*D(f,x) + a*f + b*f, f)
- (a + b)*f(x) + (a + b)*Derivative(f(x), x)
- Or you can even match both derivative order and exponent at the same time::
- >>> collect(a*D(D(f,x),x)**2 + b*D(D(f,x),x)**2, D(f,x))
- (a + b)*Derivative(f(x), (x, 2))**2
- Finally, you can apply a function to each of the collected coefficients.
- For example you can factorize symbolic coefficients of polynomial::
- >>> f = expand((x + a + 1)**3)
- >>> collect(f, x, factor)
- x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + (a + 1)**3
- .. note:: Arguments are expected to be in expanded form, so you might have
- to call :func:`~.expand` prior to calling this function.
- See Also
- ========
- collect_const, collect_sqrt, rcollect
- """
- expr = sympify(expr)
- syms = [sympify(i) for i in (syms if iterable(syms) else [syms])]
- # replace syms[i] if it is not x, -x or has Wild symbols
- cond = lambda x: x.is_Symbol or (-x).is_Symbol or bool(
- x.atoms(Wild))
- _, nonsyms = sift(syms, cond, binary=True)
- if nonsyms:
- reps = dict(zip(nonsyms, [Dummy(**assumptions(i)) for i in nonsyms]))
- syms = [reps.get(s, s) for s in syms]
- rv = collect(expr.subs(reps), syms,
- func=func, evaluate=evaluate, exact=exact,
- distribute_order_term=distribute_order_term)
- urep = {v: k for k, v in reps.items()}
- if not isinstance(rv, dict):
- return rv.xreplace(urep)
- else:
- return {urep.get(k, k).xreplace(urep): v.xreplace(urep)
- for k, v in rv.items()}
- if evaluate is None:
- evaluate = global_parameters.evaluate
- def make_expression(terms):
- product = []
- for term, rat, sym, deriv in terms:
- if deriv is not None:
- var, order = deriv
- while order > 0:
- term, order = Derivative(term, var), order - 1
- if sym is None:
- if rat is S.One:
- product.append(term)
- else:
- product.append(Pow(term, rat))
- else:
- product.append(Pow(term, rat*sym))
- return Mul(*product)
- def parse_derivative(deriv):
- # scan derivatives tower in the input expression and return
- # underlying function and maximal differentiation order
- expr, sym, order = deriv.expr, deriv.variables[0], 1
- for s in deriv.variables[1:]:
- if s == sym:
- order += 1
- else:
- raise NotImplementedError(
- 'Improve MV Derivative support in collect')
- while isinstance(expr, Derivative):
- s0 = expr.variables[0]
- for s in expr.variables:
- if s != s0:
- raise NotImplementedError(
- 'Improve MV Derivative support in collect')
- if s0 == sym:
- expr, order = expr.expr, order + len(expr.variables)
- else:
- break
- return expr, (sym, Rational(order))
- def parse_term(expr):
- """Parses expression expr and outputs tuple (sexpr, rat_expo,
- sym_expo, deriv)
- where:
- - sexpr is the base expression
- - rat_expo is the rational exponent that sexpr is raised to
- - sym_expo is the symbolic exponent that sexpr is raised to
- - deriv contains the derivatives of the expression
- For example, the output of x would be (x, 1, None, None)
- the output of 2**x would be (2, 1, x, None).
- """
- rat_expo, sym_expo = S.One, None
- sexpr, deriv = expr, None
- if expr.is_Pow:
- if isinstance(expr.base, Derivative):
- sexpr, deriv = parse_derivative(expr.base)
- else:
- sexpr = expr.base
- if expr.base == S.Exp1:
- arg = expr.exp
- if arg.is_Rational:
- sexpr, rat_expo = S.Exp1, arg
- elif arg.is_Mul:
- coeff, tail = arg.as_coeff_Mul(rational=True)
- sexpr, rat_expo = exp(tail), coeff
- elif expr.exp.is_Number:
- rat_expo = expr.exp
- else:
- coeff, tail = expr.exp.as_coeff_Mul()
- if coeff.is_Number:
- rat_expo, sym_expo = coeff, tail
- else:
- sym_expo = expr.exp
- elif isinstance(expr, exp):
- arg = expr.exp
- if arg.is_Rational:
- sexpr, rat_expo = S.Exp1, arg
- elif arg.is_Mul:
- coeff, tail = arg.as_coeff_Mul(rational=True)
- sexpr, rat_expo = exp(tail), coeff
- elif isinstance(expr, Derivative):
- sexpr, deriv = parse_derivative(expr)
- return sexpr, rat_expo, sym_expo, deriv
- def parse_expression(terms, pattern):
- """Parse terms searching for a pattern.
- Terms is a list of tuples as returned by parse_terms;
- Pattern is an expression treated as a product of factors.
- """
- pattern = Mul.make_args(pattern)
- if len(terms) < len(pattern):
- # pattern is longer than matched product
- # so no chance for positive parsing result
- return None
- else:
- pattern = [parse_term(elem) for elem in pattern]
- terms = terms[:] # need a copy
- elems, common_expo, has_deriv = [], None, False
- for elem, e_rat, e_sym, e_ord in pattern:
- if elem.is_Number and e_rat == 1 and e_sym is None:
- # a constant is a match for everything
- continue
- for j in range(len(terms)):
- if terms[j] is None:
- continue
- term, t_rat, t_sym, t_ord = terms[j]
- # keeping track of whether one of the terms had
- # a derivative or not as this will require rebuilding
- # the expression later
- if t_ord is not None:
- has_deriv = True
- if (term.match(elem) is not None and
- (t_sym == e_sym or t_sym is not None and
- e_sym is not None and
- t_sym.match(e_sym) is not None)):
- if exact is False:
- # we don't have to be exact so find common exponent
- # for both expression's term and pattern's element
- expo = t_rat / e_rat
- if common_expo is None:
- # first time
- common_expo = expo
- else:
- # common exponent was negotiated before so
- # there is no chance for a pattern match unless
- # common and current exponents are equal
- if common_expo != expo:
- common_expo = 1
- else:
- # we ought to be exact so all fields of
- # interest must match in every details
- if e_rat != t_rat or e_ord != t_ord:
- continue
- # found common term so remove it from the expression
- # and try to match next element in the pattern
- elems.append(terms[j])
- terms[j] = None
- break
- else:
- # pattern element not found
- return None
- return [_f for _f in terms if _f], elems, common_expo, has_deriv
- if evaluate:
- if expr.is_Add:
- o = expr.getO() or 0
- expr = expr.func(*[
- collect(a, syms, func, True, exact, distribute_order_term)
- for a in expr.args if a != o]) + o
- elif expr.is_Mul:
- return expr.func(*[
- collect(term, syms, func, True, exact, distribute_order_term)
- for term in expr.args])
- elif expr.is_Pow:
- b = collect(
- expr.base, syms, func, True, exact, distribute_order_term)
- return Pow(b, expr.exp)
- syms = [expand_power_base(i, deep=False) for i in syms]
- order_term = None
- if distribute_order_term:
- order_term = expr.getO()
- if order_term is not None:
- if order_term.has(*syms):
- order_term = None
- else:
- expr = expr.removeO()
- summa = [expand_power_base(i, deep=False) for i in Add.make_args(expr)]
- collected, disliked = defaultdict(list), S.Zero
- for product in summa:
- c, nc = product.args_cnc(split_1=False)
- args = list(ordered(c)) + nc
- terms = [parse_term(i) for i in args]
- small_first = True
- for symbol in syms:
- if SYMPY_DEBUG:
- print("DEBUG: parsing of expression %s with symbol %s " % (
- str(terms), str(symbol))
- )
- if isinstance(symbol, Derivative) and small_first:
- terms = list(reversed(terms))
- small_first = not small_first
- result = parse_expression(terms, symbol)
- if SYMPY_DEBUG:
- print("DEBUG: returned %s" % str(result))
- if result is not None:
- if not symbol.is_commutative:
- raise AttributeError("Can not collect noncommutative symbol")
- terms, elems, common_expo, has_deriv = result
- # when there was derivative in current pattern we
- # will need to rebuild its expression from scratch
- if not has_deriv:
- margs = []
- for elem in elems:
- if elem[2] is None:
- e = elem[1]
- else:
- e = elem[1]*elem[2]
- margs.append(Pow(elem[0], e))
- index = Mul(*margs)
- else:
- index = make_expression(elems)
- terms = expand_power_base(make_expression(terms), deep=False)
- index = expand_power_base(index, deep=False)
- collected[index].append(terms)
- break
- else:
- # none of the patterns matched
- disliked += product
- # add terms now for each key
- collected = {k: Add(*v) for k, v in collected.items()}
- if disliked is not S.Zero:
- collected[S.One] = disliked
- if order_term is not None:
- for key, val in collected.items():
- collected[key] = val + order_term
- if func is not None:
- collected = {
- key: func(val) for key, val in collected.items()}
- if evaluate:
- return Add(*[key*val for key, val in collected.items()])
- else:
- return collected
- def rcollect(expr, *vars):
- """
- Recursively collect sums in an expression.
- Examples
- ========
- >>> from sympy.simplify import rcollect
- >>> from sympy.abc import x, y
- >>> expr = (x**2*y + x*y + x + y)/(x + y)
- >>> rcollect(expr, y)
- (x + y*(x**2 + x + 1))/(x + y)
- See Also
- ========
- collect, collect_const, collect_sqrt
- """
- if expr.is_Atom or not expr.has(*vars):
- return expr
- else:
- expr = expr.__class__(*[rcollect(arg, *vars) for arg in expr.args])
- if expr.is_Add:
- return collect(expr, vars)
- else:
- return expr
- def collect_sqrt(expr, evaluate=None):
- """Return expr with terms having common square roots collected together.
- If ``evaluate`` is False a count indicating the number of sqrt-containing
- terms will be returned and, if non-zero, the terms of the Add will be
- returned, else the expression itself will be returned as a single term.
- If ``evaluate`` is True, the expression with any collected terms will be
- returned.
- Note: since I = sqrt(-1), it is collected, too.
- Examples
- ========
- >>> from sympy import sqrt
- >>> from sympy.simplify.radsimp import collect_sqrt
- >>> from sympy.abc import a, b
- >>> r2, r3, r5 = [sqrt(i) for i in [2, 3, 5]]
- >>> collect_sqrt(a*r2 + b*r2)
- sqrt(2)*(a + b)
- >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r3)
- sqrt(2)*(a + b) + sqrt(3)*(a + b)
- >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r5)
- sqrt(3)*a + sqrt(5)*b + sqrt(2)*(a + b)
- If evaluate is False then the arguments will be sorted and
- returned as a list and a count of the number of sqrt-containing
- terms will be returned:
- >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r5, evaluate=False)
- ((sqrt(3)*a, sqrt(5)*b, sqrt(2)*(a + b)), 3)
- >>> collect_sqrt(a*sqrt(2) + b, evaluate=False)
- ((b, sqrt(2)*a), 1)
- >>> collect_sqrt(a + b, evaluate=False)
- ((a + b,), 0)
- See Also
- ========
- collect, collect_const, rcollect
- """
- if evaluate is None:
- evaluate = global_parameters.evaluate
- # this step will help to standardize any complex arguments
- # of sqrts
- coeff, expr = expr.as_content_primitive()
- vars = set()
- for a in Add.make_args(expr):
- for m in a.args_cnc()[0]:
- if m.is_number and (
- m.is_Pow and m.exp.is_Rational and m.exp.q == 2 or
- m is S.ImaginaryUnit):
- vars.add(m)
- # we only want radicals, so exclude Number handling; in this case
- # d will be evaluated
- d = collect_const(expr, *vars, Numbers=False)
- hit = expr != d
- if not evaluate:
- nrad = 0
- # make the evaluated args canonical
- args = list(ordered(Add.make_args(d)))
- for i, m in enumerate(args):
- c, nc = m.args_cnc()
- for ci in c:
- # XXX should this be restricted to ci.is_number as above?
- if ci.is_Pow and ci.exp.is_Rational and ci.exp.q == 2 or \
- ci is S.ImaginaryUnit:
- nrad += 1
- break
- args[i] *= coeff
- if not (hit or nrad):
- args = [Add(*args)]
- return tuple(args), nrad
- return coeff*d
- def collect_abs(expr):
- """Return ``expr`` with arguments of multiple Abs in a term collected
- under a single instance.
- Examples
- ========
- >>> from sympy.simplify.radsimp import collect_abs
- >>> from sympy.abc import x
- >>> collect_abs(abs(x + 1)/abs(x**2 - 1))
- Abs((x + 1)/(x**2 - 1))
- >>> collect_abs(abs(1/x))
- Abs(1/x)
- """
- def _abs(mul):
- c, nc = mul.args_cnc()
- a = []
- o = []
- for i in c:
- if isinstance(i, Abs):
- a.append(i.args[0])
- elif isinstance(i, Pow) and isinstance(i.base, Abs) and i.exp.is_real:
- a.append(i.base.args[0]**i.exp)
- else:
- o.append(i)
- if len(a) < 2 and not any(i.exp.is_negative for i in a if isinstance(i, Pow)):
- return mul
- absarg = Mul(*a)
- A = Abs(absarg)
- args = [A]
- args.extend(o)
- if not A.has(Abs):
- args.extend(nc)
- return Mul(*args)
- if not isinstance(A, Abs):
- # reevaluate and make it unevaluated
- A = Abs(absarg, evaluate=False)
- args[0] = A
- _mulsort(args)
- args.extend(nc) # nc always go last
- return Mul._from_args(args, is_commutative=not nc)
- return expr.replace(
- lambda x: isinstance(x, Mul),
- lambda x: _abs(x)).replace(
- lambda x: isinstance(x, Pow),
- lambda x: _abs(x))
- def collect_const(expr, *vars, Numbers=True):
- """A non-greedy collection of terms with similar number coefficients in
- an Add expr. If ``vars`` is given then only those constants will be
- targeted. Although any Number can also be targeted, if this is not
- desired set ``Numbers=False`` and no Float or Rational will be collected.
- Parameters
- ==========
- expr : SymPy expression
- This parameter defines the expression the expression from which
- terms with similar coefficients are to be collected. A non-Add
- expression is returned as it is.
- vars : variable length collection of Numbers, optional
- Specifies the constants to target for collection. Can be multiple in
- number.
- Numbers : bool
- Specifies to target all instance of
- :class:`sympy.core.numbers.Number` class. If ``Numbers=False``, then
- no Float or Rational will be collected.
- Returns
- =======
- expr : Expr
- Returns an expression with similar coefficient terms collected.
- Examples
- ========
- >>> from sympy import sqrt
- >>> from sympy.abc import s, x, y, z
- >>> from sympy.simplify.radsimp import collect_const
- >>> collect_const(sqrt(3) + sqrt(3)*(1 + sqrt(2)))
- sqrt(3)*(sqrt(2) + 2)
- >>> collect_const(sqrt(3)*s + sqrt(7)*s + sqrt(3) + sqrt(7))
- (sqrt(3) + sqrt(7))*(s + 1)
- >>> s = sqrt(2) + 2
- >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7))
- (sqrt(2) + 3)*(sqrt(3) + sqrt(7))
- >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7), sqrt(3))
- sqrt(7) + sqrt(3)*(sqrt(2) + 3) + sqrt(7)*(sqrt(2) + 2)
- The collection is sign-sensitive, giving higher precedence to the
- unsigned values:
- >>> collect_const(x - y - z)
- x - (y + z)
- >>> collect_const(-y - z)
- -(y + z)
- >>> collect_const(2*x - 2*y - 2*z, 2)
- 2*(x - y - z)
- >>> collect_const(2*x - 2*y - 2*z, -2)
- 2*x - 2*(y + z)
- See Also
- ========
- collect, collect_sqrt, rcollect
- """
- if not expr.is_Add:
- return expr
- recurse = False
- if not vars:
- recurse = True
- vars = set()
- for a in expr.args:
- for m in Mul.make_args(a):
- if m.is_number:
- vars.add(m)
- else:
- vars = sympify(vars)
- if not Numbers:
- vars = [v for v in vars if not v.is_Number]
- vars = list(ordered(vars))
- for v in vars:
- terms = defaultdict(list)
- Fv = Factors(v)
- for m in Add.make_args(expr):
- f = Factors(m)
- q, r = f.div(Fv)
- if r.is_one:
- # only accept this as a true factor if
- # it didn't change an exponent from an Integer
- # to a non-Integer, e.g. 2/sqrt(2) -> sqrt(2)
- # -- we aren't looking for this sort of change
- fwas = f.factors.copy()
- fnow = q.factors
- if not any(k in fwas and fwas[k].is_Integer and not
- fnow[k].is_Integer for k in fnow):
- terms[v].append(q.as_expr())
- continue
- terms[S.One].append(m)
- args = []
- hit = False
- uneval = False
- for k in ordered(terms):
- v = terms[k]
- if k is S.One:
- args.extend(v)
- continue
- if len(v) > 1:
- v = Add(*v)
- hit = True
- if recurse and v != expr:
- vars.append(v)
- else:
- v = v[0]
- # be careful not to let uneval become True unless
- # it must be because it's going to be more expensive
- # to rebuild the expression as an unevaluated one
- if Numbers and k.is_Number and v.is_Add:
- args.append(_keep_coeff(k, v, sign=True))
- uneval = True
- else:
- args.append(k*v)
- if hit:
- if uneval:
- expr = _unevaluated_Add(*args)
- else:
- expr = Add(*args)
- if not expr.is_Add:
- break
- return expr
- def radsimp(expr, symbolic=True, max_terms=4):
- r"""
- Rationalize the denominator by removing square roots.
- Explanation
- ===========
- The expression returned from radsimp must be used with caution
- since if the denominator contains symbols, it will be possible to make
- substitutions that violate the assumptions of the simplification process:
- that for a denominator matching a + b*sqrt(c), a != +/-b*sqrt(c). (If
- there are no symbols, this assumptions is made valid by collecting terms
- of sqrt(c) so the match variable ``a`` does not contain ``sqrt(c)``.) If
- you do not want the simplification to occur for symbolic denominators, set
- ``symbolic`` to False.
- If there are more than ``max_terms`` radical terms then the expression is
- returned unchanged.
- Examples
- ========
- >>> from sympy import radsimp, sqrt, Symbol, pprint
- >>> from sympy import factor_terms, fraction, signsimp
- >>> from sympy.simplify.radsimp import collect_sqrt
- >>> from sympy.abc import a, b, c
- >>> radsimp(1/(2 + sqrt(2)))
- (2 - sqrt(2))/2
- >>> x,y = map(Symbol, 'xy')
- >>> e = ((2 + 2*sqrt(2))*x + (2 + sqrt(8))*y)/(2 + sqrt(2))
- >>> radsimp(e)
- sqrt(2)*(x + y)
- No simplification beyond removal of the gcd is done. One might
- want to polish the result a little, however, by collecting
- square root terms:
- >>> r2 = sqrt(2)
- >>> r5 = sqrt(5)
- >>> ans = radsimp(1/(y*r2 + x*r2 + a*r5 + b*r5)); pprint(ans)
- ___ ___ ___ ___
- \/ 5 *a + \/ 5 *b - \/ 2 *x - \/ 2 *y
- ------------------------------------------
- 2 2 2 2
- 5*a + 10*a*b + 5*b - 2*x - 4*x*y - 2*y
- >>> n, d = fraction(ans)
- >>> pprint(factor_terms(signsimp(collect_sqrt(n))/d, radical=True))
- ___ ___
- \/ 5 *(a + b) - \/ 2 *(x + y)
- ------------------------------------------
- 2 2 2 2
- 5*a + 10*a*b + 5*b - 2*x - 4*x*y - 2*y
- If radicals in the denominator cannot be removed or there is no denominator,
- the original expression will be returned.
- >>> radsimp(sqrt(2)*x + sqrt(2))
- sqrt(2)*x + sqrt(2)
- Results with symbols will not always be valid for all substitutions:
- >>> eq = 1/(a + b*sqrt(c))
- >>> eq.subs(a, b*sqrt(c))
- 1/(2*b*sqrt(c))
- >>> radsimp(eq).subs(a, b*sqrt(c))
- nan
- If ``symbolic=False``, symbolic denominators will not be transformed (but
- numeric denominators will still be processed):
- >>> radsimp(eq, symbolic=False)
- 1/(a + b*sqrt(c))
- """
- from sympy.simplify.simplify import signsimp
- syms = symbols("a:d A:D")
- def _num(rterms):
- # return the multiplier that will simplify the expression described
- # by rterms [(sqrt arg, coeff), ... ]
- a, b, c, d, A, B, C, D = syms
- if len(rterms) == 2:
- reps = dict(list(zip([A, a, B, b], [j for i in rterms for j in i])))
- return (
- sqrt(A)*a - sqrt(B)*b).xreplace(reps)
- if len(rterms) == 3:
- reps = dict(list(zip([A, a, B, b, C, c], [j for i in rterms for j in i])))
- return (
- (sqrt(A)*a + sqrt(B)*b - sqrt(C)*c)*(2*sqrt(A)*sqrt(B)*a*b - A*a**2 -
- B*b**2 + C*c**2)).xreplace(reps)
- elif len(rterms) == 4:
- reps = dict(list(zip([A, a, B, b, C, c, D, d], [j for i in rterms for j in i])))
- return ((sqrt(A)*a + sqrt(B)*b - sqrt(C)*c - sqrt(D)*d)*(2*sqrt(A)*sqrt(B)*a*b
- - A*a**2 - B*b**2 - 2*sqrt(C)*sqrt(D)*c*d + C*c**2 +
- D*d**2)*(-8*sqrt(A)*sqrt(B)*sqrt(C)*sqrt(D)*a*b*c*d + A**2*a**4 -
- 2*A*B*a**2*b**2 - 2*A*C*a**2*c**2 - 2*A*D*a**2*d**2 + B**2*b**4 -
- 2*B*C*b**2*c**2 - 2*B*D*b**2*d**2 + C**2*c**4 - 2*C*D*c**2*d**2 +
- D**2*d**4)).xreplace(reps)
- elif len(rterms) == 1:
- return sqrt(rterms[0][0])
- else:
- raise NotImplementedError
- def ispow2(d, log2=False):
- if not d.is_Pow:
- return False
- e = d.exp
- if e.is_Rational and e.q == 2 or symbolic and denom(e) == 2:
- return True
- if log2:
- q = 1
- if e.is_Rational:
- q = e.q
- elif symbolic:
- d = denom(e)
- if d.is_Integer:
- q = d
- if q != 1 and log(q, 2).is_Integer:
- return True
- return False
- def handle(expr):
- # Handle first reduces to the case
- # expr = 1/d, where d is an add, or d is base**p/2.
- # We do this by recursively calling handle on each piece.
- from sympy.simplify.simplify import nsimplify
- n, d = fraction(expr)
- if expr.is_Atom or (d.is_Atom and n.is_Atom):
- return expr
- elif not n.is_Atom:
- n = n.func(*[handle(a) for a in n.args])
- return _unevaluated_Mul(n, handle(1/d))
- elif n is not S.One:
- return _unevaluated_Mul(n, handle(1/d))
- elif d.is_Mul:
- return _unevaluated_Mul(*[handle(1/d) for d in d.args])
- # By this step, expr is 1/d, and d is not a mul.
- if not symbolic and d.free_symbols:
- return expr
- if ispow2(d):
- d2 = sqrtdenest(sqrt(d.base))**numer(d.exp)
- if d2 != d:
- return handle(1/d2)
- elif d.is_Pow and (d.exp.is_integer or d.base.is_positive):
- # (1/d**i) = (1/d)**i
- return handle(1/d.base)**d.exp
- if not (d.is_Add or ispow2(d)):
- return 1/d.func(*[handle(a) for a in d.args])
- # handle 1/d treating d as an Add (though it may not be)
- keep = True # keep changes that are made
- # flatten it and collect radicals after checking for special
- # conditions
- d = _mexpand(d)
- # did it change?
- if d.is_Atom:
- return 1/d
- # is it a number that might be handled easily?
- if d.is_number:
- _d = nsimplify(d)
- if _d.is_Number and _d.equals(d):
- return 1/_d
- while True:
- # collect similar terms
- collected = defaultdict(list)
- for m in Add.make_args(d): # d might have become non-Add
- p2 = []
- other = []
- for i in Mul.make_args(m):
- if ispow2(i, log2=True):
- p2.append(i.base if i.exp is S.Half else i.base**(2*i.exp))
- elif i is S.ImaginaryUnit:
- p2.append(S.NegativeOne)
- else:
- other.append(i)
- collected[tuple(ordered(p2))].append(Mul(*other))
- rterms = list(ordered(list(collected.items())))
- rterms = [(Mul(*i), Add(*j)) for i, j in rterms]
- nrad = len(rterms) - (1 if rterms[0][0] is S.One else 0)
- if nrad < 1:
- break
- elif nrad > max_terms:
- # there may have been invalid operations leading to this point
- # so don't keep changes, e.g. this expression is troublesome
- # in collecting terms so as not to raise the issue of 2834:
- # r = sqrt(sqrt(5) + 5)
- # eq = 1/(sqrt(5)*r + 2*sqrt(5)*sqrt(-sqrt(5) + 5) + 5*r)
- keep = False
- break
- if len(rterms) > 4:
- # in general, only 4 terms can be removed with repeated squaring
- # but other considerations can guide selection of radical terms
- # so that radicals are removed
- if all(x.is_Integer and (y**2).is_Rational for x, y in rterms):
- nd, d = rad_rationalize(S.One, Add._from_args(
- [sqrt(x)*y for x, y in rterms]))
- n *= nd
- else:
- # is there anything else that might be attempted?
- keep = False
- break
- from sympy.simplify.powsimp import powsimp, powdenest
- num = powsimp(_num(rterms))
- n *= num
- d *= num
- d = powdenest(_mexpand(d), force=symbolic)
- if d.has(S.Zero, nan, zoo):
- return expr
- if d.is_Atom:
- break
- if not keep:
- return expr
- return _unevaluated_Mul(n, 1/d)
- coeff, expr = expr.as_coeff_Add()
- expr = expr.normal()
- old = fraction(expr)
- n, d = fraction(handle(expr))
- if old != (n, d):
- if not d.is_Atom:
- was = (n, d)
- n = signsimp(n, evaluate=False)
- d = signsimp(d, evaluate=False)
- u = Factors(_unevaluated_Mul(n, 1/d))
- u = _unevaluated_Mul(*[k**v for k, v in u.factors.items()])
- n, d = fraction(u)
- if old == (n, d):
- n, d = was
- n = expand_mul(n)
- if d.is_Number or d.is_Add:
- n2, d2 = fraction(gcd_terms(_unevaluated_Mul(n, 1/d)))
- if d2.is_Number or (d2.count_ops() <= d.count_ops()):
- n, d = [signsimp(i) for i in (n2, d2)]
- if n.is_Mul and n.args[0].is_Number:
- n = n.func(*n.args)
- return coeff + _unevaluated_Mul(n, 1/d)
- def rad_rationalize(num, den):
- """
- Rationalize ``num/den`` by removing square roots in the denominator;
- num and den are sum of terms whose squares are positive rationals.
- Examples
- ========
- >>> from sympy import sqrt
- >>> from sympy.simplify.radsimp import rad_rationalize
- >>> rad_rationalize(sqrt(3), 1 + sqrt(2)/3)
- (-sqrt(3) + sqrt(6)/3, -7/9)
- """
- if not den.is_Add:
- return num, den
- g, a, b = split_surds(den)
- a = a*sqrt(g)
- num = _mexpand((a - b)*num)
- den = _mexpand(a**2 - b**2)
- return rad_rationalize(num, den)
- def fraction(expr, exact=False):
- """Returns a pair with expression's numerator and denominator.
- If the given expression is not a fraction then this function
- will return the tuple (expr, 1).
- This function will not make any attempt to simplify nested
- fractions or to do any term rewriting at all.
- If only one of the numerator/denominator pair is needed then
- use numer(expr) or denom(expr) functions respectively.
- >>> from sympy import fraction, Rational, Symbol
- >>> from sympy.abc import x, y
- >>> fraction(x/y)
- (x, y)
- >>> fraction(x)
- (x, 1)
- >>> fraction(1/y**2)
- (1, y**2)
- >>> fraction(x*y/2)
- (x*y, 2)
- >>> fraction(Rational(1, 2))
- (1, 2)
- This function will also work fine with assumptions:
- >>> k = Symbol('k', negative=True)
- >>> fraction(x * y**k)
- (x, y**(-k))
- If we know nothing about sign of some exponent and ``exact``
- flag is unset, then structure this exponent's structure will
- be analyzed and pretty fraction will be returned:
- >>> from sympy import exp, Mul
- >>> fraction(2*x**(-y))
- (2, x**y)
- >>> fraction(exp(-x))
- (1, exp(x))
- >>> fraction(exp(-x), exact=True)
- (exp(-x), 1)
- The ``exact`` flag will also keep any unevaluated Muls from
- being evaluated:
- >>> u = Mul(2, x + 1, evaluate=False)
- >>> fraction(u)
- (2*x + 2, 1)
- >>> fraction(u, exact=True)
- (2*(x + 1), 1)
- """
- expr = sympify(expr)
- numer, denom = [], []
- for term in Mul.make_args(expr):
- if term.is_commutative and (term.is_Pow or isinstance(term, exp)):
- b, ex = term.as_base_exp()
- if ex.is_negative:
- if ex is S.NegativeOne:
- denom.append(b)
- elif exact:
- if ex.is_constant():
- denom.append(Pow(b, -ex))
- else:
- numer.append(term)
- else:
- denom.append(Pow(b, -ex))
- elif ex.is_positive:
- numer.append(term)
- elif not exact and ex.is_Mul:
- n, d = term.as_numer_denom()
- if n != 1:
- numer.append(n)
- denom.append(d)
- else:
- numer.append(term)
- elif term.is_Rational and not term.is_Integer:
- if term.p != 1:
- numer.append(term.p)
- denom.append(term.q)
- else:
- numer.append(term)
- return Mul(*numer, evaluate=not exact), Mul(*denom, evaluate=not exact)
- def numer(expr):
- return fraction(expr)[0]
- def denom(expr):
- return fraction(expr)[1]
- def fraction_expand(expr, **hints):
- return expr.expand(frac=True, **hints)
- def numer_expand(expr, **hints):
- a, b = fraction(expr)
- return a.expand(numer=True, **hints) / b
- def denom_expand(expr, **hints):
- a, b = fraction(expr)
- return a / b.expand(denom=True, **hints)
- expand_numer = numer_expand
- expand_denom = denom_expand
- expand_fraction = fraction_expand
- def split_surds(expr):
- """
- Split an expression with terms whose squares are positive rationals
- into a sum of terms whose surds squared have gcd equal to g
- and a sum of terms with surds squared prime with g.
- Examples
- ========
- >>> from sympy import sqrt
- >>> from sympy.simplify.radsimp import split_surds
- >>> split_surds(3*sqrt(3) + sqrt(5)/7 + sqrt(6) + sqrt(10) + sqrt(15))
- (3, sqrt(2) + sqrt(5) + 3, sqrt(5)/7 + sqrt(10))
- """
- args = sorted(expr.args, key=default_sort_key)
- coeff_muls = [x.as_coeff_Mul() for x in args]
- surds = [x[1]**2 for x in coeff_muls if x[1].is_Pow]
- surds.sort(key=default_sort_key)
- g, b1, b2 = _split_gcd(*surds)
- g2 = g
- if not b2 and len(b1) >= 2:
- b1n = [x/g for x in b1]
- b1n = [x for x in b1n if x != 1]
- # only a common factor has been factored; split again
- g1, b1n, b2 = _split_gcd(*b1n)
- g2 = g*g1
- a1v, a2v = [], []
- for c, s in coeff_muls:
- if s.is_Pow and s.exp == S.Half:
- s1 = s.base
- if s1 in b1:
- a1v.append(c*sqrt(s1/g2))
- else:
- a2v.append(c*s)
- else:
- a2v.append(c*s)
- a = Add(*a1v)
- b = Add(*a2v)
- return g2, a, b
- def _split_gcd(*a):
- """
- Split the list of integers ``a`` into a list of integers, ``a1`` having
- ``g = gcd(a1)``, and a list ``a2`` whose elements are not divisible by
- ``g``. Returns ``g, a1, a2``.
- Examples
- ========
- >>> from sympy.simplify.radsimp import _split_gcd
- >>> _split_gcd(55, 35, 22, 14, 77, 10)
- (5, [55, 35, 10], [22, 14, 77])
- """
- g = a[0]
- b1 = [g]
- b2 = []
- for x in a[1:]:
- g1 = gcd(g, x)
- if g1 == 1:
- b2.append(x)
- else:
- g = g1
- b1.append(x)
- return g, b1, b2
|