permutation.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. from sympy.core import S
  2. from sympy.core.sympify import _sympify
  3. from sympy.functions import KroneckerDelta
  4. from .matexpr import MatrixExpr
  5. from .special import ZeroMatrix, Identity, OneMatrix
  6. class PermutationMatrix(MatrixExpr):
  7. """A Permutation Matrix
  8. Parameters
  9. ==========
  10. perm : Permutation
  11. The permutation the matrix uses.
  12. The size of the permutation determines the matrix size.
  13. See the documentation of
  14. :class:`sympy.combinatorics.permutations.Permutation` for
  15. the further information of how to create a permutation object.
  16. Examples
  17. ========
  18. >>> from sympy import Matrix, PermutationMatrix
  19. >>> from sympy.combinatorics import Permutation
  20. Creating a permutation matrix:
  21. >>> p = Permutation(1, 2, 0)
  22. >>> P = PermutationMatrix(p)
  23. >>> P = P.as_explicit()
  24. >>> P
  25. Matrix([
  26. [0, 1, 0],
  27. [0, 0, 1],
  28. [1, 0, 0]])
  29. Permuting a matrix row and column:
  30. >>> M = Matrix([0, 1, 2])
  31. >>> Matrix(P*M)
  32. Matrix([
  33. [1],
  34. [2],
  35. [0]])
  36. >>> Matrix(M.T*P)
  37. Matrix([[2, 0, 1]])
  38. See Also
  39. ========
  40. sympy.combinatorics.permutations.Permutation
  41. """
  42. def __new__(cls, perm):
  43. from sympy.combinatorics.permutations import Permutation
  44. perm = _sympify(perm)
  45. if not isinstance(perm, Permutation):
  46. raise ValueError(
  47. "{} must be a SymPy Permutation instance.".format(perm))
  48. return super().__new__(cls, perm)
  49. @property
  50. def shape(self):
  51. size = self.args[0].size
  52. return (size, size)
  53. @property
  54. def is_Identity(self):
  55. return self.args[0].is_Identity
  56. def doit(self):
  57. if self.is_Identity:
  58. return Identity(self.rows)
  59. return self
  60. def _entry(self, i, j, **kwargs):
  61. perm = self.args[0]
  62. return KroneckerDelta(perm.apply(i), j)
  63. def _eval_power(self, exp):
  64. return PermutationMatrix(self.args[0] ** exp).doit()
  65. def _eval_inverse(self):
  66. return PermutationMatrix(self.args[0] ** -1)
  67. _eval_transpose = _eval_adjoint = _eval_inverse
  68. def _eval_determinant(self):
  69. sign = self.args[0].signature()
  70. if sign == 1:
  71. return S.One
  72. elif sign == -1:
  73. return S.NegativeOne
  74. raise NotImplementedError
  75. def _eval_rewrite_as_BlockDiagMatrix(self, *args, **kwargs):
  76. from sympy.combinatorics.permutations import Permutation
  77. from .blockmatrix import BlockDiagMatrix
  78. perm = self.args[0]
  79. full_cyclic_form = perm.full_cyclic_form
  80. cycles_picks = []
  81. # Stage 1. Decompose the cycles into the blockable form.
  82. a, b, c = 0, 0, 0
  83. flag = False
  84. for cycle in full_cyclic_form:
  85. l = len(cycle)
  86. m = max(cycle)
  87. if not flag:
  88. if m + 1 > a + l:
  89. flag = True
  90. temp = [cycle]
  91. b = m
  92. c = l
  93. else:
  94. cycles_picks.append([cycle])
  95. a += l
  96. else:
  97. if m > b:
  98. if m + 1 == a + c + l:
  99. temp.append(cycle)
  100. cycles_picks.append(temp)
  101. flag = False
  102. a = m+1
  103. else:
  104. b = m
  105. temp.append(cycle)
  106. c += l
  107. else:
  108. if b + 1 == a + c + l:
  109. temp.append(cycle)
  110. cycles_picks.append(temp)
  111. flag = False
  112. a = b+1
  113. else:
  114. temp.append(cycle)
  115. c += l
  116. # Stage 2. Normalize each decomposed cycles and build matrix.
  117. p = 0
  118. args = []
  119. for pick in cycles_picks:
  120. new_cycles = []
  121. l = 0
  122. for cycle in pick:
  123. new_cycle = [i - p for i in cycle]
  124. new_cycles.append(new_cycle)
  125. l += len(cycle)
  126. p += l
  127. perm = Permutation(new_cycles)
  128. mat = PermutationMatrix(perm)
  129. args.append(mat)
  130. return BlockDiagMatrix(*args)
  131. class MatrixPermute(MatrixExpr):
  132. r"""Symbolic representation for permuting matrix rows or columns.
  133. Parameters
  134. ==========
  135. perm : Permutation, PermutationMatrix
  136. The permutation to use for permuting the matrix.
  137. The permutation can be resized to the suitable one,
  138. axis : 0 or 1
  139. The axis to permute alongside.
  140. If `0`, it will permute the matrix rows.
  141. If `1`, it will permute the matrix columns.
  142. Notes
  143. =====
  144. This follows the same notation used in
  145. :meth:`sympy.matrices.common.MatrixCommon.permute`.
  146. Examples
  147. ========
  148. >>> from sympy import Matrix, MatrixPermute
  149. >>> from sympy.combinatorics import Permutation
  150. Permuting the matrix rows:
  151. >>> p = Permutation(1, 2, 0)
  152. >>> A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  153. >>> B = MatrixPermute(A, p, axis=0)
  154. >>> B.as_explicit()
  155. Matrix([
  156. [4, 5, 6],
  157. [7, 8, 9],
  158. [1, 2, 3]])
  159. Permuting the matrix columns:
  160. >>> B = MatrixPermute(A, p, axis=1)
  161. >>> B.as_explicit()
  162. Matrix([
  163. [2, 3, 1],
  164. [5, 6, 4],
  165. [8, 9, 7]])
  166. See Also
  167. ========
  168. sympy.matrices.common.MatrixCommon.permute
  169. """
  170. def __new__(cls, mat, perm, axis=S.Zero):
  171. from sympy.combinatorics.permutations import Permutation
  172. mat = _sympify(mat)
  173. if not mat.is_Matrix:
  174. raise ValueError(
  175. "{} must be a SymPy matrix instance.".format(perm))
  176. perm = _sympify(perm)
  177. if isinstance(perm, PermutationMatrix):
  178. perm = perm.args[0]
  179. if not isinstance(perm, Permutation):
  180. raise ValueError(
  181. "{} must be a SymPy Permutation or a PermutationMatrix " \
  182. "instance".format(perm))
  183. axis = _sympify(axis)
  184. if axis not in (0, 1):
  185. raise ValueError("The axis must be 0 or 1.")
  186. mat_size = mat.shape[axis]
  187. if mat_size != perm.size:
  188. try:
  189. perm = perm.resize(mat_size)
  190. except ValueError:
  191. raise ValueError(
  192. "Size does not match between the permutation {} "
  193. "and the matrix {} threaded over the axis {} "
  194. "and cannot be converted."
  195. .format(perm, mat, axis))
  196. return super().__new__(cls, mat, perm, axis)
  197. def doit(self, deep=True):
  198. mat, perm, axis = self.args
  199. if deep:
  200. mat = mat.doit(deep=deep)
  201. perm = perm.doit(deep=deep)
  202. if perm.is_Identity:
  203. return mat
  204. if mat.is_Identity:
  205. if axis is S.Zero:
  206. return PermutationMatrix(perm)
  207. elif axis is S.One:
  208. return PermutationMatrix(perm**-1)
  209. if isinstance(mat, (ZeroMatrix, OneMatrix)):
  210. return mat
  211. if isinstance(mat, MatrixPermute) and mat.args[2] == axis:
  212. return MatrixPermute(mat.args[0], perm * mat.args[1], axis)
  213. return self
  214. @property
  215. def shape(self):
  216. return self.args[0].shape
  217. def _entry(self, i, j, **kwargs):
  218. mat, perm, axis = self.args
  219. if axis == 0:
  220. return mat[perm.apply(i), j]
  221. elif axis == 1:
  222. return mat[i, perm.apply(j)]
  223. def _eval_rewrite_as_MatMul(self, *args, **kwargs):
  224. from .matmul import MatMul
  225. mat, perm, axis = self.args
  226. deep = kwargs.get("deep", True)
  227. if deep:
  228. mat = mat.rewrite(MatMul)
  229. if axis == 0:
  230. return MatMul(PermutationMatrix(perm), mat)
  231. elif axis == 1:
  232. return MatMul(mat, PermutationMatrix(perm**-1))