sparsetools.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. from sympy.core.containers import Dict
  2. from sympy.core.symbol import Dummy
  3. from sympy.utilities.iterables import is_sequence
  4. from sympy.utilities.misc import as_int, filldedent
  5. from .sparse import MutableSparseMatrix as SparseMatrix
  6. def _doktocsr(dok):
  7. """Converts a sparse matrix to Compressed Sparse Row (CSR) format.
  8. Parameters
  9. ==========
  10. A : contains non-zero elements sorted by key (row, column)
  11. JA : JA[i] is the column corresponding to A[i]
  12. IA : IA[i] contains the index in A for the first non-zero element
  13. of row[i]. Thus IA[i+1] - IA[i] gives number of non-zero
  14. elements row[i]. The length of IA is always 1 more than the
  15. number of rows in the matrix.
  16. Examples
  17. ========
  18. >>> from sympy.matrices.sparsetools import _doktocsr
  19. >>> from sympy import SparseMatrix, diag
  20. >>> m = SparseMatrix(diag(1, 2, 3))
  21. >>> m[2, 0] = -1
  22. >>> _doktocsr(m)
  23. [[1, 2, -1, 3], [0, 1, 0, 2], [0, 1, 2, 4], [3, 3]]
  24. """
  25. row, JA, A = [list(i) for i in zip(*dok.row_list())]
  26. IA = [0]*((row[0] if row else 0) + 1)
  27. for i, r in enumerate(row):
  28. IA.extend([i]*(r - row[i - 1])) # if i = 0 nothing is extended
  29. IA.extend([len(A)]*(dok.rows - len(IA) + 1))
  30. shape = [dok.rows, dok.cols]
  31. return [A, JA, IA, shape]
  32. def _csrtodok(csr):
  33. """Converts a CSR representation to DOK representation.
  34. Examples
  35. ========
  36. >>> from sympy.matrices.sparsetools import _csrtodok
  37. >>> _csrtodok([[5, 8, 3, 6], [0, 1, 2, 1], [0, 0, 2, 3, 4], [4, 3]])
  38. Matrix([
  39. [0, 0, 0],
  40. [5, 8, 0],
  41. [0, 0, 3],
  42. [0, 6, 0]])
  43. """
  44. smat = {}
  45. A, JA, IA, shape = csr
  46. for i in range(len(IA) - 1):
  47. indices = slice(IA[i], IA[i + 1])
  48. for l, m in zip(A[indices], JA[indices]):
  49. smat[i, m] = l
  50. return SparseMatrix(*shape, smat)
  51. def banded(*args, **kwargs):
  52. """Returns a SparseMatrix from the given dictionary describing
  53. the diagonals of the matrix. The keys are positive for upper
  54. diagonals and negative for those below the main diagonal. The
  55. values may be:
  56. * expressions or single-argument functions,
  57. * lists or tuples of values,
  58. * matrices
  59. Unless dimensions are given, the size of the returned matrix will
  60. be large enough to contain the largest non-zero value provided.
  61. kwargs
  62. ======
  63. rows : rows of the resulting matrix; computed if
  64. not given.
  65. cols : columns of the resulting matrix; computed if
  66. not given.
  67. Examples
  68. ========
  69. >>> from sympy import banded, ones, Matrix
  70. >>> from sympy.abc import x
  71. If explicit values are given in tuples,
  72. the matrix will autosize to contain all values, otherwise
  73. a single value is filled onto the entire diagonal:
  74. >>> banded({1: (1, 2, 3), -1: (4, 5, 6), 0: x})
  75. Matrix([
  76. [x, 1, 0, 0],
  77. [4, x, 2, 0],
  78. [0, 5, x, 3],
  79. [0, 0, 6, x]])
  80. A function accepting a single argument can be used to fill the
  81. diagonal as a function of diagonal index (which starts at 0).
  82. The size (or shape) of the matrix must be given to obtain more
  83. than a 1x1 matrix:
  84. >>> s = lambda d: (1 + d)**2
  85. >>> banded(5, {0: s, 2: s, -2: 2})
  86. Matrix([
  87. [1, 0, 1, 0, 0],
  88. [0, 4, 0, 4, 0],
  89. [2, 0, 9, 0, 9],
  90. [0, 2, 0, 16, 0],
  91. [0, 0, 2, 0, 25]])
  92. The diagonal of matrices placed on a diagonal will coincide
  93. with the indicated diagonal:
  94. >>> vert = Matrix([1, 2, 3])
  95. >>> banded({0: vert}, cols=3)
  96. Matrix([
  97. [1, 0, 0],
  98. [2, 1, 0],
  99. [3, 2, 1],
  100. [0, 3, 2],
  101. [0, 0, 3]])
  102. >>> banded(4, {0: ones(2)})
  103. Matrix([
  104. [1, 1, 0, 0],
  105. [1, 1, 0, 0],
  106. [0, 0, 1, 1],
  107. [0, 0, 1, 1]])
  108. Errors are raised if the designated size will not hold
  109. all values an integral number of times. Here, the rows
  110. are designated as odd (but an even number is required to
  111. hold the off-diagonal 2x2 ones):
  112. >>> banded({0: 2, 1: ones(2)}, rows=5)
  113. Traceback (most recent call last):
  114. ...
  115. ValueError:
  116. sequence does not fit an integral number of times in the matrix
  117. And here, an even number of rows is given...but the square
  118. matrix has an even number of columns, too. As we saw
  119. in the previous example, an odd number is required:
  120. >>> banded(4, {0: 2, 1: ones(2)}) # trying to make 4x4 and cols must be odd
  121. Traceback (most recent call last):
  122. ...
  123. ValueError:
  124. sequence does not fit an integral number of times in the matrix
  125. A way around having to count rows is to enclosing matrix elements
  126. in a tuple and indicate the desired number of them to the right:
  127. >>> banded({0: 2, 2: (ones(2),)*3})
  128. Matrix([
  129. [2, 0, 1, 1, 0, 0, 0, 0],
  130. [0, 2, 1, 1, 0, 0, 0, 0],
  131. [0, 0, 2, 0, 1, 1, 0, 0],
  132. [0, 0, 0, 2, 1, 1, 0, 0],
  133. [0, 0, 0, 0, 2, 0, 1, 1],
  134. [0, 0, 0, 0, 0, 2, 1, 1]])
  135. An error will be raised if more than one value
  136. is written to a given entry. Here, the ones overlap
  137. with the main diagonal if they are placed on the
  138. first diagonal:
  139. >>> banded({0: (2,)*5, 1: (ones(2),)*3})
  140. Traceback (most recent call last):
  141. ...
  142. ValueError: collision at (1, 1)
  143. By placing a 0 at the bottom left of the 2x2 matrix of
  144. ones, the collision is avoided:
  145. >>> u2 = Matrix([
  146. ... [1, 1],
  147. ... [0, 1]])
  148. >>> banded({0: [2]*5, 1: [u2]*3})
  149. Matrix([
  150. [2, 1, 1, 0, 0, 0, 0],
  151. [0, 2, 1, 0, 0, 0, 0],
  152. [0, 0, 2, 1, 1, 0, 0],
  153. [0, 0, 0, 2, 1, 0, 0],
  154. [0, 0, 0, 0, 2, 1, 1],
  155. [0, 0, 0, 0, 0, 0, 1]])
  156. """
  157. try:
  158. if len(args) not in (1, 2, 3):
  159. raise TypeError
  160. if not isinstance(args[-1], (dict, Dict)):
  161. raise TypeError
  162. if len(args) == 1:
  163. rows = kwargs.get('rows', None)
  164. cols = kwargs.get('cols', None)
  165. if rows is not None:
  166. rows = as_int(rows)
  167. if cols is not None:
  168. cols = as_int(cols)
  169. elif len(args) == 2:
  170. rows = cols = as_int(args[0])
  171. else:
  172. rows, cols = map(as_int, args[:2])
  173. # fails with ValueError if any keys are not ints
  174. _ = all(as_int(k) for k in args[-1])
  175. except (ValueError, TypeError):
  176. raise TypeError(filldedent(
  177. '''unrecognized input to banded:
  178. expecting [[row,] col,] {int: value}'''))
  179. def rc(d):
  180. # return row,col coord of diagonal start
  181. r = -d if d < 0 else 0
  182. c = 0 if r else d
  183. return r, c
  184. smat = {}
  185. undone = []
  186. tba = Dummy()
  187. # first handle objects with size
  188. for d, v in args[-1].items():
  189. r, c = rc(d)
  190. # note: only list and tuple are recognized since this
  191. # will allow other Basic objects like Tuple
  192. # into the matrix if so desired
  193. if isinstance(v, (list, tuple)):
  194. extra = 0
  195. for i, vi in enumerate(v):
  196. i += extra
  197. if is_sequence(vi):
  198. vi = SparseMatrix(vi)
  199. smat[r + i, c + i] = vi
  200. extra += min(vi.shape) - 1
  201. else:
  202. smat[r + i, c + i] = vi
  203. elif is_sequence(v):
  204. v = SparseMatrix(v)
  205. rv, cv = v.shape
  206. if rows and cols:
  207. nr, xr = divmod(rows - r, rv)
  208. nc, xc = divmod(cols - c, cv)
  209. x = xr or xc
  210. do = min(nr, nc)
  211. elif rows:
  212. do, x = divmod(rows - r, rv)
  213. elif cols:
  214. do, x = divmod(cols - c, cv)
  215. else:
  216. do = 1
  217. x = 0
  218. if x:
  219. raise ValueError(filldedent('''
  220. sequence does not fit an integral number of times
  221. in the matrix'''))
  222. j = min(v.shape)
  223. for i in range(do):
  224. smat[r, c] = v
  225. r += j
  226. c += j
  227. elif v:
  228. smat[r, c] = tba
  229. undone.append((d, v))
  230. s = SparseMatrix(None, smat) # to expand matrices
  231. smat = s.todok()
  232. # check for dim errors here
  233. if rows is not None and rows < s.rows:
  234. raise ValueError('Designated rows %s < needed %s' % (rows, s.rows))
  235. if cols is not None and cols < s.cols:
  236. raise ValueError('Designated cols %s < needed %s' % (cols, s.cols))
  237. if rows is cols is None:
  238. rows = s.rows
  239. cols = s.cols
  240. elif rows is not None and cols is None:
  241. cols = max(rows, s.cols)
  242. elif cols is not None and rows is None:
  243. rows = max(cols, s.rows)
  244. def update(i, j, v):
  245. # update smat and make sure there are
  246. # no collisions
  247. if v:
  248. if (i, j) in smat and smat[i, j] not in (tba, v):
  249. raise ValueError('collision at %s' % ((i, j),))
  250. smat[i, j] = v
  251. if undone:
  252. for d, vi in undone:
  253. r, c = rc(d)
  254. v = vi if callable(vi) else lambda _: vi
  255. i = 0
  256. while r + i < rows and c + i < cols:
  257. update(r + i, c + i, v(i))
  258. i += 1
  259. return SparseMatrix(rows, cols, smat)