slice.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. from sympy.matrices.expressions.matexpr import MatrixExpr
  2. from sympy.core.basic import Basic
  3. from sympy.core.containers import Tuple
  4. from sympy.functions.elementary.integers import floor
  5. def normalize(i, parentsize):
  6. if isinstance(i, slice):
  7. i = (i.start, i.stop, i.step)
  8. if not isinstance(i, (tuple, list, Tuple)):
  9. if (i < 0) == True:
  10. i += parentsize
  11. i = (i, i+1, 1)
  12. i = list(i)
  13. if len(i) == 2:
  14. i.append(1)
  15. start, stop, step = i
  16. start = start or 0
  17. if stop is None:
  18. stop = parentsize
  19. if (start < 0) == True:
  20. start += parentsize
  21. if (stop < 0) == True:
  22. stop += parentsize
  23. step = step or 1
  24. if ((stop - start) * step < 1) == True:
  25. raise IndexError()
  26. return (start, stop, step)
  27. class MatrixSlice(MatrixExpr):
  28. """ A MatrixSlice of a Matrix Expression
  29. Examples
  30. ========
  31. >>> from sympy import MatrixSlice, ImmutableMatrix
  32. >>> M = ImmutableMatrix(4, 4, range(16))
  33. >>> M
  34. Matrix([
  35. [ 0, 1, 2, 3],
  36. [ 4, 5, 6, 7],
  37. [ 8, 9, 10, 11],
  38. [12, 13, 14, 15]])
  39. >>> B = MatrixSlice(M, (0, 2), (2, 4))
  40. >>> ImmutableMatrix(B)
  41. Matrix([
  42. [2, 3],
  43. [6, 7]])
  44. """
  45. parent = property(lambda self: self.args[0])
  46. rowslice = property(lambda self: self.args[1])
  47. colslice = property(lambda self: self.args[2])
  48. def __new__(cls, parent, rowslice, colslice):
  49. rowslice = normalize(rowslice, parent.shape[0])
  50. colslice = normalize(colslice, parent.shape[1])
  51. if not (len(rowslice) == len(colslice) == 3):
  52. raise IndexError()
  53. if ((0 > rowslice[0]) == True or
  54. (parent.shape[0] < rowslice[1]) == True or
  55. (0 > colslice[0]) == True or
  56. (parent.shape[1] < colslice[1]) == True):
  57. raise IndexError()
  58. if isinstance(parent, MatrixSlice):
  59. return mat_slice_of_slice(parent, rowslice, colslice)
  60. return Basic.__new__(cls, parent, Tuple(*rowslice), Tuple(*colslice))
  61. @property
  62. def shape(self):
  63. rows = self.rowslice[1] - self.rowslice[0]
  64. rows = rows if self.rowslice[2] == 1 else floor(rows/self.rowslice[2])
  65. cols = self.colslice[1] - self.colslice[0]
  66. cols = cols if self.colslice[2] == 1 else floor(cols/self.colslice[2])
  67. return rows, cols
  68. def _entry(self, i, j, **kwargs):
  69. return self.parent._entry(i*self.rowslice[2] + self.rowslice[0],
  70. j*self.colslice[2] + self.colslice[0],
  71. **kwargs)
  72. @property
  73. def on_diag(self):
  74. return self.rowslice == self.colslice
  75. def slice_of_slice(s, t):
  76. start1, stop1, step1 = s
  77. start2, stop2, step2 = t
  78. start = start1 + start2*step1
  79. step = step1 * step2
  80. stop = start1 + step1*stop2
  81. if stop > stop1:
  82. raise IndexError()
  83. return start, stop, step
  84. def mat_slice_of_slice(parent, rowslice, colslice):
  85. """ Collapse nested matrix slices
  86. >>> from sympy import MatrixSymbol
  87. >>> X = MatrixSymbol('X', 10, 10)
  88. >>> X[:, 1:5][5:8, :]
  89. X[5:8, 1:5]
  90. >>> X[1:9:2, 2:6][1:3, 2]
  91. X[3:7:2, 4:5]
  92. """
  93. row = slice_of_slice(parent.rowslice, rowslice)
  94. col = slice_of_slice(parent.colslice, colslice)
  95. return MatrixSlice(parent.parent, row, col)