reductions.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. from types import FunctionType
  2. from sympy.simplify.simplify import (
  3. simplify as _simplify, dotprodsimp as _dotprodsimp)
  4. from .utilities import _get_intermediate_simp, _iszero
  5. from .determinant import _find_reasonable_pivot
  6. def _row_reduce_list(mat, rows, cols, one, iszerofunc, simpfunc,
  7. normalize_last=True, normalize=True, zero_above=True):
  8. """Row reduce a flat list representation of a matrix and return a tuple
  9. (rref_matrix, pivot_cols, swaps) where ``rref_matrix`` is a flat list,
  10. ``pivot_cols`` are the pivot columns and ``swaps`` are any row swaps that
  11. were used in the process of row reduction.
  12. Parameters
  13. ==========
  14. mat : list
  15. list of matrix elements, must be ``rows`` * ``cols`` in length
  16. rows, cols : integer
  17. number of rows and columns in flat list representation
  18. one : SymPy object
  19. represents the value one, from ``Matrix.one``
  20. iszerofunc : determines if an entry can be used as a pivot
  21. simpfunc : used to simplify elements and test if they are
  22. zero if ``iszerofunc`` returns `None`
  23. normalize_last : indicates where all row reduction should
  24. happen in a fraction-free manner and then the rows are
  25. normalized (so that the pivots are 1), or whether
  26. rows should be normalized along the way (like the naive
  27. row reduction algorithm)
  28. normalize : whether pivot rows should be normalized so that
  29. the pivot value is 1
  30. zero_above : whether entries above the pivot should be zeroed.
  31. If ``zero_above=False``, an echelon matrix will be returned.
  32. """
  33. def get_col(i):
  34. return mat[i::cols]
  35. def row_swap(i, j):
  36. mat[i*cols:(i + 1)*cols], mat[j*cols:(j + 1)*cols] = \
  37. mat[j*cols:(j + 1)*cols], mat[i*cols:(i + 1)*cols]
  38. def cross_cancel(a, i, b, j):
  39. """Does the row op row[i] = a*row[i] - b*row[j]"""
  40. q = (j - i)*cols
  41. for p in range(i*cols, (i + 1)*cols):
  42. mat[p] = isimp(a*mat[p] - b*mat[p + q])
  43. isimp = _get_intermediate_simp(_dotprodsimp)
  44. piv_row, piv_col = 0, 0
  45. pivot_cols = []
  46. swaps = []
  47. # use a fraction free method to zero above and below each pivot
  48. while piv_col < cols and piv_row < rows:
  49. pivot_offset, pivot_val, \
  50. assumed_nonzero, newly_determined = _find_reasonable_pivot(
  51. get_col(piv_col)[piv_row:], iszerofunc, simpfunc)
  52. # _find_reasonable_pivot may have simplified some things
  53. # in the process. Let's not let them go to waste
  54. for (offset, val) in newly_determined:
  55. offset += piv_row
  56. mat[offset*cols + piv_col] = val
  57. if pivot_offset is None:
  58. piv_col += 1
  59. continue
  60. pivot_cols.append(piv_col)
  61. if pivot_offset != 0:
  62. row_swap(piv_row, pivot_offset + piv_row)
  63. swaps.append((piv_row, pivot_offset + piv_row))
  64. # if we aren't normalizing last, we normalize
  65. # before we zero the other rows
  66. if normalize_last is False:
  67. i, j = piv_row, piv_col
  68. mat[i*cols + j] = one
  69. for p in range(i*cols + j + 1, (i + 1)*cols):
  70. mat[p] = isimp(mat[p] / pivot_val)
  71. # after normalizing, the pivot value is 1
  72. pivot_val = one
  73. # zero above and below the pivot
  74. for row in range(rows):
  75. # don't zero our current row
  76. if row == piv_row:
  77. continue
  78. # don't zero above the pivot unless we're told.
  79. if zero_above is False and row < piv_row:
  80. continue
  81. # if we're already a zero, don't do anything
  82. val = mat[row*cols + piv_col]
  83. if iszerofunc(val):
  84. continue
  85. cross_cancel(pivot_val, row, val, piv_row)
  86. piv_row += 1
  87. # normalize each row
  88. if normalize_last is True and normalize is True:
  89. for piv_i, piv_j in enumerate(pivot_cols):
  90. pivot_val = mat[piv_i*cols + piv_j]
  91. mat[piv_i*cols + piv_j] = one
  92. for p in range(piv_i*cols + piv_j + 1, (piv_i + 1)*cols):
  93. mat[p] = isimp(mat[p] / pivot_val)
  94. return mat, tuple(pivot_cols), tuple(swaps)
  95. # This functions is a candidate for caching if it gets implemented for matrices.
  96. def _row_reduce(M, iszerofunc, simpfunc, normalize_last=True,
  97. normalize=True, zero_above=True):
  98. mat, pivot_cols, swaps = _row_reduce_list(list(M), M.rows, M.cols, M.one,
  99. iszerofunc, simpfunc, normalize_last=normalize_last,
  100. normalize=normalize, zero_above=zero_above)
  101. return M._new(M.rows, M.cols, mat), pivot_cols, swaps
  102. def _is_echelon(M, iszerofunc=_iszero):
  103. """Returns `True` if the matrix is in echelon form. That is, all rows of
  104. zeros are at the bottom, and below each leading non-zero in a row are
  105. exclusively zeros."""
  106. if M.rows <= 0 or M.cols <= 0:
  107. return True
  108. zeros_below = all(iszerofunc(t) for t in M[1:, 0])
  109. if iszerofunc(M[0, 0]):
  110. return zeros_below and _is_echelon(M[:, 1:], iszerofunc)
  111. return zeros_below and _is_echelon(M[1:, 1:], iszerofunc)
  112. def _echelon_form(M, iszerofunc=_iszero, simplify=False, with_pivots=False):
  113. """Returns a matrix row-equivalent to ``M`` that is in echelon form. Note
  114. that echelon form of a matrix is *not* unique, however, properties like the
  115. row space and the null space are preserved.
  116. Examples
  117. ========
  118. >>> from sympy import Matrix
  119. >>> M = Matrix([[1, 2], [3, 4]])
  120. >>> M.echelon_form()
  121. Matrix([
  122. [1, 2],
  123. [0, -2]])
  124. """
  125. simpfunc = simplify if isinstance(simplify, FunctionType) else _simplify
  126. mat, pivots, _ = _row_reduce(M, iszerofunc, simpfunc,
  127. normalize_last=True, normalize=False, zero_above=False)
  128. if with_pivots:
  129. return mat, pivots
  130. return mat
  131. # This functions is a candidate for caching if it gets implemented for matrices.
  132. def _rank(M, iszerofunc=_iszero, simplify=False):
  133. """Returns the rank of a matrix.
  134. Examples
  135. ========
  136. >>> from sympy import Matrix
  137. >>> from sympy.abc import x
  138. >>> m = Matrix([[1, 2], [x, 1 - 1/x]])
  139. >>> m.rank()
  140. 2
  141. >>> n = Matrix(3, 3, range(1, 10))
  142. >>> n.rank()
  143. 2
  144. """
  145. def _permute_complexity_right(M, iszerofunc):
  146. """Permute columns with complicated elements as
  147. far right as they can go. Since the ``sympy`` row reduction
  148. algorithms start on the left, having complexity right-shifted
  149. speeds things up.
  150. Returns a tuple (mat, perm) where perm is a permutation
  151. of the columns to perform to shift the complex columns right, and mat
  152. is the permuted matrix."""
  153. def complexity(i):
  154. # the complexity of a column will be judged by how many
  155. # element's zero-ness cannot be determined
  156. return sum(1 if iszerofunc(e) is None else 0 for e in M[:, i])
  157. complex = [(complexity(i), i) for i in range(M.cols)]
  158. perm = [j for (i, j) in sorted(complex)]
  159. return (M.permute(perm, orientation='cols'), perm)
  160. simpfunc = simplify if isinstance(simplify, FunctionType) else _simplify
  161. # for small matrices, we compute the rank explicitly
  162. # if is_zero on elements doesn't answer the question
  163. # for small matrices, we fall back to the full routine.
  164. if M.rows <= 0 or M.cols <= 0:
  165. return 0
  166. if M.rows <= 1 or M.cols <= 1:
  167. zeros = [iszerofunc(x) for x in M]
  168. if False in zeros:
  169. return 1
  170. if M.rows == 2 and M.cols == 2:
  171. zeros = [iszerofunc(x) for x in M]
  172. if False not in zeros and None not in zeros:
  173. return 0
  174. d = M.det()
  175. if iszerofunc(d) and False in zeros:
  176. return 1
  177. if iszerofunc(d) is False:
  178. return 2
  179. mat, _ = _permute_complexity_right(M, iszerofunc=iszerofunc)
  180. _, pivots, _ = _row_reduce(mat, iszerofunc, simpfunc, normalize_last=True,
  181. normalize=False, zero_above=False)
  182. return len(pivots)
  183. def _rref(M, iszerofunc=_iszero, simplify=False, pivots=True,
  184. normalize_last=True):
  185. """Return reduced row-echelon form of matrix and indices of pivot vars.
  186. Parameters
  187. ==========
  188. iszerofunc : Function
  189. A function used for detecting whether an element can
  190. act as a pivot. ``lambda x: x.is_zero`` is used by default.
  191. simplify : Function
  192. A function used to simplify elements when looking for a pivot.
  193. By default SymPy's ``simplify`` is used.
  194. pivots : True or False
  195. If ``True``, a tuple containing the row-reduced matrix and a tuple
  196. of pivot columns is returned. If ``False`` just the row-reduced
  197. matrix is returned.
  198. normalize_last : True or False
  199. If ``True``, no pivots are normalized to `1` until after all
  200. entries above and below each pivot are zeroed. This means the row
  201. reduction algorithm is fraction free until the very last step.
  202. If ``False``, the naive row reduction procedure is used where
  203. each pivot is normalized to be `1` before row operations are
  204. used to zero above and below the pivot.
  205. Examples
  206. ========
  207. >>> from sympy import Matrix
  208. >>> from sympy.abc import x
  209. >>> m = Matrix([[1, 2], [x, 1 - 1/x]])
  210. >>> m.rref()
  211. (Matrix([
  212. [1, 0],
  213. [0, 1]]), (0, 1))
  214. >>> rref_matrix, rref_pivots = m.rref()
  215. >>> rref_matrix
  216. Matrix([
  217. [1, 0],
  218. [0, 1]])
  219. >>> rref_pivots
  220. (0, 1)
  221. Notes
  222. =====
  223. The default value of ``normalize_last=True`` can provide significant
  224. speedup to row reduction, especially on matrices with symbols. However,
  225. if you depend on the form row reduction algorithm leaves entries
  226. of the matrix, set ``noramlize_last=False``
  227. """
  228. simpfunc = simplify if isinstance(simplify, FunctionType) else _simplify
  229. mat, pivot_cols, _ = _row_reduce(M, iszerofunc, simpfunc,
  230. normalize_last, normalize=True, zero_above=True)
  231. if pivots:
  232. mat = (mat, pivot_cols)
  233. return mat