linsolve.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. #
  2. # sympy.polys.matrices.linsolve module
  3. #
  4. # This module defines the _linsolve function which is the internal workhorse
  5. # used by linsolve. This computes the solution of a system of linear equations
  6. # using the SDM sparse matrix implementation in sympy.polys.matrices.sdm. This
  7. # is a replacement for solve_lin_sys in sympy.polys.solvers which is
  8. # inefficient for large sparse systems due to the use of a PolyRing with many
  9. # generators:
  10. #
  11. # https://github.com/sympy/sympy/issues/20857
  12. #
  13. # The implementation of _linsolve here handles:
  14. #
  15. # - Extracting the coefficients from the Expr/Eq input equations.
  16. # - Constructing a domain and converting the coefficients to
  17. # that domain.
  18. # - Using the SDM.rref, SDM.nullspace etc methods to generate the full
  19. # solution working with arithmetic only in the domain of the coefficients.
  20. #
  21. # The routines here are particularly designed to be efficient for large sparse
  22. # systems of linear equations although as well as dense systems. It is
  23. # possible that for some small dense systems solve_lin_sys which uses the
  24. # dense matrix implementation DDM will be more efficient. With smaller systems
  25. # though the bulk of the time is spent just preprocessing the inputs and the
  26. # relative time spent in rref is too small to be noticeable.
  27. #
  28. from collections import defaultdict
  29. from sympy.core.add import Add
  30. from sympy.core.mul import Mul
  31. from sympy.core.singleton import S
  32. from sympy.polys.constructor import construct_domain
  33. from sympy.polys.solvers import PolyNonlinearError
  34. from .sdm import (
  35. SDM,
  36. sdm_irref,
  37. sdm_particular_from_rref,
  38. sdm_nullspace_from_rref
  39. )
  40. def _linsolve(eqs, syms):
  41. """Solve a linear system of equations.
  42. Examples
  43. ========
  44. Solve a linear system with a unique solution:
  45. >>> from sympy import symbols, Eq
  46. >>> from sympy.polys.matrices.linsolve import _linsolve
  47. >>> x, y = symbols('x, y')
  48. >>> eqs = [Eq(x + y, 1), Eq(x - y, 2)]
  49. >>> _linsolve(eqs, [x, y])
  50. {x: 3/2, y: -1/2}
  51. In the case of underdetermined systems the solution will be expressed in
  52. terms of the unknown symbols that are unconstrained:
  53. >>> _linsolve([Eq(x + y, 0)], [x, y])
  54. {x: -y, y: y}
  55. """
  56. # Number of unknowns (columns in the non-augmented matrix)
  57. nsyms = len(syms)
  58. # Convert to sparse augmented matrix (len(eqs) x (nsyms+1))
  59. eqsdict, rhs = _linear_eq_to_dict(eqs, syms)
  60. Aaug = sympy_dict_to_dm(eqsdict, rhs, syms)
  61. K = Aaug.domain
  62. # sdm_irref has issues with float matrices. This uses the ddm_rref()
  63. # function. When sdm_rref() can handle float matrices reasonably this
  64. # should be removed...
  65. if K.is_RealField or K.is_ComplexField:
  66. Aaug = Aaug.to_ddm().rref()[0].to_sdm()
  67. # Compute reduced-row echelon form (RREF)
  68. Arref, pivots, nzcols = sdm_irref(Aaug)
  69. # No solution:
  70. if pivots and pivots[-1] == nsyms:
  71. return None
  72. # Particular solution for non-homogeneous system:
  73. P = sdm_particular_from_rref(Arref, nsyms+1, pivots)
  74. # Nullspace - general solution to homogeneous system
  75. # Note: using nsyms not nsyms+1 to ignore last column
  76. V, nonpivots = sdm_nullspace_from_rref(Arref, K.one, nsyms, pivots, nzcols)
  77. # Collect together terms from particular and nullspace:
  78. sol = defaultdict(list)
  79. for i, v in P.items():
  80. sol[syms[i]].append(K.to_sympy(v))
  81. for npi, Vi in zip(nonpivots, V):
  82. sym = syms[npi]
  83. for i, v in Vi.items():
  84. sol[syms[i]].append(sym * K.to_sympy(v))
  85. # Use a single call to Add for each term:
  86. sol = {s: Add(*terms) for s, terms in sol.items()}
  87. # Fill in the zeros:
  88. zero = S.Zero
  89. for s in set(syms) - set(sol):
  90. sol[s] = zero
  91. # All done!
  92. return sol
  93. def sympy_dict_to_dm(eqs_coeffs, eqs_rhs, syms):
  94. """Convert a system of dict equations to a sparse augmented matrix"""
  95. elems = set(eqs_rhs).union(*(e.values() for e in eqs_coeffs))
  96. K, elems_K = construct_domain(elems, field=True, extension=True)
  97. elem_map = dict(zip(elems, elems_K))
  98. neqs = len(eqs_coeffs)
  99. nsyms = len(syms)
  100. sym2index = dict(zip(syms, range(nsyms)))
  101. eqsdict = []
  102. for eq, rhs in zip(eqs_coeffs, eqs_rhs):
  103. eqdict = {sym2index[s]: elem_map[c] for s, c in eq.items()}
  104. if rhs:
  105. eqdict[nsyms] = - elem_map[rhs]
  106. if eqdict:
  107. eqsdict.append(eqdict)
  108. sdm_aug = SDM(enumerate(eqsdict), (neqs, nsyms+1), K)
  109. return sdm_aug
  110. def _expand_eqs_deprecated(eqs):
  111. """Use expand to cancel nonlinear terms.
  112. This approach matches previous behaviour of linsolve but should be
  113. deprecated.
  114. """
  115. def expand_eq(eq):
  116. if eq.is_Equality:
  117. eq = eq.lhs - eq.rhs
  118. return eq.expand()
  119. return [expand_eq(eq) for eq in eqs]
  120. def _linear_eq_to_dict(eqs, syms):
  121. """Convert a system Expr/Eq equations into dict form"""
  122. try:
  123. return _linear_eq_to_dict_inner(eqs, syms)
  124. except PolyNonlinearError:
  125. # XXX: This should be deprecated:
  126. eqs = _expand_eqs_deprecated(eqs)
  127. return _linear_eq_to_dict_inner(eqs, syms)
  128. def _linear_eq_to_dict_inner(eqs, syms):
  129. """Convert a system Expr/Eq equations into dict form"""
  130. syms = set(syms)
  131. eqsdict, eqs_rhs = [], []
  132. for eq in eqs:
  133. rhs, eqdict = _lin_eq2dict(eq, syms)
  134. eqsdict.append(eqdict)
  135. eqs_rhs.append(rhs)
  136. return eqsdict, eqs_rhs
  137. def _lin_eq2dict(a, symset):
  138. """Efficiently convert a linear equation to a dict of coefficients"""
  139. if a in symset:
  140. return S.Zero, {a: S.One}
  141. elif a.is_Add:
  142. terms_list = defaultdict(list)
  143. coeff_list = []
  144. for ai in a.args:
  145. ci, ti = _lin_eq2dict(ai, symset)
  146. coeff_list.append(ci)
  147. for mij, cij in ti.items():
  148. terms_list[mij].append(cij)
  149. coeff = Add(*coeff_list)
  150. terms = {sym: Add(*coeffs) for sym, coeffs in terms_list.items()}
  151. return coeff, terms
  152. elif a.is_Mul:
  153. terms = terms_coeff = None
  154. coeff_list = []
  155. for ai in a.args:
  156. ci, ti = _lin_eq2dict(ai, symset)
  157. if not ti:
  158. coeff_list.append(ci)
  159. elif terms is None:
  160. terms = ti
  161. terms_coeff = ci
  162. else:
  163. raise PolyNonlinearError
  164. coeff = Mul(*coeff_list)
  165. if terms is None:
  166. return coeff, {}
  167. else:
  168. terms = {sym: coeff * c for sym, c in terms.items()}
  169. return coeff * terms_coeff, terms
  170. elif a.is_Equality:
  171. return _lin_eq2dict(a.lhs - a.rhs, symset)
  172. elif not a.has_free(*symset):
  173. return a, {}
  174. else:
  175. raise PolyNonlinearError