polymatrix.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. from sympy.core.expr import Expr
  2. from sympy.core.symbol import Dummy
  3. from sympy.core.sympify import _sympify
  4. from sympy.polys.polyerrors import CoercionFailed
  5. from sympy.polys.polytools import Poly, parallel_poly_from_expr
  6. from sympy.polys.domains import QQ
  7. from sympy.polys.matrices import DomainMatrix
  8. from sympy.polys.matrices.domainscalar import DomainScalar
  9. class MutablePolyDenseMatrix:
  10. """
  11. A mutable matrix of objects from poly module or to operate with them.
  12. Examples
  13. ========
  14. >>> from sympy.polys.polymatrix import PolyMatrix
  15. >>> from sympy import Symbol, Poly
  16. >>> x = Symbol('x')
  17. >>> pm1 = PolyMatrix([[Poly(x**2, x), Poly(-x, x)], [Poly(x**3, x), Poly(-1 + x, x)]])
  18. >>> v1 = PolyMatrix([[1, 0], [-1, 0]], x)
  19. >>> pm1*v1
  20. PolyMatrix([
  21. [ x**2 + x, 0],
  22. [x**3 - x + 1, 0]], ring=QQ[x])
  23. >>> pm1.ring
  24. ZZ[x]
  25. >>> v1*pm1
  26. PolyMatrix([
  27. [ x**2, -x],
  28. [-x**2, x]], ring=QQ[x])
  29. >>> pm2 = PolyMatrix([[Poly(x**2, x, domain='QQ'), Poly(0, x, domain='QQ'), Poly(1, x, domain='QQ'), \
  30. Poly(x**3, x, domain='QQ'), Poly(0, x, domain='QQ'), Poly(-x**3, x, domain='QQ')]])
  31. >>> v2 = PolyMatrix([1, 0, 0, 0, 0, 0], x)
  32. >>> v2.ring
  33. QQ[x]
  34. >>> pm2*v2
  35. PolyMatrix([[x**2]], ring=QQ[x])
  36. """
  37. def __new__(cls, *args, ring=None):
  38. if not args:
  39. # PolyMatrix(ring=QQ[x])
  40. if ring is None:
  41. raise TypeError("The ring needs to be specified for an empty PolyMatrix")
  42. rows, cols, items, gens = 0, 0, [], ()
  43. elif isinstance(args[0], list):
  44. elements, gens = args[0], args[1:]
  45. if not elements:
  46. # PolyMatrix([])
  47. rows, cols, items = 0, 0, []
  48. elif isinstance(elements[0], (list, tuple)):
  49. # PolyMatrix([[1, 2]], x)
  50. rows, cols = len(elements), len(elements[0])
  51. items = [e for row in elements for e in row]
  52. else:
  53. # PolyMatrix([1, 2], x)
  54. rows, cols = len(elements), 1
  55. items = elements
  56. elif [type(a) for a in args[:3]] == [int, int, list]:
  57. # PolyMatrix(2, 2, [1, 2, 3, 4], x)
  58. rows, cols, items, gens = args[0], args[1], args[2], args[3:]
  59. elif [type(a) for a in args[:3]] == [int, int, type(lambda: 0)]:
  60. # PolyMatrix(2, 2, lambda i, j: i+j, x)
  61. rows, cols, func, gens = args[0], args[1], args[2], args[3:]
  62. items = [func(i, j) for i in range(rows) for j in range(cols)]
  63. else:
  64. raise TypeError("Invalid arguments")
  65. # PolyMatrix([[1]], x, y) vs PolyMatrix([[1]], (x, y))
  66. if len(gens) == 1 and isinstance(gens[0], tuple):
  67. gens = gens[0]
  68. # gens is now a tuple (x, y)
  69. return cls.from_list(rows, cols, items, gens, ring)
  70. @classmethod
  71. def from_list(cls, rows, cols, items, gens, ring):
  72. # items can be Expr, Poly, or a mix of Expr and Poly
  73. items = [_sympify(item) for item in items]
  74. if items and all(isinstance(item, Poly) for item in items):
  75. polys = True
  76. else:
  77. polys = False
  78. # Identify the ring for the polys
  79. if ring is not None:
  80. # Parse a domain string like 'QQ[x]'
  81. if isinstance(ring, str):
  82. ring = Poly(0, Dummy(), domain=ring).domain
  83. elif polys:
  84. p = items[0]
  85. for p2 in items[1:]:
  86. p, _ = p.unify(p2)
  87. ring = p.domain[p.gens]
  88. else:
  89. items, info = parallel_poly_from_expr(items, gens, field=True)
  90. ring = info['domain'][info['gens']]
  91. polys = True
  92. # Efficiently convert when all elements are Poly
  93. if polys:
  94. p_ring = Poly(0, ring.symbols, domain=ring.domain)
  95. to_ring = ring.ring.from_list
  96. convert_poly = lambda p: to_ring(p.unify(p_ring)[0].rep.rep)
  97. elements = [convert_poly(p) for p in items]
  98. else:
  99. convert_expr = ring.from_sympy
  100. elements = [convert_expr(e.as_expr()) for e in items]
  101. # Convert to domain elements and construct DomainMatrix
  102. elements_lol = [[elements[i*cols + j] for j in range(cols)] for i in range(rows)]
  103. dm = DomainMatrix(elements_lol, (rows, cols), ring)
  104. return cls.from_dm(dm)
  105. @classmethod
  106. def from_dm(cls, dm):
  107. obj = super().__new__(cls)
  108. dm = dm.to_sparse()
  109. R = dm.domain
  110. obj._dm = dm
  111. obj.ring = R
  112. obj.domain = R.domain
  113. obj.gens = R.symbols
  114. return obj
  115. def to_Matrix(self):
  116. return self._dm.to_Matrix()
  117. @classmethod
  118. def from_Matrix(cls, other, *gens, ring=None):
  119. return cls(*other.shape, other.flat(), *gens, ring=ring)
  120. def set_gens(self, gens):
  121. return self.from_Matrix(self.to_Matrix(), gens)
  122. def __repr__(self):
  123. if self.rows * self.cols:
  124. return 'Poly' + repr(self.to_Matrix())[:-1] + f', ring={self.ring})'
  125. else:
  126. return f'PolyMatrix({self.rows}, {self.cols}, [], ring={self.ring})'
  127. @property
  128. def shape(self):
  129. return self._dm.shape
  130. @property
  131. def rows(self):
  132. return self.shape[0]
  133. @property
  134. def cols(self):
  135. return self.shape[1]
  136. def __len__(self):
  137. return self.rows * self.cols
  138. def __getitem__(self, key):
  139. def to_poly(v):
  140. ground = self._dm.domain.domain
  141. gens = self._dm.domain.symbols
  142. return Poly(v.to_dict(), gens, domain=ground)
  143. dm = self._dm
  144. if isinstance(key, slice):
  145. items = dm.flat()[key]
  146. return [to_poly(item) for item in items]
  147. elif isinstance(key, int):
  148. i, j = divmod(key, self.cols)
  149. e = dm[i,j]
  150. return to_poly(e.element)
  151. i, j = key
  152. if isinstance(i, int) and isinstance(j, int):
  153. return to_poly(dm[i, j].element)
  154. else:
  155. return self.from_dm(dm[i, j])
  156. def __eq__(self, other):
  157. if not isinstance(self, type(other)):
  158. return NotImplemented
  159. return self._dm == other._dm
  160. def __add__(self, other):
  161. if isinstance(other, type(self)):
  162. return self.from_dm(self._dm + other._dm)
  163. return NotImplemented
  164. def __sub__(self, other):
  165. if isinstance(other, type(self)):
  166. return self.from_dm(self._dm - other._dm)
  167. return NotImplemented
  168. def __mul__(self, other):
  169. if isinstance(other, type(self)):
  170. return self.from_dm(self._dm * other._dm)
  171. elif isinstance(other, int):
  172. other = _sympify(other)
  173. if isinstance(other, Expr):
  174. Kx = self.ring
  175. try:
  176. other_ds = DomainScalar(Kx.from_sympy(other), Kx)
  177. except (CoercionFailed, ValueError):
  178. other_ds = DomainScalar.from_sympy(other)
  179. return self.from_dm(self._dm * other_ds)
  180. return NotImplemented
  181. def __rmul__(self, other):
  182. if isinstance(other, int):
  183. other = _sympify(other)
  184. if isinstance(other, Expr):
  185. other_ds = DomainScalar.from_sympy(other)
  186. return self.from_dm(other_ds * self._dm)
  187. return NotImplemented
  188. def __truediv__(self, other):
  189. if isinstance(other, Poly):
  190. other = other.as_expr()
  191. elif isinstance(other, int):
  192. other = _sympify(other)
  193. if not isinstance(other, Expr):
  194. return NotImplemented
  195. other = self.domain.from_sympy(other)
  196. inverse = self.ring.convert_from(1/other, self.domain)
  197. inverse = DomainScalar(inverse, self.ring)
  198. dm = self._dm * inverse
  199. return self.from_dm(dm)
  200. def __neg__(self):
  201. return self.from_dm(-self._dm)
  202. def transpose(self):
  203. return self.from_dm(self._dm.transpose())
  204. def row_join(self, other):
  205. dm = DomainMatrix.hstack(self._dm, other._dm)
  206. return self.from_dm(dm)
  207. def col_join(self, other):
  208. dm = DomainMatrix.vstack(self._dm, other._dm)
  209. return self.from_dm(dm)
  210. def applyfunc(self, func):
  211. M = self.to_Matrix().applyfunc(func)
  212. return self.from_Matrix(M, self.gens)
  213. @classmethod
  214. def eye(cls, n, gens):
  215. return cls.from_dm(DomainMatrix.eye(n, QQ[gens]))
  216. @classmethod
  217. def zeros(cls, m, n, gens):
  218. return cls.from_dm(DomainMatrix.zeros((m, n), QQ[gens]))
  219. def rref(self, simplify='ignore', normalize_last='ignore'):
  220. # If this is K[x] then computes RREF in ground field K.
  221. if not (self.domain.is_Field and all(p.is_ground for p in self)):
  222. raise ValueError("PolyMatrix rref is only for ground field elements")
  223. dm = self._dm
  224. dm_ground = dm.convert_to(dm.domain.domain)
  225. dm_rref, pivots = dm_ground.rref()
  226. dm_rref = dm_rref.convert_to(dm.domain)
  227. return self.from_dm(dm_rref), pivots
  228. def nullspace(self):
  229. # If this is K[x] then computes nullspace in ground field K.
  230. if not (self.domain.is_Field and all(p.is_ground for p in self)):
  231. raise ValueError("PolyMatrix nullspace is only for ground field elements")
  232. dm = self._dm
  233. K, Kx = self.domain, self.ring
  234. dm_null_rows = dm.convert_to(K).nullspace().convert_to(Kx)
  235. dm_null = dm_null_rows.transpose()
  236. dm_basis = [dm_null[:,i] for i in range(dm_null.shape[1])]
  237. return [self.from_dm(dmvec) for dmvec in dm_basis]
  238. def rank(self):
  239. return self.cols - len(self.nullspace())
  240. MutablePolyMatrix = PolyMatrix = MutablePolyDenseMatrix