sparse.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. from collections.abc import Callable
  2. from sympy.core.containers import Dict
  3. from sympy.utilities.exceptions import sympy_deprecation_warning
  4. from sympy.utilities.iterables import is_sequence
  5. from sympy.utilities.misc import as_int
  6. from .matrices import MatrixBase
  7. from .repmatrix import MutableRepMatrix, RepMatrix
  8. from .utilities import _iszero
  9. from .decompositions import (
  10. _liupc, _row_structure_symbolic_cholesky, _cholesky_sparse,
  11. _LDLdecomposition_sparse)
  12. from .solvers import (
  13. _lower_triangular_solve_sparse, _upper_triangular_solve_sparse)
  14. class SparseRepMatrix(RepMatrix):
  15. """
  16. A sparse matrix (a matrix with a large number of zero elements).
  17. Examples
  18. ========
  19. >>> from sympy import SparseMatrix, ones
  20. >>> SparseMatrix(2, 2, range(4))
  21. Matrix([
  22. [0, 1],
  23. [2, 3]])
  24. >>> SparseMatrix(2, 2, {(1, 1): 2})
  25. Matrix([
  26. [0, 0],
  27. [0, 2]])
  28. A SparseMatrix can be instantiated from a ragged list of lists:
  29. >>> SparseMatrix([[1, 2, 3], [1, 2], [1]])
  30. Matrix([
  31. [1, 2, 3],
  32. [1, 2, 0],
  33. [1, 0, 0]])
  34. For safety, one may include the expected size and then an error
  35. will be raised if the indices of any element are out of range or
  36. (for a flat list) if the total number of elements does not match
  37. the expected shape:
  38. >>> SparseMatrix(2, 2, [1, 2])
  39. Traceback (most recent call last):
  40. ...
  41. ValueError: List length (2) != rows*columns (4)
  42. Here, an error is not raised because the list is not flat and no
  43. element is out of range:
  44. >>> SparseMatrix(2, 2, [[1, 2]])
  45. Matrix([
  46. [1, 2],
  47. [0, 0]])
  48. But adding another element to the first (and only) row will cause
  49. an error to be raised:
  50. >>> SparseMatrix(2, 2, [[1, 2, 3]])
  51. Traceback (most recent call last):
  52. ...
  53. ValueError: The location (0, 2) is out of designated range: (1, 1)
  54. To autosize the matrix, pass None for rows:
  55. >>> SparseMatrix(None, [[1, 2, 3]])
  56. Matrix([[1, 2, 3]])
  57. >>> SparseMatrix(None, {(1, 1): 1, (3, 3): 3})
  58. Matrix([
  59. [0, 0, 0, 0],
  60. [0, 1, 0, 0],
  61. [0, 0, 0, 0],
  62. [0, 0, 0, 3]])
  63. Values that are themselves a Matrix are automatically expanded:
  64. >>> SparseMatrix(4, 4, {(1, 1): ones(2)})
  65. Matrix([
  66. [0, 0, 0, 0],
  67. [0, 1, 1, 0],
  68. [0, 1, 1, 0],
  69. [0, 0, 0, 0]])
  70. A ValueError is raised if the expanding matrix tries to overwrite
  71. a different element already present:
  72. >>> SparseMatrix(3, 3, {(0, 0): ones(2), (1, 1): 2})
  73. Traceback (most recent call last):
  74. ...
  75. ValueError: collision at (1, 1)
  76. See Also
  77. ========
  78. DenseMatrix
  79. MutableSparseMatrix
  80. ImmutableSparseMatrix
  81. """
  82. @classmethod
  83. def _handle_creation_inputs(cls, *args, **kwargs):
  84. if len(args) == 1 and isinstance(args[0], MatrixBase):
  85. rows = args[0].rows
  86. cols = args[0].cols
  87. smat = args[0].todok()
  88. return rows, cols, smat
  89. smat = {}
  90. # autosizing
  91. if len(args) == 2 and args[0] is None:
  92. args = [None, None, args[1]]
  93. if len(args) == 3:
  94. r, c = args[:2]
  95. if r is c is None:
  96. rows = cols = None
  97. elif None in (r, c):
  98. raise ValueError(
  99. 'Pass rows=None and no cols for autosizing.')
  100. else:
  101. rows, cols = as_int(args[0]), as_int(args[1])
  102. if isinstance(args[2], Callable):
  103. op = args[2]
  104. if None in (rows, cols):
  105. raise ValueError(
  106. "{} and {} must be integers for this "
  107. "specification.".format(rows, cols))
  108. row_indices = [cls._sympify(i) for i in range(rows)]
  109. col_indices = [cls._sympify(j) for j in range(cols)]
  110. for i in row_indices:
  111. for j in col_indices:
  112. value = cls._sympify(op(i, j))
  113. if value != cls.zero:
  114. smat[i, j] = value
  115. return rows, cols, smat
  116. elif isinstance(args[2], (dict, Dict)):
  117. def update(i, j, v):
  118. # update smat and make sure there are no collisions
  119. if v:
  120. if (i, j) in smat and v != smat[i, j]:
  121. raise ValueError(
  122. "There is a collision at {} for {} and {}."
  123. .format((i, j), v, smat[i, j])
  124. )
  125. smat[i, j] = v
  126. # manual copy, copy.deepcopy() doesn't work
  127. for (r, c), v in args[2].items():
  128. if isinstance(v, MatrixBase):
  129. for (i, j), vv in v.todok().items():
  130. update(r + i, c + j, vv)
  131. elif isinstance(v, (list, tuple)):
  132. _, _, smat = cls._handle_creation_inputs(v, **kwargs)
  133. for i, j in smat:
  134. update(r + i, c + j, smat[i, j])
  135. else:
  136. v = cls._sympify(v)
  137. update(r, c, cls._sympify(v))
  138. elif is_sequence(args[2]):
  139. flat = not any(is_sequence(i) for i in args[2])
  140. if not flat:
  141. _, _, smat = \
  142. cls._handle_creation_inputs(args[2], **kwargs)
  143. else:
  144. flat_list = args[2]
  145. if len(flat_list) != rows * cols:
  146. raise ValueError(
  147. "The length of the flat list ({}) does not "
  148. "match the specified size ({} * {})."
  149. .format(len(flat_list), rows, cols)
  150. )
  151. for i in range(rows):
  152. for j in range(cols):
  153. value = flat_list[i*cols + j]
  154. value = cls._sympify(value)
  155. if value != cls.zero:
  156. smat[i, j] = value
  157. if rows is None: # autosizing
  158. keys = smat.keys()
  159. rows = max([r for r, _ in keys]) + 1 if keys else 0
  160. cols = max([c for _, c in keys]) + 1 if keys else 0
  161. else:
  162. for i, j in smat.keys():
  163. if i and i >= rows or j and j >= cols:
  164. raise ValueError(
  165. "The location {} is out of the designated range"
  166. "[{}, {}]x[{}, {}]"
  167. .format((i, j), 0, rows - 1, 0, cols - 1)
  168. )
  169. return rows, cols, smat
  170. elif len(args) == 1 and isinstance(args[0], (list, tuple)):
  171. # list of values or lists
  172. v = args[0]
  173. c = 0
  174. for i, row in enumerate(v):
  175. if not isinstance(row, (list, tuple)):
  176. row = [row]
  177. for j, vv in enumerate(row):
  178. if vv != cls.zero:
  179. smat[i, j] = cls._sympify(vv)
  180. c = max(c, len(row))
  181. rows = len(v) if c else 0
  182. cols = c
  183. return rows, cols, smat
  184. else:
  185. # handle full matrix forms with _handle_creation_inputs
  186. rows, cols, mat = super()._handle_creation_inputs(*args)
  187. for i in range(rows):
  188. for j in range(cols):
  189. value = mat[cols*i + j]
  190. if value != cls.zero:
  191. smat[i, j] = value
  192. return rows, cols, smat
  193. @property
  194. def _smat(self):
  195. sympy_deprecation_warning(
  196. """
  197. The private _smat attribute of SparseMatrix is deprecated. Use the
  198. .todok() method instead.
  199. """,
  200. deprecated_since_version="1.9",
  201. active_deprecations_target="deprecated-private-matrix-attributes"
  202. )
  203. return self.todok()
  204. def _eval_inverse(self, **kwargs):
  205. return self.inv(method=kwargs.get('method', 'LDL'),
  206. iszerofunc=kwargs.get('iszerofunc', _iszero),
  207. try_block_diag=kwargs.get('try_block_diag', False))
  208. def applyfunc(self, f):
  209. """Apply a function to each element of the matrix.
  210. Examples
  211. ========
  212. >>> from sympy import SparseMatrix
  213. >>> m = SparseMatrix(2, 2, lambda i, j: i*2+j)
  214. >>> m
  215. Matrix([
  216. [0, 1],
  217. [2, 3]])
  218. >>> m.applyfunc(lambda i: 2*i)
  219. Matrix([
  220. [0, 2],
  221. [4, 6]])
  222. """
  223. if not callable(f):
  224. raise TypeError("`f` must be callable.")
  225. # XXX: This only applies the function to the nonzero elements of the
  226. # matrix so is inconsistent with DenseMatrix.applyfunc e.g.
  227. # zeros(2, 2).applyfunc(lambda x: x + 1)
  228. dok = {}
  229. for k, v in self.todok().items():
  230. fv = f(v)
  231. if fv != 0:
  232. dok[k] = fv
  233. return self._new(self.rows, self.cols, dok)
  234. def as_immutable(self):
  235. """Returns an Immutable version of this Matrix."""
  236. from .immutable import ImmutableSparseMatrix
  237. return ImmutableSparseMatrix(self)
  238. def as_mutable(self):
  239. """Returns a mutable version of this matrix.
  240. Examples
  241. ========
  242. >>> from sympy import ImmutableMatrix
  243. >>> X = ImmutableMatrix([[1, 2], [3, 4]])
  244. >>> Y = X.as_mutable()
  245. >>> Y[1, 1] = 5 # Can set values in Y
  246. >>> Y
  247. Matrix([
  248. [1, 2],
  249. [3, 5]])
  250. """
  251. return MutableSparseMatrix(self)
  252. def col_list(self):
  253. """Returns a column-sorted list of non-zero elements of the matrix.
  254. Examples
  255. ========
  256. >>> from sympy import SparseMatrix
  257. >>> a=SparseMatrix(((1, 2), (3, 4)))
  258. >>> a
  259. Matrix([
  260. [1, 2],
  261. [3, 4]])
  262. >>> a.CL
  263. [(0, 0, 1), (1, 0, 3), (0, 1, 2), (1, 1, 4)]
  264. See Also
  265. ========
  266. sympy.matrices.sparse.SparseMatrix.row_list
  267. """
  268. return [tuple(k + (self[k],)) for k in sorted(list(self.todok().keys()), key=lambda k: list(reversed(k)))]
  269. def nnz(self):
  270. """Returns the number of non-zero elements in Matrix."""
  271. return len(self.todok())
  272. def row_list(self):
  273. """Returns a row-sorted list of non-zero elements of the matrix.
  274. Examples
  275. ========
  276. >>> from sympy import SparseMatrix
  277. >>> a = SparseMatrix(((1, 2), (3, 4)))
  278. >>> a
  279. Matrix([
  280. [1, 2],
  281. [3, 4]])
  282. >>> a.RL
  283. [(0, 0, 1), (0, 1, 2), (1, 0, 3), (1, 1, 4)]
  284. See Also
  285. ========
  286. sympy.matrices.sparse.SparseMatrix.col_list
  287. """
  288. return [tuple(k + (self[k],)) for k in
  289. sorted(self.todok().keys(), key=list)]
  290. def scalar_multiply(self, scalar):
  291. "Scalar element-wise multiplication"
  292. return scalar * self
  293. def solve_least_squares(self, rhs, method='LDL'):
  294. """Return the least-square fit to the data.
  295. By default the cholesky_solve routine is used (method='CH'); other
  296. methods of matrix inversion can be used. To find out which are
  297. available, see the docstring of the .inv() method.
  298. Examples
  299. ========
  300. >>> from sympy import SparseMatrix, Matrix, ones
  301. >>> A = Matrix([1, 2, 3])
  302. >>> B = Matrix([2, 3, 4])
  303. >>> S = SparseMatrix(A.row_join(B))
  304. >>> S
  305. Matrix([
  306. [1, 2],
  307. [2, 3],
  308. [3, 4]])
  309. If each line of S represent coefficients of Ax + By
  310. and x and y are [2, 3] then S*xy is:
  311. >>> r = S*Matrix([2, 3]); r
  312. Matrix([
  313. [ 8],
  314. [13],
  315. [18]])
  316. But let's add 1 to the middle value and then solve for the
  317. least-squares value of xy:
  318. >>> xy = S.solve_least_squares(Matrix([8, 14, 18])); xy
  319. Matrix([
  320. [ 5/3],
  321. [10/3]])
  322. The error is given by S*xy - r:
  323. >>> S*xy - r
  324. Matrix([
  325. [1/3],
  326. [1/3],
  327. [1/3]])
  328. >>> _.norm().n(2)
  329. 0.58
  330. If a different xy is used, the norm will be higher:
  331. >>> xy += ones(2, 1)/10
  332. >>> (S*xy - r).norm().n(2)
  333. 1.5
  334. """
  335. t = self.T
  336. return (t*self).inv(method=method)*t*rhs
  337. def solve(self, rhs, method='LDL'):
  338. """Return solution to self*soln = rhs using given inversion method.
  339. For a list of possible inversion methods, see the .inv() docstring.
  340. """
  341. if not self.is_square:
  342. if self.rows < self.cols:
  343. raise ValueError('Under-determined system.')
  344. elif self.rows > self.cols:
  345. raise ValueError('For over-determined system, M, having '
  346. 'more rows than columns, try M.solve_least_squares(rhs).')
  347. else:
  348. return self.inv(method=method).multiply(rhs)
  349. RL = property(row_list, None, None, "Alternate faster representation")
  350. CL = property(col_list, None, None, "Alternate faster representation")
  351. def liupc(self):
  352. return _liupc(self)
  353. def row_structure_symbolic_cholesky(self):
  354. return _row_structure_symbolic_cholesky(self)
  355. def cholesky(self, hermitian=True):
  356. return _cholesky_sparse(self, hermitian=hermitian)
  357. def LDLdecomposition(self, hermitian=True):
  358. return _LDLdecomposition_sparse(self, hermitian=hermitian)
  359. def lower_triangular_solve(self, rhs):
  360. return _lower_triangular_solve_sparse(self, rhs)
  361. def upper_triangular_solve(self, rhs):
  362. return _upper_triangular_solve_sparse(self, rhs)
  363. liupc.__doc__ = _liupc.__doc__
  364. row_structure_symbolic_cholesky.__doc__ = _row_structure_symbolic_cholesky.__doc__
  365. cholesky.__doc__ = _cholesky_sparse.__doc__
  366. LDLdecomposition.__doc__ = _LDLdecomposition_sparse.__doc__
  367. lower_triangular_solve.__doc__ = lower_triangular_solve.__doc__
  368. upper_triangular_solve.__doc__ = upper_triangular_solve.__doc__
  369. class MutableSparseMatrix(SparseRepMatrix, MutableRepMatrix):
  370. @classmethod
  371. def _new(cls, *args, **kwargs):
  372. rows, cols, smat = cls._handle_creation_inputs(*args, **kwargs)
  373. rep = cls._smat_to_DomainMatrix(rows, cols, smat)
  374. return cls._fromrep(rep)
  375. SparseMatrix = MutableSparseMatrix