ddm.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. """
  2. Module for the DDM class.
  3. The DDM class is an internal representation used by DomainMatrix. The letters
  4. DDM stand for Dense Domain Matrix. A DDM instance represents a matrix using
  5. elements from a polynomial Domain (e.g. ZZ, QQ, ...) in a dense-matrix
  6. representation.
  7. Basic usage:
  8. >>> from sympy import ZZ, QQ
  9. >>> from sympy.polys.matrices.ddm import DDM
  10. >>> A = DDM([[ZZ(0), ZZ(1)], [ZZ(-1), ZZ(0)]], (2, 2), ZZ)
  11. >>> A.shape
  12. (2, 2)
  13. >>> A
  14. [[0, 1], [-1, 0]]
  15. >>> type(A)
  16. <class 'sympy.polys.matrices.ddm.DDM'>
  17. >>> A @ A
  18. [[-1, 0], [0, -1]]
  19. The ddm_* functions are designed to operate on DDM as well as on an ordinary
  20. list of lists:
  21. >>> from sympy.polys.matrices.dense import ddm_idet
  22. >>> ddm_idet(A, QQ)
  23. 1
  24. >>> ddm_idet([[0, 1], [-1, 0]], QQ)
  25. 1
  26. >>> A
  27. [[-1, 0], [0, -1]]
  28. Note that ddm_idet modifies the input matrix in-place. It is recommended to
  29. use the DDM.det method as a friendlier interface to this instead which takes
  30. care of copying the matrix:
  31. >>> B = DDM([[ZZ(0), ZZ(1)], [ZZ(-1), ZZ(0)]], (2, 2), ZZ)
  32. >>> B.det()
  33. 1
  34. Normally DDM would not be used directly and is just part of the internal
  35. representation of DomainMatrix which adds further functionality including e.g.
  36. unifying domains.
  37. The dense format used by DDM is a list of lists of elements e.g. the 2x2
  38. identity matrix is like [[1, 0], [0, 1]]. The DDM class itself is a subclass
  39. of list and its list items are plain lists. Elements are accessed as e.g.
  40. ddm[i][j] where ddm[i] gives the ith row and ddm[i][j] gets the element in the
  41. jth column of that row. Subclassing list makes e.g. iteration and indexing
  42. very efficient. We do not override __getitem__ because it would lose that
  43. benefit.
  44. The core routines are implemented by the ddm_* functions defined in dense.py.
  45. Those functions are intended to be able to operate on a raw list-of-lists
  46. representation of matrices with most functions operating in-place. The DDM
  47. class takes care of copying etc and also stores a Domain object associated
  48. with its elements. This makes it possible to implement things like A + B with
  49. domain checking and also shape checking so that the list of lists
  50. representation is friendlier.
  51. """
  52. from itertools import chain
  53. from .exceptions import DMBadInputError, DMShapeError, DMDomainError
  54. from .dense import (
  55. ddm_transpose,
  56. ddm_iadd,
  57. ddm_isub,
  58. ddm_ineg,
  59. ddm_imul,
  60. ddm_irmul,
  61. ddm_imatmul,
  62. ddm_irref,
  63. ddm_idet,
  64. ddm_iinv,
  65. ddm_ilu_split,
  66. ddm_ilu_solve,
  67. ddm_berk,
  68. )
  69. class DDM(list):
  70. """Dense matrix based on polys domain elements
  71. This is a list subclass and is a wrapper for a list of lists that supports
  72. basic matrix arithmetic +, -, *, **.
  73. """
  74. fmt = 'dense'
  75. def __init__(self, rowslist, shape, domain):
  76. super().__init__(rowslist)
  77. self.shape = self.rows, self.cols = m, n = shape
  78. self.domain = domain
  79. if not (len(self) == m and all(len(row) == n for row in self)):
  80. raise DMBadInputError("Inconsistent row-list/shape")
  81. def getitem(self, i, j):
  82. return self[i][j]
  83. def setitem(self, i, j, value):
  84. self[i][j] = value
  85. def extract_slice(self, slice1, slice2):
  86. ddm = [row[slice2] for row in self[slice1]]
  87. rows = len(ddm)
  88. cols = len(ddm[0]) if ddm else len(range(self.shape[1])[slice2])
  89. return DDM(ddm, (rows, cols), self.domain)
  90. def extract(self, rows, cols):
  91. ddm = []
  92. for i in rows:
  93. rowi = self[i]
  94. ddm.append([rowi[j] for j in cols])
  95. return DDM(ddm, (len(rows), len(cols)), self.domain)
  96. def to_list(self):
  97. return list(self)
  98. def to_list_flat(self):
  99. flat = []
  100. for row in self:
  101. flat.extend(row)
  102. return flat
  103. def flatiter(self):
  104. return chain.from_iterable(self)
  105. def flat(self):
  106. items = []
  107. for row in self:
  108. items.extend(row)
  109. return items
  110. def to_dok(self):
  111. return {(i, j): e for i, row in enumerate(self) for j, e in enumerate(row)}
  112. def to_ddm(self):
  113. return self
  114. def to_sdm(self):
  115. return SDM.from_list(self, self.shape, self.domain)
  116. def convert_to(self, K):
  117. Kold = self.domain
  118. if K == Kold:
  119. return self.copy()
  120. rows = ([K.convert_from(e, Kold) for e in row] for row in self)
  121. return DDM(rows, self.shape, K)
  122. def __str__(self):
  123. rowsstr = ['[%s]' % ', '.join(map(str, row)) for row in self]
  124. return '[%s]' % ', '.join(rowsstr)
  125. def __repr__(self):
  126. cls = type(self).__name__
  127. rows = list.__repr__(self)
  128. return '%s(%s, %s, %s)' % (cls, rows, self.shape, self.domain)
  129. def __eq__(self, other):
  130. if not isinstance(other, DDM):
  131. return False
  132. return (super().__eq__(other) and self.domain == other.domain)
  133. def __ne__(self, other):
  134. return not self.__eq__(other)
  135. @classmethod
  136. def zeros(cls, shape, domain):
  137. z = domain.zero
  138. m, n = shape
  139. rowslist = ([z] * n for _ in range(m))
  140. return DDM(rowslist, shape, domain)
  141. @classmethod
  142. def ones(cls, shape, domain):
  143. one = domain.one
  144. m, n = shape
  145. rowlist = ([one] * n for _ in range(m))
  146. return DDM(rowlist, shape, domain)
  147. @classmethod
  148. def eye(cls, size, domain):
  149. one = domain.one
  150. ddm = cls.zeros((size, size), domain)
  151. for i in range(size):
  152. ddm[i][i] = one
  153. return ddm
  154. def copy(self):
  155. copyrows = (row[:] for row in self)
  156. return DDM(copyrows, self.shape, self.domain)
  157. def transpose(self):
  158. rows, cols = self.shape
  159. if rows:
  160. ddmT = ddm_transpose(self)
  161. else:
  162. ddmT = [[]] * cols
  163. return DDM(ddmT, (cols, rows), self.domain)
  164. def __add__(a, b):
  165. if not isinstance(b, DDM):
  166. return NotImplemented
  167. return a.add(b)
  168. def __sub__(a, b):
  169. if not isinstance(b, DDM):
  170. return NotImplemented
  171. return a.sub(b)
  172. def __neg__(a):
  173. return a.neg()
  174. def __mul__(a, b):
  175. if b in a.domain:
  176. return a.mul(b)
  177. else:
  178. return NotImplemented
  179. def __rmul__(a, b):
  180. if b in a.domain:
  181. return a.mul(b)
  182. else:
  183. return NotImplemented
  184. def __matmul__(a, b):
  185. if isinstance(b, DDM):
  186. return a.matmul(b)
  187. else:
  188. return NotImplemented
  189. @classmethod
  190. def _check(cls, a, op, b, ashape, bshape):
  191. if a.domain != b.domain:
  192. msg = "Domain mismatch: %s %s %s" % (a.domain, op, b.domain)
  193. raise DMDomainError(msg)
  194. if ashape != bshape:
  195. msg = "Shape mismatch: %s %s %s" % (a.shape, op, b.shape)
  196. raise DMShapeError(msg)
  197. def add(a, b):
  198. """a + b"""
  199. a._check(a, '+', b, a.shape, b.shape)
  200. c = a.copy()
  201. ddm_iadd(c, b)
  202. return c
  203. def sub(a, b):
  204. """a - b"""
  205. a._check(a, '-', b, a.shape, b.shape)
  206. c = a.copy()
  207. ddm_isub(c, b)
  208. return c
  209. def neg(a):
  210. """-a"""
  211. b = a.copy()
  212. ddm_ineg(b)
  213. return b
  214. def mul(a, b):
  215. c = a.copy()
  216. ddm_imul(c, b)
  217. return c
  218. def rmul(a, b):
  219. c = a.copy()
  220. ddm_irmul(c, b)
  221. return c
  222. def matmul(a, b):
  223. """a @ b (matrix product)"""
  224. m, o = a.shape
  225. o2, n = b.shape
  226. a._check(a, '*', b, o, o2)
  227. c = a.zeros((m, n), a.domain)
  228. ddm_imatmul(c, a, b)
  229. return c
  230. def mul_elementwise(a, b):
  231. assert a.shape == b.shape
  232. assert a.domain == b.domain
  233. c = [[aij * bij for aij, bij in zip(ai, bi)] for ai, bi in zip(a, b)]
  234. return DDM(c, a.shape, a.domain)
  235. def hstack(A, *B):
  236. """Horizontally stacks :py:class:`~.DDM` matrices.
  237. Examples
  238. ========
  239. >>> from sympy import ZZ
  240. >>> from sympy.polys.matrices.sdm import DDM
  241. >>> A = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ)
  242. >>> B = DDM([[ZZ(5), ZZ(6)], [ZZ(7), ZZ(8)]], (2, 2), ZZ)
  243. >>> A.hstack(B)
  244. [[1, 2, 5, 6], [3, 4, 7, 8]]
  245. >>> C = DDM([[ZZ(9), ZZ(10)], [ZZ(11), ZZ(12)]], (2, 2), ZZ)
  246. >>> A.hstack(B, C)
  247. [[1, 2, 5, 6, 9, 10], [3, 4, 7, 8, 11, 12]]
  248. """
  249. Anew = list(A.copy())
  250. rows, cols = A.shape
  251. domain = A.domain
  252. for Bk in B:
  253. Bkrows, Bkcols = Bk.shape
  254. assert Bkrows == rows
  255. assert Bk.domain == domain
  256. cols += Bkcols
  257. for i, Bki in enumerate(Bk):
  258. Anew[i].extend(Bki)
  259. return DDM(Anew, (rows, cols), A.domain)
  260. def vstack(A, *B):
  261. """Vertically stacks :py:class:`~.DDM` matrices.
  262. Examples
  263. ========
  264. >>> from sympy import ZZ
  265. >>> from sympy.polys.matrices.sdm import DDM
  266. >>> A = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ)
  267. >>> B = DDM([[ZZ(5), ZZ(6)], [ZZ(7), ZZ(8)]], (2, 2), ZZ)
  268. >>> A.vstack(B)
  269. [[1, 2], [3, 4], [5, 6], [7, 8]]
  270. >>> C = DDM([[ZZ(9), ZZ(10)], [ZZ(11), ZZ(12)]], (2, 2), ZZ)
  271. >>> A.vstack(B, C)
  272. [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]
  273. """
  274. Anew = list(A.copy())
  275. rows, cols = A.shape
  276. domain = A.domain
  277. for Bk in B:
  278. Bkrows, Bkcols = Bk.shape
  279. assert Bkcols == cols
  280. assert Bk.domain == domain
  281. rows += Bkrows
  282. Anew.extend(Bk.copy())
  283. return DDM(Anew, (rows, cols), A.domain)
  284. def applyfunc(self, func, domain):
  285. elements = (list(map(func, row)) for row in self)
  286. return DDM(elements, self.shape, domain)
  287. def scc(a):
  288. """Strongly connected components of a square matrix *a*.
  289. Examples
  290. ========
  291. >>> from sympy import ZZ
  292. >>> from sympy.polys.matrices.sdm import DDM
  293. >>> A = DDM([[ZZ(1), ZZ(0)], [ZZ(0), ZZ(1)]], (2, 2), ZZ)
  294. >>> A.scc()
  295. [[0], [1]]
  296. See also
  297. ========
  298. sympy.polys.matrices.domainmatrix.DomainMatrix.scc
  299. """
  300. return a.to_sdm().scc()
  301. def rref(a):
  302. """Reduced-row echelon form of a and list of pivots"""
  303. b = a.copy()
  304. K = a.domain
  305. partial_pivot = K.is_RealField or K.is_ComplexField
  306. pivots = ddm_irref(b, _partial_pivot=partial_pivot)
  307. return b, pivots
  308. def nullspace(a):
  309. rref, pivots = a.rref()
  310. rows, cols = a.shape
  311. domain = a.domain
  312. basis = []
  313. nonpivots = []
  314. for i in range(cols):
  315. if i in pivots:
  316. continue
  317. nonpivots.append(i)
  318. vec = [domain.one if i == j else domain.zero for j in range(cols)]
  319. for ii, jj in enumerate(pivots):
  320. vec[jj] -= rref[ii][i]
  321. basis.append(vec)
  322. return DDM(basis, (len(basis), cols), domain), nonpivots
  323. def particular(a):
  324. return a.to_sdm().particular().to_ddm()
  325. def det(a):
  326. """Determinant of a"""
  327. m, n = a.shape
  328. if m != n:
  329. raise DMShapeError("Determinant of non-square matrix")
  330. b = a.copy()
  331. K = b.domain
  332. deta = ddm_idet(b, K)
  333. return deta
  334. def inv(a):
  335. """Inverse of a"""
  336. m, n = a.shape
  337. if m != n:
  338. raise DMShapeError("Determinant of non-square matrix")
  339. ainv = a.copy()
  340. K = a.domain
  341. ddm_iinv(ainv, a, K)
  342. return ainv
  343. def lu(a):
  344. """L, U decomposition of a"""
  345. m, n = a.shape
  346. K = a.domain
  347. U = a.copy()
  348. L = a.eye(m, K)
  349. swaps = ddm_ilu_split(L, U, K)
  350. return L, U, swaps
  351. def lu_solve(a, b):
  352. """x where a*x = b"""
  353. m, n = a.shape
  354. m2, o = b.shape
  355. a._check(a, 'lu_solve', b, m, m2)
  356. L, U, swaps = a.lu()
  357. x = a.zeros((n, o), a.domain)
  358. ddm_ilu_solve(x, L, U, swaps, b)
  359. return x
  360. def charpoly(a):
  361. """Coefficients of characteristic polynomial of a"""
  362. K = a.domain
  363. m, n = a.shape
  364. if m != n:
  365. raise DMShapeError("Charpoly of non-square matrix")
  366. vec = ddm_berk(a, K)
  367. coeffs = [vec[i][0] for i in range(n+1)]
  368. return coeffs
  369. def is_zero_matrix(self):
  370. """
  371. Says whether this matrix has all zero entries.
  372. """
  373. zero = self.domain.zero
  374. return all(Mij == zero for Mij in self.flatiter())
  375. def is_upper(self):
  376. """
  377. Says whether this matrix is upper-triangular. True can be returned
  378. even if the matrix is not square.
  379. """
  380. zero = self.domain.zero
  381. return all(Mij == zero for i, Mi in enumerate(self) for Mij in Mi[:i])
  382. def is_lower(self):
  383. """
  384. Says whether this matrix is lower-triangular. True can be returned
  385. even if the matrix is not square.
  386. """
  387. zero = self.domain.zero
  388. return all(Mij == zero for i, Mi in enumerate(self) for Mij in Mi[i+1:])
  389. from .sdm import SDM