dense.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. """
  2. Module for the ddm_* routines for operating on a matrix in list of lists
  3. matrix representation.
  4. These routines are used internally by the DDM class which also provides a
  5. friendlier interface for them. The idea here is to implement core matrix
  6. routines in a way that can be applied to any simple list representation
  7. without the need to use any particular matrix class. For example we can
  8. compute the RREF of a matrix like:
  9. >>> from sympy.polys.matrices.dense import ddm_irref
  10. >>> M = [[1, 2, 3], [4, 5, 6]]
  11. >>> pivots = ddm_irref(M)
  12. >>> M
  13. [[1.0, 0.0, -1.0], [0, 1.0, 2.0]]
  14. These are lower-level routines that work mostly in place.The routines at this
  15. level should not need to know what the domain of the elements is but should
  16. ideally document what operations they will use and what functions they need to
  17. be provided with.
  18. The next-level up is the DDM class which uses these routines but wraps them up
  19. with an interface that handles copying etc and keeps track of the Domain of
  20. the elements of the matrix:
  21. >>> from sympy.polys.domains import QQ
  22. >>> from sympy.polys.matrices.ddm import DDM
  23. >>> M = DDM([[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(5), QQ(6)]], (2, 3), QQ)
  24. >>> M
  25. [[1, 2, 3], [4, 5, 6]]
  26. >>> Mrref, pivots = M.rref()
  27. >>> Mrref
  28. [[1, 0, -1], [0, 1, 2]]
  29. """
  30. from operator import mul
  31. from .exceptions import (
  32. DMShapeError,
  33. DMNonInvertibleMatrixError,
  34. DMNonSquareMatrixError,
  35. )
  36. def ddm_transpose(a):
  37. """matrix transpose"""
  38. aT = list(map(list, zip(*a)))
  39. return aT
  40. def ddm_iadd(a, b):
  41. """a += b"""
  42. for ai, bi in zip(a, b):
  43. for j, bij in enumerate(bi):
  44. ai[j] += bij
  45. def ddm_isub(a, b):
  46. """a -= b"""
  47. for ai, bi in zip(a, b):
  48. for j, bij in enumerate(bi):
  49. ai[j] -= bij
  50. def ddm_ineg(a):
  51. """a <-- -a"""
  52. for ai in a:
  53. for j, aij in enumerate(ai):
  54. ai[j] = -aij
  55. def ddm_imul(a, b):
  56. for ai in a:
  57. for j, aij in enumerate(ai):
  58. ai[j] = aij * b
  59. def ddm_irmul(a, b):
  60. for ai in a:
  61. for j, aij in enumerate(ai):
  62. ai[j] = b * aij
  63. def ddm_imatmul(a, b, c):
  64. """a += b @ c"""
  65. cT = list(zip(*c))
  66. for bi, ai in zip(b, a):
  67. for j, cTj in enumerate(cT):
  68. ai[j] = sum(map(mul, bi, cTj), ai[j])
  69. def ddm_irref(a, _partial_pivot=False):
  70. """a <-- rref(a)"""
  71. # a is (m x n)
  72. m = len(a)
  73. if not m:
  74. return []
  75. n = len(a[0])
  76. i = 0
  77. pivots = []
  78. for j in range(n):
  79. # Proper pivoting should be used for all domains for performance
  80. # reasons but it is only strictly needed for RR and CC (and possibly
  81. # other domains like RR(x)). This path is used by DDM.rref() if the
  82. # domain is RR or CC. It uses partial (row) pivoting based on the
  83. # absolute value of the pivot candidates.
  84. if _partial_pivot:
  85. ip = max(range(i, m), key=lambda ip: abs(a[ip][j]))
  86. a[i], a[ip] = a[ip], a[i]
  87. # pivot
  88. aij = a[i][j]
  89. # zero-pivot
  90. if not aij:
  91. for ip in range(i+1, m):
  92. aij = a[ip][j]
  93. # row-swap
  94. if aij:
  95. a[i], a[ip] = a[ip], a[i]
  96. break
  97. else:
  98. # next column
  99. continue
  100. # normalise row
  101. ai = a[i]
  102. aijinv = aij**-1
  103. for l in range(j, n):
  104. ai[l] *= aijinv # ai[j] = one
  105. # eliminate above and below to the right
  106. for k, ak in enumerate(a):
  107. if k == i or not ak[j]:
  108. continue
  109. akj = ak[j]
  110. ak[j] -= akj # ak[j] = zero
  111. for l in range(j+1, n):
  112. ak[l] -= akj * ai[l]
  113. # next row
  114. pivots.append(j)
  115. i += 1
  116. # no more rows?
  117. if i >= m:
  118. break
  119. return pivots
  120. def ddm_idet(a, K):
  121. """a <-- echelon(a); return det"""
  122. # Bareiss algorithm
  123. # https://www.math.usm.edu/perry/Research/Thesis_DRL.pdf
  124. # a is (m x n)
  125. m = len(a)
  126. if not m:
  127. return K.one
  128. n = len(a[0])
  129. exquo = K.exquo
  130. # uf keeps track of the sign change from row swaps
  131. uf = K.one
  132. for k in range(n-1):
  133. if not a[k][k]:
  134. for i in range(k+1, n):
  135. if a[i][k]:
  136. a[k], a[i] = a[i], a[k]
  137. uf = -uf
  138. break
  139. else:
  140. return K.zero
  141. akkm1 = a[k-1][k-1] if k else K.one
  142. for i in range(k+1, n):
  143. for j in range(k+1, n):
  144. a[i][j] = exquo(a[i][j]*a[k][k] - a[i][k]*a[k][j], akkm1)
  145. return uf * a[-1][-1]
  146. def ddm_iinv(ainv, a, K):
  147. if not K.is_Field:
  148. raise ValueError('Not a field')
  149. # a is (m x n)
  150. m = len(a)
  151. if not m:
  152. return
  153. n = len(a[0])
  154. if m != n:
  155. raise DMNonSquareMatrixError
  156. eye = [[K.one if i==j else K.zero for j in range(n)] for i in range(n)]
  157. Aaug = [row + eyerow for row, eyerow in zip(a, eye)]
  158. pivots = ddm_irref(Aaug)
  159. if pivots != list(range(n)):
  160. raise DMNonInvertibleMatrixError('Matrix det == 0; not invertible.')
  161. ainv[:] = [row[n:] for row in Aaug]
  162. def ddm_ilu_split(L, U, K):
  163. """L, U <-- LU(U)"""
  164. m = len(U)
  165. if not m:
  166. return []
  167. n = len(U[0])
  168. swaps = ddm_ilu(U)
  169. zeros = [K.zero] * min(m, n)
  170. for i in range(1, m):
  171. j = min(i, n)
  172. L[i][:j] = U[i][:j]
  173. U[i][:j] = zeros[:j]
  174. return swaps
  175. def ddm_ilu(a):
  176. """a <-- LU(a)"""
  177. m = len(a)
  178. if not m:
  179. return []
  180. n = len(a[0])
  181. swaps = []
  182. for i in range(min(m, n)):
  183. if not a[i][i]:
  184. for ip in range(i+1, m):
  185. if a[ip][i]:
  186. swaps.append((i, ip))
  187. a[i], a[ip] = a[ip], a[i]
  188. break
  189. else:
  190. # M = Matrix([[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 1], [0, 0, 1, 2]])
  191. continue
  192. for j in range(i+1, m):
  193. l_ji = a[j][i] / a[i][i]
  194. a[j][i] = l_ji
  195. for k in range(i+1, n):
  196. a[j][k] -= l_ji * a[i][k]
  197. return swaps
  198. def ddm_ilu_solve(x, L, U, swaps, b):
  199. """x <-- solve(L*U*x = swaps(b))"""
  200. m = len(U)
  201. if not m:
  202. return
  203. n = len(U[0])
  204. m2 = len(b)
  205. if not m2:
  206. raise DMShapeError("Shape mismtch")
  207. o = len(b[0])
  208. if m != m2:
  209. raise DMShapeError("Shape mismtch")
  210. if m < n:
  211. raise NotImplementedError("Underdetermined")
  212. if swaps:
  213. b = [row[:] for row in b]
  214. for i1, i2 in swaps:
  215. b[i1], b[i2] = b[i2], b[i1]
  216. # solve Ly = b
  217. y = [[None] * o for _ in range(m)]
  218. for k in range(o):
  219. for i in range(m):
  220. rhs = b[i][k]
  221. for j in range(i):
  222. rhs -= L[i][j] * y[j][k]
  223. y[i][k] = rhs
  224. if m > n:
  225. for i in range(n, m):
  226. for j in range(o):
  227. if y[i][j]:
  228. raise DMNonInvertibleMatrixError
  229. # Solve Ux = y
  230. for k in range(o):
  231. for i in reversed(range(n)):
  232. if not U[i][i]:
  233. raise DMNonInvertibleMatrixError
  234. rhs = y[i][k]
  235. for j in range(i+1, n):
  236. rhs -= U[i][j] * x[j][k]
  237. x[i][k] = rhs / U[i][i]
  238. def ddm_berk(M, K):
  239. m = len(M)
  240. if not m:
  241. return [[K.one]]
  242. n = len(M[0])
  243. if m != n:
  244. raise DMShapeError("Not square")
  245. if n == 1:
  246. return [[K.one], [-M[0][0]]]
  247. a = M[0][0]
  248. R = [M[0][1:]]
  249. C = [[row[0]] for row in M[1:]]
  250. A = [row[1:] for row in M[1:]]
  251. q = ddm_berk(A, K)
  252. T = [[K.zero] * n for _ in range(n+1)]
  253. for i in range(n):
  254. T[i][i] = K.one
  255. T[i+1][i] = -a
  256. for i in range(2, n+1):
  257. if i == 2:
  258. AnC = C
  259. else:
  260. C = AnC
  261. AnC = [[K.zero] for row in C]
  262. ddm_imatmul(AnC, A, C)
  263. RAnC = [[K.zero]]
  264. ddm_imatmul(RAnC, R, AnC)
  265. for j in range(0, n+1-i):
  266. T[i+j][j] = -RAnC[0][0]
  267. qout = [[K.zero] for _ in range(n+1)]
  268. ddm_imatmul(qout, T, q)
  269. return qout