arrayterator.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. """
  2. A buffered iterator for big arrays.
  3. This module solves the problem of iterating over a big file-based array
  4. without having to read it into memory. The `Arrayterator` class wraps
  5. an array object, and when iterated it will return sub-arrays with at most
  6. a user-specified number of elements.
  7. """
  8. from operator import mul
  9. from functools import reduce
  10. __all__ = ['Arrayterator']
  11. class Arrayterator:
  12. """
  13. Buffered iterator for big arrays.
  14. `Arrayterator` creates a buffered iterator for reading big arrays in small
  15. contiguous blocks. The class is useful for objects stored in the
  16. file system. It allows iteration over the object *without* reading
  17. everything in memory; instead, small blocks are read and iterated over.
  18. `Arrayterator` can be used with any object that supports multidimensional
  19. slices. This includes NumPy arrays, but also variables from
  20. Scientific.IO.NetCDF or pynetcdf for example.
  21. Parameters
  22. ----------
  23. var : array_like
  24. The object to iterate over.
  25. buf_size : int, optional
  26. The buffer size. If `buf_size` is supplied, the maximum amount of
  27. data that will be read into memory is `buf_size` elements.
  28. Default is None, which will read as many element as possible
  29. into memory.
  30. Attributes
  31. ----------
  32. var
  33. buf_size
  34. start
  35. stop
  36. step
  37. shape
  38. flat
  39. See Also
  40. --------
  41. ndenumerate : Multidimensional array iterator.
  42. flatiter : Flat array iterator.
  43. memmap : Create a memory-map to an array stored in a binary file on disk.
  44. Notes
  45. -----
  46. The algorithm works by first finding a "running dimension", along which
  47. the blocks will be extracted. Given an array of dimensions
  48. ``(d1, d2, ..., dn)``, e.g. if `buf_size` is smaller than ``d1``, the
  49. first dimension will be used. If, on the other hand,
  50. ``d1 < buf_size < d1*d2`` the second dimension will be used, and so on.
  51. Blocks are extracted along this dimension, and when the last block is
  52. returned the process continues from the next dimension, until all
  53. elements have been read.
  54. Examples
  55. --------
  56. >>> a = np.arange(3 * 4 * 5 * 6).reshape(3, 4, 5, 6)
  57. >>> a_itor = np.lib.Arrayterator(a, 2)
  58. >>> a_itor.shape
  59. (3, 4, 5, 6)
  60. Now we can iterate over ``a_itor``, and it will return arrays of size
  61. two. Since `buf_size` was smaller than any dimension, the first
  62. dimension will be iterated over first:
  63. >>> for subarr in a_itor:
  64. ... if not subarr.all():
  65. ... print(subarr, subarr.shape) # doctest: +SKIP
  66. >>> # [[[[0 1]]]] (1, 1, 1, 2)
  67. """
  68. def __init__(self, var, buf_size=None):
  69. self.var = var
  70. self.buf_size = buf_size
  71. self.start = [0 for dim in var.shape]
  72. self.stop = [dim for dim in var.shape]
  73. self.step = [1 for dim in var.shape]
  74. def __getattr__(self, attr):
  75. return getattr(self.var, attr)
  76. def __getitem__(self, index):
  77. """
  78. Return a new arrayterator.
  79. """
  80. # Fix index, handling ellipsis and incomplete slices.
  81. if not isinstance(index, tuple):
  82. index = (index,)
  83. fixed = []
  84. length, dims = len(index), self.ndim
  85. for slice_ in index:
  86. if slice_ is Ellipsis:
  87. fixed.extend([slice(None)] * (dims-length+1))
  88. length = len(fixed)
  89. elif isinstance(slice_, int):
  90. fixed.append(slice(slice_, slice_+1, 1))
  91. else:
  92. fixed.append(slice_)
  93. index = tuple(fixed)
  94. if len(index) < dims:
  95. index += (slice(None),) * (dims-len(index))
  96. # Return a new arrayterator object.
  97. out = self.__class__(self.var, self.buf_size)
  98. for i, (start, stop, step, slice_) in enumerate(
  99. zip(self.start, self.stop, self.step, index)):
  100. out.start[i] = start + (slice_.start or 0)
  101. out.step[i] = step * (slice_.step or 1)
  102. out.stop[i] = start + (slice_.stop or stop-start)
  103. out.stop[i] = min(stop, out.stop[i])
  104. return out
  105. def __array__(self):
  106. """
  107. Return corresponding data.
  108. """
  109. slice_ = tuple(slice(*t) for t in zip(
  110. self.start, self.stop, self.step))
  111. return self.var[slice_]
  112. @property
  113. def flat(self):
  114. """
  115. A 1-D flat iterator for Arrayterator objects.
  116. This iterator returns elements of the array to be iterated over in
  117. `Arrayterator` one by one. It is similar to `flatiter`.
  118. See Also
  119. --------
  120. Arrayterator
  121. flatiter
  122. Examples
  123. --------
  124. >>> a = np.arange(3 * 4 * 5 * 6).reshape(3, 4, 5, 6)
  125. >>> a_itor = np.lib.Arrayterator(a, 2)
  126. >>> for subarr in a_itor.flat:
  127. ... if not subarr:
  128. ... print(subarr, type(subarr))
  129. ...
  130. 0 <class 'numpy.int64'>
  131. """
  132. for block in self:
  133. yield from block.flat
  134. @property
  135. def shape(self):
  136. """
  137. The shape of the array to be iterated over.
  138. For an example, see `Arrayterator`.
  139. """
  140. return tuple(((stop-start-1)//step+1) for start, stop, step in
  141. zip(self.start, self.stop, self.step))
  142. def __iter__(self):
  143. # Skip arrays with degenerate dimensions
  144. if [dim for dim in self.shape if dim <= 0]:
  145. return
  146. start = self.start[:]
  147. stop = self.stop[:]
  148. step = self.step[:]
  149. ndims = self.var.ndim
  150. while True:
  151. count = self.buf_size or reduce(mul, self.shape)
  152. # iterate over each dimension, looking for the
  153. # running dimension (ie, the dimension along which
  154. # the blocks will be built from)
  155. rundim = 0
  156. for i in range(ndims-1, -1, -1):
  157. # if count is zero we ran out of elements to read
  158. # along higher dimensions, so we read only a single position
  159. if count == 0:
  160. stop[i] = start[i]+1
  161. elif count <= self.shape[i]:
  162. # limit along this dimension
  163. stop[i] = start[i] + count*step[i]
  164. rundim = i
  165. else:
  166. # read everything along this dimension
  167. stop[i] = self.stop[i]
  168. stop[i] = min(self.stop[i], stop[i])
  169. count = count//self.shape[i]
  170. # yield a block
  171. slice_ = tuple(slice(*t) for t in zip(start, stop, step))
  172. yield self.var[slice_]
  173. # Update start position, taking care of overflow to
  174. # other dimensions
  175. start[rundim] = stop[rundim] # start where we stopped
  176. for i in range(ndims-1, 0, -1):
  177. if start[i] >= self.stop[i]:
  178. start[i] = self.start[i]
  179. start[i-1] += self.step[i-1]
  180. if start[0] >= self.stop[0]:
  181. return