matrixutils.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. """Utilities to deal with sympy.Matrix, numpy and scipy.sparse."""
  2. from sympy.core.expr import Expr
  3. from sympy.core.numbers import I
  4. from sympy.core.singleton import S
  5. from sympy.matrices.matrices import MatrixBase
  6. from sympy.matrices import eye, zeros
  7. from sympy.external import import_module
  8. __all__ = [
  9. 'numpy_ndarray',
  10. 'scipy_sparse_matrix',
  11. 'sympy_to_numpy',
  12. 'sympy_to_scipy_sparse',
  13. 'numpy_to_sympy',
  14. 'scipy_sparse_to_sympy',
  15. 'flatten_scalar',
  16. 'matrix_dagger',
  17. 'to_sympy',
  18. 'to_numpy',
  19. 'to_scipy_sparse',
  20. 'matrix_tensor_product',
  21. 'matrix_zeros'
  22. ]
  23. # Conditionally define the base classes for numpy and scipy.sparse arrays
  24. # for use in isinstance tests.
  25. np = import_module('numpy')
  26. if not np:
  27. class numpy_ndarray:
  28. pass
  29. else:
  30. numpy_ndarray = np.ndarray # type: ignore
  31. scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']})
  32. if not scipy:
  33. class scipy_sparse_matrix:
  34. pass
  35. sparse = None
  36. else:
  37. sparse = scipy.sparse
  38. scipy_sparse_matrix = sparse.spmatrix # type: ignore
  39. def sympy_to_numpy(m, **options):
  40. """Convert a SymPy Matrix/complex number to a numpy matrix or scalar."""
  41. if not np:
  42. raise ImportError
  43. dtype = options.get('dtype', 'complex')
  44. if isinstance(m, MatrixBase):
  45. return np.matrix(m.tolist(), dtype=dtype)
  46. elif isinstance(m, Expr):
  47. if m.is_Number or m.is_NumberSymbol or m == I:
  48. return complex(m)
  49. raise TypeError('Expected MatrixBase or complex scalar, got: %r' % m)
  50. def sympy_to_scipy_sparse(m, **options):
  51. """Convert a SymPy Matrix/complex number to a numpy matrix or scalar."""
  52. if not np or not sparse:
  53. raise ImportError
  54. dtype = options.get('dtype', 'complex')
  55. if isinstance(m, MatrixBase):
  56. return sparse.csr_matrix(np.matrix(m.tolist(), dtype=dtype))
  57. elif isinstance(m, Expr):
  58. if m.is_Number or m.is_NumberSymbol or m == I:
  59. return complex(m)
  60. raise TypeError('Expected MatrixBase or complex scalar, got: %r' % m)
  61. def scipy_sparse_to_sympy(m, **options):
  62. """Convert a scipy.sparse matrix to a SymPy matrix."""
  63. return MatrixBase(m.todense())
  64. def numpy_to_sympy(m, **options):
  65. """Convert a numpy matrix to a SymPy matrix."""
  66. return MatrixBase(m)
  67. def to_sympy(m, **options):
  68. """Convert a numpy/scipy.sparse matrix to a SymPy matrix."""
  69. if isinstance(m, MatrixBase):
  70. return m
  71. elif isinstance(m, numpy_ndarray):
  72. return numpy_to_sympy(m)
  73. elif isinstance(m, scipy_sparse_matrix):
  74. return scipy_sparse_to_sympy(m)
  75. elif isinstance(m, Expr):
  76. return m
  77. raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m)
  78. def to_numpy(m, **options):
  79. """Convert a sympy/scipy.sparse matrix to a numpy matrix."""
  80. dtype = options.get('dtype', 'complex')
  81. if isinstance(m, (MatrixBase, Expr)):
  82. return sympy_to_numpy(m, dtype=dtype)
  83. elif isinstance(m, numpy_ndarray):
  84. return m
  85. elif isinstance(m, scipy_sparse_matrix):
  86. return m.todense()
  87. raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m)
  88. def to_scipy_sparse(m, **options):
  89. """Convert a sympy/numpy matrix to a scipy.sparse matrix."""
  90. dtype = options.get('dtype', 'complex')
  91. if isinstance(m, (MatrixBase, Expr)):
  92. return sympy_to_scipy_sparse(m, dtype=dtype)
  93. elif isinstance(m, numpy_ndarray):
  94. if not sparse:
  95. raise ImportError
  96. return sparse.csr_matrix(m)
  97. elif isinstance(m, scipy_sparse_matrix):
  98. return m
  99. raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m)
  100. def flatten_scalar(e):
  101. """Flatten a 1x1 matrix to a scalar, return larger matrices unchanged."""
  102. if isinstance(e, MatrixBase):
  103. if e.shape == (1, 1):
  104. e = e[0]
  105. if isinstance(e, (numpy_ndarray, scipy_sparse_matrix)):
  106. if e.shape == (1, 1):
  107. e = complex(e[0, 0])
  108. return e
  109. def matrix_dagger(e):
  110. """Return the dagger of a sympy/numpy/scipy.sparse matrix."""
  111. if isinstance(e, MatrixBase):
  112. return e.H
  113. elif isinstance(e, (numpy_ndarray, scipy_sparse_matrix)):
  114. return e.conjugate().transpose()
  115. raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % e)
  116. # TODO: Move this into sympy.matricies.
  117. def _sympy_tensor_product(*matrices):
  118. """Compute the kronecker product of a sequence of SymPy Matrices.
  119. """
  120. from sympy.matrices.expressions.kronecker import matrix_kronecker_product
  121. return matrix_kronecker_product(*matrices)
  122. def _numpy_tensor_product(*product):
  123. """numpy version of tensor product of multiple arguments."""
  124. if not np:
  125. raise ImportError
  126. answer = product[0]
  127. for item in product[1:]:
  128. answer = np.kron(answer, item)
  129. return answer
  130. def _scipy_sparse_tensor_product(*product):
  131. """scipy.sparse version of tensor product of multiple arguments."""
  132. if not sparse:
  133. raise ImportError
  134. answer = product[0]
  135. for item in product[1:]:
  136. answer = sparse.kron(answer, item)
  137. # The final matrices will just be multiplied, so csr is a good final
  138. # sparse format.
  139. return sparse.csr_matrix(answer)
  140. def matrix_tensor_product(*product):
  141. """Compute the matrix tensor product of sympy/numpy/scipy.sparse matrices."""
  142. if isinstance(product[0], MatrixBase):
  143. return _sympy_tensor_product(*product)
  144. elif isinstance(product[0], numpy_ndarray):
  145. return _numpy_tensor_product(*product)
  146. elif isinstance(product[0], scipy_sparse_matrix):
  147. return _scipy_sparse_tensor_product(*product)
  148. def _numpy_eye(n):
  149. """numpy version of complex eye."""
  150. if not np:
  151. raise ImportError
  152. return np.matrix(np.eye(n, dtype='complex'))
  153. def _scipy_sparse_eye(n):
  154. """scipy.sparse version of complex eye."""
  155. if not sparse:
  156. raise ImportError
  157. return sparse.eye(n, n, dtype='complex')
  158. def matrix_eye(n, **options):
  159. """Get the version of eye and tensor_product for a given format."""
  160. format = options.get('format', 'sympy')
  161. if format == 'sympy':
  162. return eye(n)
  163. elif format == 'numpy':
  164. return _numpy_eye(n)
  165. elif format == 'scipy.sparse':
  166. return _scipy_sparse_eye(n)
  167. raise NotImplementedError('Invalid format: %r' % format)
  168. def _numpy_zeros(m, n, **options):
  169. """numpy version of zeros."""
  170. dtype = options.get('dtype', 'float64')
  171. if not np:
  172. raise ImportError
  173. return np.zeros((m, n), dtype=dtype)
  174. def _scipy_sparse_zeros(m, n, **options):
  175. """scipy.sparse version of zeros."""
  176. spmatrix = options.get('spmatrix', 'csr')
  177. dtype = options.get('dtype', 'float64')
  178. if not sparse:
  179. raise ImportError
  180. if spmatrix == 'lil':
  181. return sparse.lil_matrix((m, n), dtype=dtype)
  182. elif spmatrix == 'csr':
  183. return sparse.csr_matrix((m, n), dtype=dtype)
  184. def matrix_zeros(m, n, **options):
  185. """"Get a zeros matrix for a given format."""
  186. format = options.get('format', 'sympy')
  187. if format == 'sympy':
  188. return zeros(m, n)
  189. elif format == 'numpy':
  190. return _numpy_zeros(m, n, **options)
  191. elif format == 'scipy.sparse':
  192. return _scipy_sparse_zeros(m, n, **options)
  193. raise NotImplementedError('Invaild format: %r' % format)
  194. def _numpy_matrix_to_zero(e):
  195. """Convert a numpy zero matrix to the zero scalar."""
  196. if not np:
  197. raise ImportError
  198. test = np.zeros_like(e)
  199. if np.allclose(e, test):
  200. return 0.0
  201. else:
  202. return e
  203. def _scipy_sparse_matrix_to_zero(e):
  204. """Convert a scipy.sparse zero matrix to the zero scalar."""
  205. if not np:
  206. raise ImportError
  207. edense = e.todense()
  208. test = np.zeros_like(edense)
  209. if np.allclose(edense, test):
  210. return 0.0
  211. else:
  212. return e
  213. def matrix_to_zero(e):
  214. """Convert a zero matrix to the scalar zero."""
  215. if isinstance(e, MatrixBase):
  216. if zeros(*e.shape) == e:
  217. e = S.Zero
  218. elif isinstance(e, numpy_ndarray):
  219. e = _numpy_matrix_to_zero(e)
  220. elif isinstance(e, scipy_sparse_matrix):
  221. e = _scipy_sparse_matrix_to_zero(e)
  222. return e