sdm.py 34 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232
  1. """
  2. Module for the SDM class.
  3. """
  4. from operator import add, neg, pos, sub, mul
  5. from collections import defaultdict
  6. from sympy.utilities.iterables import _strongly_connected_components
  7. from .exceptions import DMBadInputError, DMDomainError, DMShapeError
  8. from .ddm import DDM
  9. class SDM(dict):
  10. r"""Sparse matrix based on polys domain elements
  11. This is a dict subclass and is a wrapper for a dict of dicts that supports
  12. basic matrix arithmetic +, -, *, **.
  13. In order to create a new :py:class:`~.SDM`, a dict
  14. of dicts mapping non-zero elements to their
  15. corresponding row and column in the matrix is needed.
  16. We also need to specify the shape and :py:class:`~.Domain`
  17. of our :py:class:`~.SDM` object.
  18. We declare a 2x2 :py:class:`~.SDM` matrix belonging
  19. to QQ domain as shown below.
  20. The 2x2 Matrix in the example is
  21. .. math::
  22. A = \left[\begin{array}{ccc}
  23. 0 & \frac{1}{2} \\
  24. 0 & 0 \end{array} \right]
  25. >>> from sympy.polys.matrices.sdm import SDM
  26. >>> from sympy import QQ
  27. >>> elemsdict = {0:{1:QQ(1, 2)}}
  28. >>> A = SDM(elemsdict, (2, 2), QQ)
  29. >>> A
  30. {0: {1: 1/2}}
  31. We can manipulate :py:class:`~.SDM` the same way
  32. as a Matrix class
  33. >>> from sympy import ZZ
  34. >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
  35. >>> B = SDM({0:{0: ZZ(3)}, 1:{1:ZZ(4)}}, (2, 2), ZZ)
  36. >>> A + B
  37. {0: {0: 3, 1: 2}, 1: {0: 1, 1: 4}}
  38. Multiplication
  39. >>> A*B
  40. {0: {1: 8}, 1: {0: 3}}
  41. >>> A*ZZ(2)
  42. {0: {1: 4}, 1: {0: 2}}
  43. """
  44. fmt = 'sparse'
  45. def __init__(self, elemsdict, shape, domain):
  46. super().__init__(elemsdict)
  47. self.shape = self.rows, self.cols = m, n = shape
  48. self.domain = domain
  49. if not all(0 <= r < m for r in self):
  50. raise DMBadInputError("Row out of range")
  51. if not all(0 <= c < n for row in self.values() for c in row):
  52. raise DMBadInputError("Column out of range")
  53. def getitem(self, i, j):
  54. try:
  55. return self[i][j]
  56. except KeyError:
  57. m, n = self.shape
  58. if -m <= i < m and -n <= j < n:
  59. try:
  60. return self[i % m][j % n]
  61. except KeyError:
  62. return self.domain.zero
  63. else:
  64. raise IndexError("index out of range")
  65. def setitem(self, i, j, value):
  66. m, n = self.shape
  67. if not (-m <= i < m and -n <= j < n):
  68. raise IndexError("index out of range")
  69. i, j = i % m, j % n
  70. if value:
  71. try:
  72. self[i][j] = value
  73. except KeyError:
  74. self[i] = {j: value}
  75. else:
  76. rowi = self.get(i, None)
  77. if rowi is not None:
  78. try:
  79. del rowi[j]
  80. except KeyError:
  81. pass
  82. else:
  83. if not rowi:
  84. del self[i]
  85. def extract_slice(self, slice1, slice2):
  86. m, n = self.shape
  87. ri = range(m)[slice1]
  88. ci = range(n)[slice2]
  89. sdm = {}
  90. for i, row in self.items():
  91. if i in ri:
  92. row = {ci.index(j): e for j, e in row.items() if j in ci}
  93. if row:
  94. sdm[ri.index(i)] = row
  95. return self.new(sdm, (len(ri), len(ci)), self.domain)
  96. def extract(self, rows, cols):
  97. if not (self and rows and cols):
  98. return self.zeros((len(rows), len(cols)), self.domain)
  99. m, n = self.shape
  100. if not (-m <= min(rows) <= max(rows) < m):
  101. raise IndexError('Row index out of range')
  102. if not (-n <= min(cols) <= max(cols) < n):
  103. raise IndexError('Column index out of range')
  104. # rows and cols can contain duplicates e.g. M[[1, 2, 2], [0, 1]]
  105. # Build a map from row/col in self to list of rows/cols in output
  106. rowmap = defaultdict(list)
  107. colmap = defaultdict(list)
  108. for i2, i1 in enumerate(rows):
  109. rowmap[i1 % m].append(i2)
  110. for j2, j1 in enumerate(cols):
  111. colmap[j1 % n].append(j2)
  112. # Used to efficiently skip zero rows/cols
  113. rowset = set(rowmap)
  114. colset = set(colmap)
  115. sdm1 = self
  116. sdm2 = {}
  117. for i1 in rowset & set(sdm1):
  118. row1 = sdm1[i1]
  119. row2 = {}
  120. for j1 in colset & set(row1):
  121. row1_j1 = row1[j1]
  122. for j2 in colmap[j1]:
  123. row2[j2] = row1_j1
  124. if row2:
  125. for i2 in rowmap[i1]:
  126. sdm2[i2] = row2.copy()
  127. return self.new(sdm2, (len(rows), len(cols)), self.domain)
  128. def __str__(self):
  129. rowsstr = []
  130. for i, row in self.items():
  131. elemsstr = ', '.join('%s: %s' % (j, elem) for j, elem in row.items())
  132. rowsstr.append('%s: {%s}' % (i, elemsstr))
  133. return '{%s}' % ', '.join(rowsstr)
  134. def __repr__(self):
  135. cls = type(self).__name__
  136. rows = dict.__repr__(self)
  137. return '%s(%s, %s, %s)' % (cls, rows, self.shape, self.domain)
  138. @classmethod
  139. def new(cls, sdm, shape, domain):
  140. """
  141. Parameters
  142. ==========
  143. sdm: A dict of dicts for non-zero elements in SDM
  144. shape: tuple representing dimension of SDM
  145. domain: Represents :py:class:`~.Domain` of SDM
  146. Returns
  147. =======
  148. An :py:class:`~.SDM` object
  149. Examples
  150. ========
  151. >>> from sympy.polys.matrices.sdm import SDM
  152. >>> from sympy import QQ
  153. >>> elemsdict = {0:{1: QQ(2)}}
  154. >>> A = SDM.new(elemsdict, (2, 2), QQ)
  155. >>> A
  156. {0: {1: 2}}
  157. """
  158. return cls(sdm, shape, domain)
  159. def copy(A):
  160. """
  161. Returns the copy of a :py:class:`~.SDM` object
  162. Examples
  163. ========
  164. >>> from sympy.polys.matrices.sdm import SDM
  165. >>> from sympy import QQ
  166. >>> elemsdict = {0:{1:QQ(2)}, 1:{}}
  167. >>> A = SDM(elemsdict, (2, 2), QQ)
  168. >>> B = A.copy()
  169. >>> B
  170. {0: {1: 2}, 1: {}}
  171. """
  172. Ac = {i: Ai.copy() for i, Ai in A.items()}
  173. return A.new(Ac, A.shape, A.domain)
  174. @classmethod
  175. def from_list(cls, ddm, shape, domain):
  176. """
  177. Parameters
  178. ==========
  179. ddm:
  180. list of lists containing domain elements
  181. shape:
  182. Dimensions of :py:class:`~.SDM` matrix
  183. domain:
  184. Represents :py:class:`~.Domain` of :py:class:`~.SDM` object
  185. Returns
  186. =======
  187. :py:class:`~.SDM` containing elements of ddm
  188. Examples
  189. ========
  190. >>> from sympy.polys.matrices.sdm import SDM
  191. >>> from sympy import QQ
  192. >>> ddm = [[QQ(1, 2), QQ(0)], [QQ(0), QQ(3, 4)]]
  193. >>> A = SDM.from_list(ddm, (2, 2), QQ)
  194. >>> A
  195. {0: {0: 1/2}, 1: {1: 3/4}}
  196. """
  197. m, n = shape
  198. if not (len(ddm) == m and all(len(row) == n for row in ddm)):
  199. raise DMBadInputError("Inconsistent row-list/shape")
  200. getrow = lambda i: {j:ddm[i][j] for j in range(n) if ddm[i][j]}
  201. irows = ((i, getrow(i)) for i in range(m))
  202. sdm = {i: row for i, row in irows if row}
  203. return cls(sdm, shape, domain)
  204. @classmethod
  205. def from_ddm(cls, ddm):
  206. """
  207. converts object of :py:class:`~.DDM` to
  208. :py:class:`~.SDM`
  209. Examples
  210. ========
  211. >>> from sympy.polys.matrices.ddm import DDM
  212. >>> from sympy.polys.matrices.sdm import SDM
  213. >>> from sympy import QQ
  214. >>> ddm = DDM( [[QQ(1, 2), 0], [0, QQ(3, 4)]], (2, 2), QQ)
  215. >>> A = SDM.from_ddm(ddm)
  216. >>> A
  217. {0: {0: 1/2}, 1: {1: 3/4}}
  218. """
  219. return cls.from_list(ddm, ddm.shape, ddm.domain)
  220. def to_list(M):
  221. """
  222. Converts a :py:class:`~.SDM` object to a list
  223. Examples
  224. ========
  225. >>> from sympy.polys.matrices.sdm import SDM
  226. >>> from sympy import QQ
  227. >>> elemsdict = {0:{1:QQ(2)}, 1:{}}
  228. >>> A = SDM(elemsdict, (2, 2), QQ)
  229. >>> A.to_list()
  230. [[0, 2], [0, 0]]
  231. """
  232. m, n = M.shape
  233. zero = M.domain.zero
  234. ddm = [[zero] * n for _ in range(m)]
  235. for i, row in M.items():
  236. for j, e in row.items():
  237. ddm[i][j] = e
  238. return ddm
  239. def to_list_flat(M):
  240. m, n = M.shape
  241. zero = M.domain.zero
  242. flat = [zero] * (m * n)
  243. for i, row in M.items():
  244. for j, e in row.items():
  245. flat[i*n + j] = e
  246. return flat
  247. def to_dok(M):
  248. return {(i, j): e for i, row in M.items() for j, e in row.items()}
  249. def to_ddm(M):
  250. """
  251. Convert a :py:class:`~.SDM` object to a :py:class:`~.DDM` object
  252. Examples
  253. ========
  254. >>> from sympy.polys.matrices.sdm import SDM
  255. >>> from sympy import QQ
  256. >>> A = SDM({0:{1:QQ(2)}, 1:{}}, (2, 2), QQ)
  257. >>> A.to_ddm()
  258. [[0, 2], [0, 0]]
  259. """
  260. return DDM(M.to_list(), M.shape, M.domain)
  261. def to_sdm(M):
  262. return M
  263. @classmethod
  264. def zeros(cls, shape, domain):
  265. r"""
  266. Returns a :py:class:`~.SDM` of size shape,
  267. belonging to the specified domain
  268. In the example below we declare a matrix A where,
  269. .. math::
  270. A := \left[\begin{array}{ccc}
  271. 0 & 0 & 0 \\
  272. 0 & 0 & 0 \end{array} \right]
  273. >>> from sympy.polys.matrices.sdm import SDM
  274. >>> from sympy import QQ
  275. >>> A = SDM.zeros((2, 3), QQ)
  276. >>> A
  277. {}
  278. """
  279. return cls({}, shape, domain)
  280. @classmethod
  281. def ones(cls, shape, domain):
  282. one = domain.one
  283. m, n = shape
  284. row = dict(zip(range(n), [one]*n))
  285. sdm = {i: row.copy() for i in range(m)}
  286. return cls(sdm, shape, domain)
  287. @classmethod
  288. def eye(cls, shape, domain):
  289. """
  290. Returns a identity :py:class:`~.SDM` matrix of dimensions
  291. size x size, belonging to the specified domain
  292. Examples
  293. ========
  294. >>> from sympy.polys.matrices.sdm import SDM
  295. >>> from sympy import QQ
  296. >>> I = SDM.eye((2, 2), QQ)
  297. >>> I
  298. {0: {0: 1}, 1: {1: 1}}
  299. """
  300. rows, cols = shape
  301. one = domain.one
  302. sdm = {i: {i: one} for i in range(min(rows, cols))}
  303. return cls(sdm, shape, domain)
  304. @classmethod
  305. def diag(cls, diagonal, domain, shape):
  306. sdm = {i: {i: v} for i, v in enumerate(diagonal) if v}
  307. return cls(sdm, shape, domain)
  308. def transpose(M):
  309. """
  310. Returns the transpose of a :py:class:`~.SDM` matrix
  311. Examples
  312. ========
  313. >>> from sympy.polys.matrices.sdm import SDM
  314. >>> from sympy import QQ
  315. >>> A = SDM({0:{1:QQ(2)}, 1:{}}, (2, 2), QQ)
  316. >>> A.transpose()
  317. {1: {0: 2}}
  318. """
  319. MT = sdm_transpose(M)
  320. return M.new(MT, M.shape[::-1], M.domain)
  321. def __add__(A, B):
  322. if not isinstance(B, SDM):
  323. return NotImplemented
  324. return A.add(B)
  325. def __sub__(A, B):
  326. if not isinstance(B, SDM):
  327. return NotImplemented
  328. return A.sub(B)
  329. def __neg__(A):
  330. return A.neg()
  331. def __mul__(A, B):
  332. """A * B"""
  333. if isinstance(B, SDM):
  334. return A.matmul(B)
  335. elif B in A.domain:
  336. return A.mul(B)
  337. else:
  338. return NotImplemented
  339. def __rmul__(a, b):
  340. if b in a.domain:
  341. return a.rmul(b)
  342. else:
  343. return NotImplemented
  344. def matmul(A, B):
  345. """
  346. Performs matrix multiplication of two SDM matrices
  347. Parameters
  348. ==========
  349. A, B: SDM to multiply
  350. Returns
  351. =======
  352. SDM
  353. SDM after multiplication
  354. Raises
  355. ======
  356. DomainError
  357. If domain of A does not match
  358. with that of B
  359. Examples
  360. ========
  361. >>> from sympy import ZZ
  362. >>> from sympy.polys.matrices.sdm import SDM
  363. >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
  364. >>> B = SDM({0:{0:ZZ(2), 1:ZZ(3)}, 1:{0:ZZ(4)}}, (2, 2), ZZ)
  365. >>> A.matmul(B)
  366. {0: {0: 8}, 1: {0: 2, 1: 3}}
  367. """
  368. if A.domain != B.domain:
  369. raise DMDomainError
  370. m, n = A.shape
  371. n2, o = B.shape
  372. if n != n2:
  373. raise DMShapeError
  374. C = sdm_matmul(A, B, A.domain, m, o)
  375. return A.new(C, (m, o), A.domain)
  376. def mul(A, b):
  377. """
  378. Multiplies each element of A with a scalar b
  379. Examples
  380. ========
  381. >>> from sympy import ZZ
  382. >>> from sympy.polys.matrices.sdm import SDM
  383. >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
  384. >>> A.mul(ZZ(3))
  385. {0: {1: 6}, 1: {0: 3}}
  386. """
  387. Csdm = unop_dict(A, lambda aij: aij*b)
  388. return A.new(Csdm, A.shape, A.domain)
  389. def rmul(A, b):
  390. Csdm = unop_dict(A, lambda aij: b*aij)
  391. return A.new(Csdm, A.shape, A.domain)
  392. def mul_elementwise(A, B):
  393. if A.domain != B.domain:
  394. raise DMDomainError
  395. if A.shape != B.shape:
  396. raise DMShapeError
  397. zero = A.domain.zero
  398. fzero = lambda e: zero
  399. Csdm = binop_dict(A, B, mul, fzero, fzero)
  400. return A.new(Csdm, A.shape, A.domain)
  401. def add(A, B):
  402. """
  403. Adds two :py:class:`~.SDM` matrices
  404. Examples
  405. ========
  406. >>> from sympy import ZZ
  407. >>> from sympy.polys.matrices.sdm import SDM
  408. >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
  409. >>> B = SDM({0:{0: ZZ(3)}, 1:{1:ZZ(4)}}, (2, 2), ZZ)
  410. >>> A.add(B)
  411. {0: {0: 3, 1: 2}, 1: {0: 1, 1: 4}}
  412. """
  413. Csdm = binop_dict(A, B, add, pos, pos)
  414. return A.new(Csdm, A.shape, A.domain)
  415. def sub(A, B):
  416. """
  417. Subtracts two :py:class:`~.SDM` matrices
  418. Examples
  419. ========
  420. >>> from sympy import ZZ
  421. >>> from sympy.polys.matrices.sdm import SDM
  422. >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
  423. >>> B = SDM({0:{0: ZZ(3)}, 1:{1:ZZ(4)}}, (2, 2), ZZ)
  424. >>> A.sub(B)
  425. {0: {0: -3, 1: 2}, 1: {0: 1, 1: -4}}
  426. """
  427. Csdm = binop_dict(A, B, sub, pos, neg)
  428. return A.new(Csdm, A.shape, A.domain)
  429. def neg(A):
  430. """
  431. Returns the negative of a :py:class:`~.SDM` matrix
  432. Examples
  433. ========
  434. >>> from sympy import ZZ
  435. >>> from sympy.polys.matrices.sdm import SDM
  436. >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
  437. >>> A.neg()
  438. {0: {1: -2}, 1: {0: -1}}
  439. """
  440. Csdm = unop_dict(A, neg)
  441. return A.new(Csdm, A.shape, A.domain)
  442. def convert_to(A, K):
  443. """
  444. Converts the :py:class:`~.Domain` of a :py:class:`~.SDM` matrix to K
  445. Examples
  446. ========
  447. >>> from sympy import ZZ, QQ
  448. >>> from sympy.polys.matrices.sdm import SDM
  449. >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
  450. >>> A.convert_to(QQ)
  451. {0: {1: 2}, 1: {0: 1}}
  452. """
  453. Kold = A.domain
  454. if K == Kold:
  455. return A.copy()
  456. Ak = unop_dict(A, lambda e: K.convert_from(e, Kold))
  457. return A.new(Ak, A.shape, K)
  458. def scc(A):
  459. """Strongly connected components of a square matrix *A*.
  460. Examples
  461. ========
  462. >>> from sympy import ZZ
  463. >>> from sympy.polys.matrices.sdm import SDM
  464. >>> A = SDM({0:{0: ZZ(2)}, 1:{1:ZZ(1)}}, (2, 2), ZZ)
  465. >>> A.scc()
  466. [[0], [1]]
  467. See also
  468. ========
  469. sympy.polys.matrices.domainmatrix.DomainMatrix.scc
  470. """
  471. rows, cols = A.shape
  472. assert rows == cols
  473. V = range(rows)
  474. Emap = {v: list(A.get(v, [])) for v in V}
  475. return _strongly_connected_components(V, Emap)
  476. def rref(A):
  477. """
  478. Returns reduced-row echelon form and list of pivots for the :py:class:`~.SDM`
  479. Examples
  480. ========
  481. >>> from sympy import QQ
  482. >>> from sympy.polys.matrices.sdm import SDM
  483. >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(2), 1:QQ(4)}}, (2, 2), QQ)
  484. >>> A.rref()
  485. ({0: {0: 1, 1: 2}}, [0])
  486. """
  487. B, pivots, _ = sdm_irref(A)
  488. return A.new(B, A.shape, A.domain), pivots
  489. def inv(A):
  490. """
  491. Returns inverse of a matrix A
  492. Examples
  493. ========
  494. >>> from sympy import QQ
  495. >>> from sympy.polys.matrices.sdm import SDM
  496. >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ)
  497. >>> A.inv()
  498. {0: {0: -2, 1: 1}, 1: {0: 3/2, 1: -1/2}}
  499. """
  500. return A.from_ddm(A.to_ddm().inv())
  501. def det(A):
  502. """
  503. Returns determinant of A
  504. Examples
  505. ========
  506. >>> from sympy import QQ
  507. >>> from sympy.polys.matrices.sdm import SDM
  508. >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ)
  509. >>> A.det()
  510. -2
  511. """
  512. return A.to_ddm().det()
  513. def lu(A):
  514. """
  515. Returns LU decomposition for a matrix A
  516. Examples
  517. ========
  518. >>> from sympy import QQ
  519. >>> from sympy.polys.matrices.sdm import SDM
  520. >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ)
  521. >>> A.lu()
  522. ({0: {0: 1}, 1: {0: 3, 1: 1}}, {0: {0: 1, 1: 2}, 1: {1: -2}}, [])
  523. """
  524. L, U, swaps = A.to_ddm().lu()
  525. return A.from_ddm(L), A.from_ddm(U), swaps
  526. def lu_solve(A, b):
  527. """
  528. Uses LU decomposition to solve Ax = b,
  529. Examples
  530. ========
  531. >>> from sympy import QQ
  532. >>> from sympy.polys.matrices.sdm import SDM
  533. >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ)
  534. >>> b = SDM({0:{0:QQ(1)}, 1:{0:QQ(2)}}, (2, 1), QQ)
  535. >>> A.lu_solve(b)
  536. {1: {0: 1/2}}
  537. """
  538. return A.from_ddm(A.to_ddm().lu_solve(b.to_ddm()))
  539. def nullspace(A):
  540. """
  541. Returns nullspace for a :py:class:`~.SDM` matrix A
  542. Examples
  543. ========
  544. >>> from sympy import QQ
  545. >>> from sympy.polys.matrices.sdm import SDM
  546. >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0: QQ(2), 1: QQ(4)}}, (2, 2), QQ)
  547. >>> A.nullspace()
  548. ({0: {0: -2, 1: 1}}, [1])
  549. """
  550. ncols = A.shape[1]
  551. one = A.domain.one
  552. B, pivots, nzcols = sdm_irref(A)
  553. K, nonpivots = sdm_nullspace_from_rref(B, one, ncols, pivots, nzcols)
  554. K = dict(enumerate(K))
  555. shape = (len(K), ncols)
  556. return A.new(K, shape, A.domain), nonpivots
  557. def particular(A):
  558. ncols = A.shape[1]
  559. B, pivots, nzcols = sdm_irref(A)
  560. P = sdm_particular_from_rref(B, ncols, pivots)
  561. rep = {0:P} if P else {}
  562. return A.new(rep, (1, ncols-1), A.domain)
  563. def hstack(A, *B):
  564. """Horizontally stacks :py:class:`~.SDM` matrices.
  565. Examples
  566. ========
  567. >>> from sympy import ZZ
  568. >>> from sympy.polys.matrices.sdm import SDM
  569. >>> A = SDM({0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}}, (2, 2), ZZ)
  570. >>> B = SDM({0: {0: ZZ(5), 1: ZZ(6)}, 1: {0: ZZ(7), 1: ZZ(8)}}, (2, 2), ZZ)
  571. >>> A.hstack(B)
  572. {0: {0: 1, 1: 2, 2: 5, 3: 6}, 1: {0: 3, 1: 4, 2: 7, 3: 8}}
  573. >>> C = SDM({0: {0: ZZ(9), 1: ZZ(10)}, 1: {0: ZZ(11), 1: ZZ(12)}}, (2, 2), ZZ)
  574. >>> A.hstack(B, C)
  575. {0: {0: 1, 1: 2, 2: 5, 3: 6, 4: 9, 5: 10}, 1: {0: 3, 1: 4, 2: 7, 3: 8, 4: 11, 5: 12}}
  576. """
  577. Anew = dict(A.copy())
  578. rows, cols = A.shape
  579. domain = A.domain
  580. for Bk in B:
  581. Bkrows, Bkcols = Bk.shape
  582. assert Bkrows == rows
  583. assert Bk.domain == domain
  584. for i, Bki in Bk.items():
  585. Ai = Anew.get(i, None)
  586. if Ai is None:
  587. Anew[i] = Ai = {}
  588. for j, Bkij in Bki.items():
  589. Ai[j + cols] = Bkij
  590. cols += Bkcols
  591. return A.new(Anew, (rows, cols), A.domain)
  592. def vstack(A, *B):
  593. """Vertically stacks :py:class:`~.SDM` matrices.
  594. Examples
  595. ========
  596. >>> from sympy import ZZ
  597. >>> from sympy.polys.matrices.sdm import SDM
  598. >>> A = SDM({0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}}, (2, 2), ZZ)
  599. >>> B = SDM({0: {0: ZZ(5), 1: ZZ(6)}, 1: {0: ZZ(7), 1: ZZ(8)}}, (2, 2), ZZ)
  600. >>> A.vstack(B)
  601. {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}, 2: {0: 5, 1: 6}, 3: {0: 7, 1: 8}}
  602. >>> C = SDM({0: {0: ZZ(9), 1: ZZ(10)}, 1: {0: ZZ(11), 1: ZZ(12)}}, (2, 2), ZZ)
  603. >>> A.vstack(B, C)
  604. {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}, 2: {0: 5, 1: 6}, 3: {0: 7, 1: 8}, 4: {0: 9, 1: 10}, 5: {0: 11, 1: 12}}
  605. """
  606. Anew = dict(A.copy())
  607. rows, cols = A.shape
  608. domain = A.domain
  609. for Bk in B:
  610. Bkrows, Bkcols = Bk.shape
  611. assert Bkcols == cols
  612. assert Bk.domain == domain
  613. for i, Bki in Bk.items():
  614. Anew[i + rows] = Bki
  615. rows += Bkrows
  616. return A.new(Anew, (rows, cols), A.domain)
  617. def applyfunc(self, func, domain):
  618. sdm = {i: {j: func(e) for j, e in row.items()} for i, row in self.items()}
  619. return self.new(sdm, self.shape, domain)
  620. def charpoly(A):
  621. """
  622. Returns the coefficients of the characteristic polynomial
  623. of the :py:class:`~.SDM` matrix. These elements will be domain elements.
  624. The domain of the elements will be same as domain of the :py:class:`~.SDM`.
  625. Examples
  626. ========
  627. >>> from sympy import QQ, Symbol
  628. >>> from sympy.polys.matrices.sdm import SDM
  629. >>> from sympy.polys import Poly
  630. >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ)
  631. >>> A.charpoly()
  632. [1, -5, -2]
  633. We can create a polynomial using the
  634. coefficients using :py:class:`~.Poly`
  635. >>> x = Symbol('x')
  636. >>> p = Poly(A.charpoly(), x, domain=A.domain)
  637. >>> p
  638. Poly(x**2 - 5*x - 2, x, domain='QQ')
  639. """
  640. return A.to_ddm().charpoly()
  641. def is_zero_matrix(self):
  642. """
  643. Says whether this matrix has all zero entries.
  644. """
  645. return not self
  646. def is_upper(self):
  647. """
  648. Says whether this matrix is upper-triangular. True can be returned
  649. even if the matrix is not square.
  650. """
  651. return all(i <= j for i, row in self.items() for j in row)
  652. def is_lower(self):
  653. """
  654. Says whether this matrix is lower-triangular. True can be returned
  655. even if the matrix is not square.
  656. """
  657. return all(i >= j for i, row in self.items() for j in row)
  658. def binop_dict(A, B, fab, fa, fb):
  659. Anz, Bnz = set(A), set(B)
  660. C = {}
  661. for i in Anz & Bnz:
  662. Ai, Bi = A[i], B[i]
  663. Ci = {}
  664. Anzi, Bnzi = set(Ai), set(Bi)
  665. for j in Anzi & Bnzi:
  666. Cij = fab(Ai[j], Bi[j])
  667. if Cij:
  668. Ci[j] = Cij
  669. for j in Anzi - Bnzi:
  670. Cij = fa(Ai[j])
  671. if Cij:
  672. Ci[j] = Cij
  673. for j in Bnzi - Anzi:
  674. Cij = fb(Bi[j])
  675. if Cij:
  676. Ci[j] = Cij
  677. if Ci:
  678. C[i] = Ci
  679. for i in Anz - Bnz:
  680. Ai = A[i]
  681. Ci = {}
  682. for j, Aij in Ai.items():
  683. Cij = fa(Aij)
  684. if Cij:
  685. Ci[j] = Cij
  686. if Ci:
  687. C[i] = Ci
  688. for i in Bnz - Anz:
  689. Bi = B[i]
  690. Ci = {}
  691. for j, Bij in Bi.items():
  692. Cij = fb(Bij)
  693. if Cij:
  694. Ci[j] = Cij
  695. if Ci:
  696. C[i] = Ci
  697. return C
  698. def unop_dict(A, f):
  699. B = {}
  700. for i, Ai in A.items():
  701. Bi = {}
  702. for j, Aij in Ai.items():
  703. Bij = f(Aij)
  704. if Bij:
  705. Bi[j] = Bij
  706. if Bi:
  707. B[i] = Bi
  708. return B
  709. def sdm_transpose(M):
  710. MT = {}
  711. for i, Mi in M.items():
  712. for j, Mij in Mi.items():
  713. try:
  714. MT[j][i] = Mij
  715. except KeyError:
  716. MT[j] = {i: Mij}
  717. return MT
  718. def sdm_matmul(A, B, K, m, o):
  719. #
  720. # Should be fast if A and B are very sparse.
  721. # Consider e.g. A = B = eye(1000).
  722. #
  723. # The idea here is that we compute C = A*B in terms of the rows of C and
  724. # B since the dict of dicts representation naturally stores the matrix as
  725. # rows. The ith row of C (Ci) is equal to the sum of Aik * Bk where Bk is
  726. # the kth row of B. The algorithm below loops over each nonzero element
  727. # Aik of A and if the corresponding row Bj is nonzero then we do
  728. # Ci += Aik * Bk.
  729. # To make this more efficient we don't need to loop over all elements Aik.
  730. # Instead for each row Ai we compute the intersection of the nonzero
  731. # columns in Ai with the nonzero rows in B. That gives the k such that
  732. # Aik and Bk are both nonzero. In Python the intersection of two sets
  733. # of int can be computed very efficiently.
  734. #
  735. if K.is_EXRAW:
  736. return sdm_matmul_exraw(A, B, K, m, o)
  737. C = {}
  738. B_knz = set(B)
  739. for i, Ai in A.items():
  740. Ci = {}
  741. Ai_knz = set(Ai)
  742. for k in Ai_knz & B_knz:
  743. Aik = Ai[k]
  744. for j, Bkj in B[k].items():
  745. Cij = Ci.get(j, None)
  746. if Cij is not None:
  747. Cij = Cij + Aik * Bkj
  748. if Cij:
  749. Ci[j] = Cij
  750. else:
  751. Ci.pop(j)
  752. else:
  753. Cij = Aik * Bkj
  754. if Cij:
  755. Ci[j] = Cij
  756. if Ci:
  757. C[i] = Ci
  758. return C
  759. def sdm_matmul_exraw(A, B, K, m, o):
  760. #
  761. # Like sdm_matmul above except that:
  762. #
  763. # - Handles cases like 0*oo -> nan (sdm_matmul skips multipication by zero)
  764. # - Uses K.sum (Add(*items)) for efficient addition of Expr
  765. #
  766. zero = K.zero
  767. C = {}
  768. B_knz = set(B)
  769. for i, Ai in A.items():
  770. Ci_list = defaultdict(list)
  771. Ai_knz = set(Ai)
  772. # Nonzero row/column pair
  773. for k in Ai_knz & B_knz:
  774. Aik = Ai[k]
  775. if zero * Aik == zero:
  776. # This is the main inner loop:
  777. for j, Bkj in B[k].items():
  778. Ci_list[j].append(Aik * Bkj)
  779. else:
  780. for j in range(o):
  781. Ci_list[j].append(Aik * B[k].get(j, zero))
  782. # Zero row in B, check for infinities in A
  783. for k in Ai_knz - B_knz:
  784. zAik = zero * Ai[k]
  785. if zAik != zero:
  786. for j in range(o):
  787. Ci_list[j].append(zAik)
  788. # Add terms using K.sum (Add(*terms)) for efficiency
  789. Ci = {}
  790. for j, Cij_list in Ci_list.items():
  791. Cij = K.sum(Cij_list)
  792. if Cij:
  793. Ci[j] = Cij
  794. if Ci:
  795. C[i] = Ci
  796. # Find all infinities in B
  797. for k, Bk in B.items():
  798. for j, Bkj in Bk.items():
  799. if zero * Bkj != zero:
  800. for i in range(m):
  801. Aik = A.get(i, {}).get(k, zero)
  802. # If Aik is not zero then this was handled above
  803. if Aik == zero:
  804. Ci = C.get(i, {})
  805. Cij = Ci.get(j, zero) + Aik * Bkj
  806. if Cij != zero:
  807. Ci[j] = Cij
  808. else: # pragma: no cover
  809. # Not sure how we could get here but let's raise an
  810. # exception just in case.
  811. raise RuntimeError
  812. C[i] = Ci
  813. return C
  814. def sdm_irref(A):
  815. """RREF and pivots of a sparse matrix *A*.
  816. Compute the reduced row echelon form (RREF) of the matrix *A* and return a
  817. list of the pivot columns. This routine does not work in place and leaves
  818. the original matrix *A* unmodified.
  819. Examples
  820. ========
  821. This routine works with a dict of dicts sparse representation of a matrix:
  822. >>> from sympy import QQ
  823. >>> from sympy.polys.matrices.sdm import sdm_irref
  824. >>> A = {0: {0: QQ(1), 1: QQ(2)}, 1: {0: QQ(3), 1: QQ(4)}}
  825. >>> Arref, pivots, _ = sdm_irref(A)
  826. >>> Arref
  827. {0: {0: 1}, 1: {1: 1}}
  828. >>> pivots
  829. [0, 1]
  830. The analogous calculation with :py:class:`~.Matrix` would be
  831. >>> from sympy import Matrix
  832. >>> M = Matrix([[1, 2], [3, 4]])
  833. >>> Mrref, pivots = M.rref()
  834. >>> Mrref
  835. Matrix([
  836. [1, 0],
  837. [0, 1]])
  838. >>> pivots
  839. (0, 1)
  840. Notes
  841. =====
  842. The cost of this algorithm is determined purely by the nonzero elements of
  843. the matrix. No part of the cost of any step in this algorithm depends on
  844. the number of rows or columns in the matrix. No step depends even on the
  845. number of nonzero rows apart from the primary loop over those rows. The
  846. implementation is much faster than ddm_rref for sparse matrices. In fact
  847. at the time of writing it is also (slightly) faster than the dense
  848. implementation even if the input is a fully dense matrix so it seems to be
  849. faster in all cases.
  850. The elements of the matrix should support exact division with ``/``. For
  851. example elements of any domain that is a field (e.g. ``QQ``) should be
  852. fine. No attempt is made to handle inexact arithmetic.
  853. """
  854. #
  855. # Any zeros in the matrix are not stored at all so an element is zero if
  856. # its row dict has no index at that key. A row is entirely zero if its
  857. # row index is not in the outer dict. Since rref reorders the rows and
  858. # removes zero rows we can completely discard the row indices. The first
  859. # step then copies the row dicts into a list sorted by the index of the
  860. # first nonzero column in each row.
  861. #
  862. # The algorithm then processes each row Ai one at a time. Previously seen
  863. # rows are used to cancel their pivot columns from Ai. Then a pivot from
  864. # Ai is chosen and is cancelled from all previously seen rows. At this
  865. # point Ai joins the previously seen rows. Once all rows are seen all
  866. # elimination has occurred and the rows are sorted by pivot column index.
  867. #
  868. # The previously seen rows are stored in two separate groups. The reduced
  869. # group consists of all rows that have been reduced to a single nonzero
  870. # element (the pivot). There is no need to attempt any further reduction
  871. # with these. Rows that still have other nonzeros need to be considered
  872. # when Ai is cancelled from the previously seen rows.
  873. #
  874. # A dict nonzerocolumns is used to map from a column index to a set of
  875. # previously seen rows that still have a nonzero element in that column.
  876. # This means that we can cancel the pivot from Ai into the previously seen
  877. # rows without needing to loop over each row that might have a zero in
  878. # that column.
  879. #
  880. # Row dicts sorted by index of first nonzero column
  881. # (Maybe sorting is not needed/useful.)
  882. Arows = sorted((Ai.copy() for Ai in A.values()), key=min)
  883. # Each processed row has an associated pivot column.
  884. # pivot_row_map maps from the pivot column index to the row dict.
  885. # This means that we can represent a set of rows purely as a set of their
  886. # pivot indices.
  887. pivot_row_map = {}
  888. # Set of pivot indices for rows that are fully reduced to a single nonzero.
  889. reduced_pivots = set()
  890. # Set of pivot indices for rows not fully reduced
  891. nonreduced_pivots = set()
  892. # Map from column index to a set of pivot indices representing the rows
  893. # that have a nonzero at that column.
  894. nonzero_columns = defaultdict(set)
  895. while Arows:
  896. # Select pivot element and row
  897. Ai = Arows.pop()
  898. # Nonzero columns from fully reduced pivot rows can be removed
  899. Ai = {j: Aij for j, Aij in Ai.items() if j not in reduced_pivots}
  900. # Others require full row cancellation
  901. for j in nonreduced_pivots & set(Ai):
  902. Aj = pivot_row_map[j]
  903. Aij = Ai[j]
  904. Ainz = set(Ai)
  905. Ajnz = set(Aj)
  906. for k in Ajnz - Ainz:
  907. Ai[k] = - Aij * Aj[k]
  908. Ai.pop(j)
  909. Ainz.remove(j)
  910. for k in Ajnz & Ainz:
  911. Aik = Ai[k] - Aij * Aj[k]
  912. if Aik:
  913. Ai[k] = Aik
  914. else:
  915. Ai.pop(k)
  916. # We have now cancelled previously seen pivots from Ai.
  917. # If it is zero then discard it.
  918. if not Ai:
  919. continue
  920. # Choose a pivot from Ai:
  921. j = min(Ai)
  922. Aij = Ai[j]
  923. pivot_row_map[j] = Ai
  924. Ainz = set(Ai)
  925. # Normalise the pivot row to make the pivot 1.
  926. #
  927. # This approach is slow for some domains. Cross cancellation might be
  928. # better for e.g. QQ(x) with division delayed to the final steps.
  929. Aijinv = Aij**-1
  930. for l in Ai:
  931. Ai[l] *= Aijinv
  932. # Use Aij to cancel column j from all previously seen rows
  933. for k in nonzero_columns.pop(j, ()):
  934. Ak = pivot_row_map[k]
  935. Akj = Ak[j]
  936. Aknz = set(Ak)
  937. for l in Ainz - Aknz:
  938. Ak[l] = - Akj * Ai[l]
  939. nonzero_columns[l].add(k)
  940. Ak.pop(j)
  941. Aknz.remove(j)
  942. for l in Ainz & Aknz:
  943. Akl = Ak[l] - Akj * Ai[l]
  944. if Akl:
  945. Ak[l] = Akl
  946. else:
  947. # Drop nonzero elements
  948. Ak.pop(l)
  949. if l != j:
  950. nonzero_columns[l].remove(k)
  951. if len(Ak) == 1:
  952. reduced_pivots.add(k)
  953. nonreduced_pivots.remove(k)
  954. if len(Ai) == 1:
  955. reduced_pivots.add(j)
  956. else:
  957. nonreduced_pivots.add(j)
  958. for l in Ai:
  959. if l != j:
  960. nonzero_columns[l].add(j)
  961. # All done!
  962. pivots = sorted(reduced_pivots | nonreduced_pivots)
  963. pivot2row = {p: n for n, p in enumerate(pivots)}
  964. nonzero_columns = {c: set(pivot2row[p] for p in s) for c, s in nonzero_columns.items()}
  965. rows = [pivot_row_map[i] for i in pivots]
  966. rref = dict(enumerate(rows))
  967. return rref, pivots, nonzero_columns
  968. def sdm_nullspace_from_rref(A, one, ncols, pivots, nonzero_cols):
  969. """Get nullspace from A which is in RREF"""
  970. nonpivots = sorted(set(range(ncols)) - set(pivots))
  971. K = []
  972. for j in nonpivots:
  973. Kj = {j:one}
  974. for i in nonzero_cols.get(j, ()):
  975. Kj[pivots[i]] = -A[i][j]
  976. K.append(Kj)
  977. return K, nonpivots
  978. def sdm_particular_from_rref(A, ncols, pivots):
  979. """Get a particular solution from A which is in RREF"""
  980. P = {}
  981. for i, j in enumerate(pivots):
  982. Ain = A[i].get(ncols-1, None)
  983. if Ain is not None:
  984. P[j] = Ain / A[i][j]
  985. return P