solvers.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. """Low-level linear systems solver. """
  2. from sympy.utilities.exceptions import sympy_deprecation_warning
  3. from sympy.utilities.iterables import connected_components
  4. from sympy.core.sympify import sympify
  5. from sympy.core.numbers import Integer, Rational
  6. from sympy.matrices.dense import MutableDenseMatrix
  7. from sympy.polys.domains import ZZ, QQ
  8. from sympy.polys.domains import EX
  9. from sympy.polys.rings import sring
  10. from sympy.polys.polyerrors import NotInvertible
  11. from sympy.polys.domainmatrix import DomainMatrix
  12. class PolyNonlinearError(Exception):
  13. """Raised by solve_lin_sys for nonlinear equations"""
  14. pass
  15. class RawMatrix(MutableDenseMatrix):
  16. """
  17. .. deprecated:: 1.9
  18. This class fundamentally is broken by design. Use ``DomainMatrix`` if
  19. you want a matrix over the polys domains or ``Matrix`` for a matrix
  20. with ``Expr`` elements. The ``RawMatrix`` class will be removed/broken
  21. in future in order to reestablish the invariant that the elements of a
  22. Matrix should be of type ``Expr``.
  23. """
  24. _sympify = staticmethod(lambda x: x)
  25. def __init__(self, *args, **kwargs):
  26. sympy_deprecation_warning(
  27. """
  28. The RawMatrix class is deprecated. Use either DomainMatrix or
  29. Matrix instead.
  30. """,
  31. deprecated_since_version="1.9",
  32. active_deprecations_target="deprecated-rawmatrix",
  33. )
  34. domain = ZZ
  35. for i in range(self.rows):
  36. for j in range(self.cols):
  37. val = self[i,j]
  38. if getattr(val, 'is_Poly', False):
  39. K = val.domain[val.gens]
  40. val_sympy = val.as_expr()
  41. elif hasattr(val, 'parent'):
  42. K = val.parent()
  43. val_sympy = K.to_sympy(val)
  44. elif isinstance(val, (int, Integer)):
  45. K = ZZ
  46. val_sympy = sympify(val)
  47. elif isinstance(val, Rational):
  48. K = QQ
  49. val_sympy = val
  50. else:
  51. for K in ZZ, QQ:
  52. if K.of_type(val):
  53. val_sympy = K.to_sympy(val)
  54. break
  55. else:
  56. raise TypeError
  57. domain = domain.unify(K)
  58. self[i,j] = val_sympy
  59. self.ring = domain
  60. def eqs_to_matrix(eqs_coeffs, eqs_rhs, gens, domain):
  61. """Get matrix from linear equations in dict format.
  62. Explanation
  63. ===========
  64. Get the matrix representation of a system of linear equations represented
  65. as dicts with low-level DomainElement coefficients. This is an
  66. *internal* function that is used by solve_lin_sys.
  67. Parameters
  68. ==========
  69. eqs_coeffs: list[dict[Symbol, DomainElement]]
  70. The left hand sides of the equations as dicts mapping from symbols to
  71. coefficients where the coefficients are instances of
  72. DomainElement.
  73. eqs_rhs: list[DomainElements]
  74. The right hand sides of the equations as instances of
  75. DomainElement.
  76. gens: list[Symbol]
  77. The unknowns in the system of equations.
  78. domain: Domain
  79. The domain for coefficients of both lhs and rhs.
  80. Returns
  81. =======
  82. The augmented matrix representation of the system as a DomainMatrix.
  83. Examples
  84. ========
  85. >>> from sympy import symbols, ZZ
  86. >>> from sympy.polys.solvers import eqs_to_matrix
  87. >>> x, y = symbols('x, y')
  88. >>> eqs_coeff = [{x:ZZ(1), y:ZZ(1)}, {x:ZZ(1), y:ZZ(-1)}]
  89. >>> eqs_rhs = [ZZ(0), ZZ(-1)]
  90. >>> eqs_to_matrix(eqs_coeff, eqs_rhs, [x, y], ZZ)
  91. DomainMatrix([[1, 1, 0], [1, -1, 1]], (2, 3), ZZ)
  92. See also
  93. ========
  94. solve_lin_sys: Uses :func:`~eqs_to_matrix` internally
  95. """
  96. sym2index = {x: n for n, x in enumerate(gens)}
  97. nrows = len(eqs_coeffs)
  98. ncols = len(gens) + 1
  99. rows = [[domain.zero] * ncols for _ in range(nrows)]
  100. for row, eq_coeff, eq_rhs in zip(rows, eqs_coeffs, eqs_rhs):
  101. for sym, coeff in eq_coeff.items():
  102. row[sym2index[sym]] = domain.convert(coeff)
  103. row[-1] = -domain.convert(eq_rhs)
  104. return DomainMatrix(rows, (nrows, ncols), domain)
  105. def sympy_eqs_to_ring(eqs, symbols):
  106. """Convert a system of equations from Expr to a PolyRing
  107. Explanation
  108. ===========
  109. High-level functions like ``solve`` expect Expr as inputs but can use
  110. ``solve_lin_sys`` internally. This function converts equations from
  111. ``Expr`` to the low-level poly types used by the ``solve_lin_sys``
  112. function.
  113. Parameters
  114. ==========
  115. eqs: List of Expr
  116. A list of equations as Expr instances
  117. symbols: List of Symbol
  118. A list of the symbols that are the unknowns in the system of
  119. equations.
  120. Returns
  121. =======
  122. Tuple[List[PolyElement], Ring]: The equations as PolyElement instances
  123. and the ring of polynomials within which each equation is represented.
  124. Examples
  125. ========
  126. >>> from sympy import symbols
  127. >>> from sympy.polys.solvers import sympy_eqs_to_ring
  128. >>> a, x, y = symbols('a, x, y')
  129. >>> eqs = [x-y, x+a*y]
  130. >>> eqs_ring, ring = sympy_eqs_to_ring(eqs, [x, y])
  131. >>> eqs_ring
  132. [x - y, x + a*y]
  133. >>> type(eqs_ring[0])
  134. <class 'sympy.polys.rings.PolyElement'>
  135. >>> ring
  136. ZZ(a)[x,y]
  137. With the equations in this form they can be passed to ``solve_lin_sys``:
  138. >>> from sympy.polys.solvers import solve_lin_sys
  139. >>> solve_lin_sys(eqs_ring, ring)
  140. {y: 0, x: 0}
  141. """
  142. try:
  143. K, eqs_K = sring(eqs, symbols, field=True, extension=True)
  144. except NotInvertible:
  145. # https://github.com/sympy/sympy/issues/18874
  146. K, eqs_K = sring(eqs, symbols, domain=EX)
  147. return eqs_K, K.to_domain()
  148. def solve_lin_sys(eqs, ring, _raw=True):
  149. """Solve a system of linear equations from a PolynomialRing
  150. Explanation
  151. ===========
  152. Solves a system of linear equations given as PolyElement instances of a
  153. PolynomialRing. The basic arithmetic is carried out using instance of
  154. DomainElement which is more efficient than :class:`~sympy.core.expr.Expr`
  155. for the most common inputs.
  156. While this is a public function it is intended primarily for internal use
  157. so its interface is not necessarily convenient. Users are suggested to use
  158. the :func:`sympy.solvers.solveset.linsolve` function (which uses this
  159. function internally) instead.
  160. Parameters
  161. ==========
  162. eqs: list[PolyElement]
  163. The linear equations to be solved as elements of a
  164. PolynomialRing (assumed equal to zero).
  165. ring: PolynomialRing
  166. The polynomial ring from which eqs are drawn. The generators of this
  167. ring are the unkowns to be solved for and the domain of the ring is
  168. the domain of the coefficients of the system of equations.
  169. _raw: bool
  170. If *_raw* is False, the keys and values in the returned dictionary
  171. will be of type Expr (and the unit of the field will be removed from
  172. the keys) otherwise the low-level polys types will be returned, e.g.
  173. PolyElement: PythonRational.
  174. Returns
  175. =======
  176. ``None`` if the system has no solution.
  177. dict[Symbol, Expr] if _raw=False
  178. dict[Symbol, DomainElement] if _raw=True.
  179. Examples
  180. ========
  181. >>> from sympy import symbols
  182. >>> from sympy.polys.solvers import solve_lin_sys, sympy_eqs_to_ring
  183. >>> x, y = symbols('x, y')
  184. >>> eqs = [x - y, x + y - 2]
  185. >>> eqs_ring, ring = sympy_eqs_to_ring(eqs, [x, y])
  186. >>> solve_lin_sys(eqs_ring, ring)
  187. {y: 1, x: 1}
  188. Passing ``_raw=False`` returns the same result except that the keys are
  189. ``Expr`` rather than low-level poly types.
  190. >>> solve_lin_sys(eqs_ring, ring, _raw=False)
  191. {x: 1, y: 1}
  192. See also
  193. ========
  194. sympy_eqs_to_ring: prepares the inputs to ``solve_lin_sys``.
  195. linsolve: ``linsolve`` uses ``solve_lin_sys`` internally.
  196. sympy.solvers.solvers.solve: ``solve`` uses ``solve_lin_sys`` internally.
  197. """
  198. as_expr = not _raw
  199. assert ring.domain.is_Field
  200. eqs_dict = [dict(eq) for eq in eqs]
  201. one_monom = ring.one.monoms()[0]
  202. zero = ring.domain.zero
  203. eqs_rhs = []
  204. eqs_coeffs = []
  205. for eq_dict in eqs_dict:
  206. eq_rhs = eq_dict.pop(one_monom, zero)
  207. eq_coeffs = {}
  208. for monom, coeff in eq_dict.items():
  209. if sum(monom) != 1:
  210. msg = "Nonlinear term encountered in solve_lin_sys"
  211. raise PolyNonlinearError(msg)
  212. eq_coeffs[ring.gens[monom.index(1)]] = coeff
  213. if not eq_coeffs:
  214. if not eq_rhs:
  215. continue
  216. else:
  217. return None
  218. eqs_rhs.append(eq_rhs)
  219. eqs_coeffs.append(eq_coeffs)
  220. result = _solve_lin_sys(eqs_coeffs, eqs_rhs, ring)
  221. if result is not None and as_expr:
  222. def to_sympy(x):
  223. as_expr = getattr(x, 'as_expr', None)
  224. if as_expr:
  225. return as_expr()
  226. else:
  227. return ring.domain.to_sympy(x)
  228. tresult = {to_sympy(sym): to_sympy(val) for sym, val in result.items()}
  229. # Remove 1.0x
  230. result = {}
  231. for k, v in tresult.items():
  232. if k.is_Mul:
  233. c, s = k.as_coeff_Mul()
  234. result[s] = v/c
  235. else:
  236. result[k] = v
  237. return result
  238. def _solve_lin_sys(eqs_coeffs, eqs_rhs, ring):
  239. """Solve a linear system from dict of PolynomialRing coefficients
  240. Explanation
  241. ===========
  242. This is an **internal** function used by :func:`solve_lin_sys` after the
  243. equations have been preprocessed. The role of this function is to split
  244. the system into connected components and pass those to
  245. :func:`_solve_lin_sys_component`.
  246. Examples
  247. ========
  248. Setup a system for $x-y=0$ and $x+y=2$ and solve:
  249. >>> from sympy import symbols, sring
  250. >>> from sympy.polys.solvers import _solve_lin_sys
  251. >>> x, y = symbols('x, y')
  252. >>> R, (xr, yr) = sring([x, y], [x, y])
  253. >>> eqs = [{xr:R.one, yr:-R.one}, {xr:R.one, yr:R.one}]
  254. >>> eqs_rhs = [R.zero, -2*R.one]
  255. >>> _solve_lin_sys(eqs, eqs_rhs, R)
  256. {y: 1, x: 1}
  257. See also
  258. ========
  259. solve_lin_sys: This function is used internally by :func:`solve_lin_sys`.
  260. """
  261. V = ring.gens
  262. E = []
  263. for eq_coeffs in eqs_coeffs:
  264. syms = list(eq_coeffs)
  265. E.extend(zip(syms[:-1], syms[1:]))
  266. G = V, E
  267. components = connected_components(G)
  268. sym2comp = {}
  269. for n, component in enumerate(components):
  270. for sym in component:
  271. sym2comp[sym] = n
  272. subsystems = [([], []) for _ in range(len(components))]
  273. for eq_coeff, eq_rhs in zip(eqs_coeffs, eqs_rhs):
  274. sym = next(iter(eq_coeff), None)
  275. sub_coeff, sub_rhs = subsystems[sym2comp[sym]]
  276. sub_coeff.append(eq_coeff)
  277. sub_rhs.append(eq_rhs)
  278. sol = {}
  279. for subsystem in subsystems:
  280. subsol = _solve_lin_sys_component(subsystem[0], subsystem[1], ring)
  281. if subsol is None:
  282. return None
  283. sol.update(subsol)
  284. return sol
  285. def _solve_lin_sys_component(eqs_coeffs, eqs_rhs, ring):
  286. """Solve a linear system from dict of PolynomialRing coefficients
  287. Explanation
  288. ===========
  289. This is an **internal** function used by :func:`solve_lin_sys` after the
  290. equations have been preprocessed. After :func:`_solve_lin_sys` splits the
  291. system into connected components this function is called for each
  292. component. The system of equations is solved using Gauss-Jordan
  293. elimination with division followed by back-substitution.
  294. Examples
  295. ========
  296. Setup a system for $x-y=0$ and $x+y=2$ and solve:
  297. >>> from sympy import symbols, sring
  298. >>> from sympy.polys.solvers import _solve_lin_sys_component
  299. >>> x, y = symbols('x, y')
  300. >>> R, (xr, yr) = sring([x, y], [x, y])
  301. >>> eqs = [{xr:R.one, yr:-R.one}, {xr:R.one, yr:R.one}]
  302. >>> eqs_rhs = [R.zero, -2*R.one]
  303. >>> _solve_lin_sys_component(eqs, eqs_rhs, R)
  304. {y: 1, x: 1}
  305. See also
  306. ========
  307. solve_lin_sys: This function is used internally by :func:`solve_lin_sys`.
  308. """
  309. # transform from equations to matrix form
  310. matrix = eqs_to_matrix(eqs_coeffs, eqs_rhs, ring.gens, ring.domain)
  311. # convert to a field for rref
  312. if not matrix.domain.is_Field:
  313. matrix = matrix.to_field()
  314. # solve by row-reduction
  315. echelon, pivots = matrix.rref()
  316. # construct the returnable form of the solutions
  317. keys = ring.gens
  318. if pivots and pivots[-1] == len(keys):
  319. return None
  320. if len(pivots) == len(keys):
  321. sol = []
  322. for s in [row[-1] for row in echelon.rep.to_ddm()]:
  323. a = s
  324. sol.append(a)
  325. sols = dict(zip(keys, sol))
  326. else:
  327. sols = {}
  328. g = ring.gens
  329. # Extract ground domain coefficients and convert to the ring:
  330. if hasattr(ring, 'ring'):
  331. convert = ring.ring.ground_new
  332. else:
  333. convert = ring.ground_new
  334. echelon = echelon.rep.to_ddm()
  335. vals_set = {v for row in echelon for v in row}
  336. vals_map = {v: convert(v) for v in vals_set}
  337. echelon = [[vals_map[eij] for eij in ei] for ei in echelon]
  338. for i, p in enumerate(pivots):
  339. v = echelon[i][-1] - sum(echelon[i][j]*g[j] for j in range(p+1, len(g)) if echelon[i][j])
  340. sols[keys[p]] = v
  341. return sols