# # sympy.polys.matrices.linsolve module # # This module defines the _linsolve function which is the internal workhorse # used by linsolve. This computes the solution of a system of linear equations # using the SDM sparse matrix implementation in sympy.polys.matrices.sdm. This # is a replacement for solve_lin_sys in sympy.polys.solvers which is # inefficient for large sparse systems due to the use of a PolyRing with many # generators: # # https://github.com/sympy/sympy/issues/20857 # # The implementation of _linsolve here handles: # # - Extracting the coefficients from the Expr/Eq input equations. # - Constructing a domain and converting the coefficients to # that domain. # - Using the SDM.rref, SDM.nullspace etc methods to generate the full # solution working with arithmetic only in the domain of the coefficients. # # The routines here are particularly designed to be efficient for large sparse # systems of linear equations although as well as dense systems. It is # possible that for some small dense systems solve_lin_sys which uses the # dense matrix implementation DDM will be more efficient. With smaller systems # though the bulk of the time is spent just preprocessing the inputs and the # relative time spent in rref is too small to be noticeable. # from collections import defaultdict from sympy.core.add import Add from sympy.core.mul import Mul from sympy.core.singleton import S from sympy.polys.constructor import construct_domain from sympy.polys.solvers import PolyNonlinearError from .sdm import ( SDM, sdm_irref, sdm_particular_from_rref, sdm_nullspace_from_rref ) def _linsolve(eqs, syms): """Solve a linear system of equations. Examples ======== Solve a linear system with a unique solution: >>> from sympy import symbols, Eq >>> from sympy.polys.matrices.linsolve import _linsolve >>> x, y = symbols('x, y') >>> eqs = [Eq(x + y, 1), Eq(x - y, 2)] >>> _linsolve(eqs, [x, y]) {x: 3/2, y: -1/2} In the case of underdetermined systems the solution will be expressed in terms of the unknown symbols that are unconstrained: >>> _linsolve([Eq(x + y, 0)], [x, y]) {x: -y, y: y} """ # Number of unknowns (columns in the non-augmented matrix) nsyms = len(syms) # Convert to sparse augmented matrix (len(eqs) x (nsyms+1)) eqsdict, rhs = _linear_eq_to_dict(eqs, syms) Aaug = sympy_dict_to_dm(eqsdict, rhs, syms) K = Aaug.domain # sdm_irref has issues with float matrices. This uses the ddm_rref() # function. When sdm_rref() can handle float matrices reasonably this # should be removed... if K.is_RealField or K.is_ComplexField: Aaug = Aaug.to_ddm().rref()[0].to_sdm() # Compute reduced-row echelon form (RREF) Arref, pivots, nzcols = sdm_irref(Aaug) # No solution: if pivots and pivots[-1] == nsyms: return None # Particular solution for non-homogeneous system: P = sdm_particular_from_rref(Arref, nsyms+1, pivots) # Nullspace - general solution to homogeneous system # Note: using nsyms not nsyms+1 to ignore last column V, nonpivots = sdm_nullspace_from_rref(Arref, K.one, nsyms, pivots, nzcols) # Collect together terms from particular and nullspace: sol = defaultdict(list) for i, v in P.items(): sol[syms[i]].append(K.to_sympy(v)) for npi, Vi in zip(nonpivots, V): sym = syms[npi] for i, v in Vi.items(): sol[syms[i]].append(sym * K.to_sympy(v)) # Use a single call to Add for each term: sol = {s: Add(*terms) for s, terms in sol.items()} # Fill in the zeros: zero = S.Zero for s in set(syms) - set(sol): sol[s] = zero # All done! return sol def sympy_dict_to_dm(eqs_coeffs, eqs_rhs, syms): """Convert a system of dict equations to a sparse augmented matrix""" elems = set(eqs_rhs).union(*(e.values() for e in eqs_coeffs)) K, elems_K = construct_domain(elems, field=True, extension=True) elem_map = dict(zip(elems, elems_K)) neqs = len(eqs_coeffs) nsyms = len(syms) sym2index = dict(zip(syms, range(nsyms))) eqsdict = [] for eq, rhs in zip(eqs_coeffs, eqs_rhs): eqdict = {sym2index[s]: elem_map[c] for s, c in eq.items()} if rhs: eqdict[nsyms] = - elem_map[rhs] if eqdict: eqsdict.append(eqdict) sdm_aug = SDM(enumerate(eqsdict), (neqs, nsyms+1), K) return sdm_aug def _expand_eqs_deprecated(eqs): """Use expand to cancel nonlinear terms. This approach matches previous behaviour of linsolve but should be deprecated. """ def expand_eq(eq): if eq.is_Equality: eq = eq.lhs - eq.rhs return eq.expand() return [expand_eq(eq) for eq in eqs] def _linear_eq_to_dict(eqs, syms): """Convert a system Expr/Eq equations into dict form""" try: return _linear_eq_to_dict_inner(eqs, syms) except PolyNonlinearError: # XXX: This should be deprecated: eqs = _expand_eqs_deprecated(eqs) return _linear_eq_to_dict_inner(eqs, syms) def _linear_eq_to_dict_inner(eqs, syms): """Convert a system Expr/Eq equations into dict form""" syms = set(syms) eqsdict, eqs_rhs = [], [] for eq in eqs: rhs, eqdict = _lin_eq2dict(eq, syms) eqsdict.append(eqdict) eqs_rhs.append(rhs) return eqsdict, eqs_rhs def _lin_eq2dict(a, symset): """Efficiently convert a linear equation to a dict of coefficients""" if a in symset: return S.Zero, {a: S.One} elif a.is_Add: terms_list = defaultdict(list) coeff_list = [] for ai in a.args: ci, ti = _lin_eq2dict(ai, symset) coeff_list.append(ci) for mij, cij in ti.items(): terms_list[mij].append(cij) coeff = Add(*coeff_list) terms = {sym: Add(*coeffs) for sym, coeffs in terms_list.items()} return coeff, terms elif a.is_Mul: terms = terms_coeff = None coeff_list = [] for ai in a.args: ci, ti = _lin_eq2dict(ai, symset) if not ti: coeff_list.append(ci) elif terms is None: terms = ti terms_coeff = ci else: raise PolyNonlinearError coeff = Mul(*coeff_list) if terms is None: return coeff, {} else: terms = {sym: coeff * c for sym, c in terms.items()} return coeff * terms_coeff, terms elif a.is_Equality: return _lin_eq2dict(a.lhs - a.rhs, symset) elif not a.has_free(*symset): return a, {} else: raise PolyNonlinearError