test_slicing.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. # This file is part of h5py, a Python interface to the HDF5 library.
  2. #
  3. # http://www.h5py.org
  4. #
  5. # Copyright 2008-2013 Andrew Collette and contributors
  6. #
  7. # License: Standard 3-clause BSD; see "license.txt" for full license terms
  8. # and contributor agreement.
  9. """
  10. Dataset slicing test module.
  11. Tests all supported slicing operations, including read/write and
  12. broadcasting operations. Does not test type conversion except for
  13. corner cases overlapping with slicing; for example, when selecting
  14. specific fields of a compound type.
  15. """
  16. import numpy as np
  17. from .common import ut, TestCase
  18. import h5py
  19. from h5py import h5s, h5t, h5d
  20. from h5py import File, MultiBlockSlice
  21. class BaseSlicing(TestCase):
  22. def setUp(self):
  23. self.f = File(self.mktemp(), 'w')
  24. def tearDown(self):
  25. if self.f:
  26. self.f.close()
  27. class TestSingleElement(BaseSlicing):
  28. """
  29. Feature: Retrieving a single element works with NumPy semantics
  30. """
  31. def test_single_index(self):
  32. """ Single-element selection with [index] yields array scalar """
  33. dset = self.f.create_dataset('x', (1,), dtype='i1')
  34. out = dset[0]
  35. self.assertIsInstance(out, np.int8)
  36. def test_single_null(self):
  37. """ Single-element selection with [()] yields ndarray """
  38. dset = self.f.create_dataset('x', (1,), dtype='i1')
  39. out = dset[()]
  40. self.assertIsInstance(out, np.ndarray)
  41. self.assertEqual(out.shape, (1,))
  42. def test_scalar_index(self):
  43. """ Slicing with [...] yields scalar ndarray """
  44. dset = self.f.create_dataset('x', shape=(), dtype='f')
  45. out = dset[...]
  46. self.assertIsInstance(out, np.ndarray)
  47. self.assertEqual(out.shape, ())
  48. def test_scalar_null(self):
  49. """ Slicing with [()] yields array scalar """
  50. dset = self.f.create_dataset('x', shape=(), dtype='i1')
  51. out = dset[()]
  52. self.assertIsInstance(out, np.int8)
  53. def test_compound(self):
  54. """ Compound scalar is numpy.void, not tuple (issue 135) """
  55. dt = np.dtype([('a','i4'),('b','f8')])
  56. v = np.ones((4,), dtype=dt)
  57. dset = self.f.create_dataset('foo', (4,), data=v)
  58. self.assertEqual(dset[0], v[0])
  59. self.assertIsInstance(dset[0], np.void)
  60. class TestObjectIndex(BaseSlicing):
  61. """
  62. Feature: numpy.object_ subtypes map to real Python objects
  63. """
  64. def test_reference(self):
  65. """ Indexing a reference dataset returns a h5py.Reference instance """
  66. dset = self.f.create_dataset('x', (1,), dtype=h5py.ref_dtype)
  67. dset[0] = self.f.ref
  68. self.assertEqual(type(dset[0]), h5py.Reference)
  69. def test_regref(self):
  70. """ Indexing a region reference dataset returns a h5py.RegionReference
  71. """
  72. dset1 = self.f.create_dataset('x', (10,10))
  73. regref = dset1.regionref[...]
  74. dset2 = self.f.create_dataset('y', (1,), dtype=h5py.regionref_dtype)
  75. dset2[0] = regref
  76. self.assertEqual(type(dset2[0]), h5py.RegionReference)
  77. def test_reference_field(self):
  78. """ Compound types of which a reference is an element work right """
  79. dt = np.dtype([('a', 'i'),('b', h5py.ref_dtype)])
  80. dset = self.f.create_dataset('x', (1,), dtype=dt)
  81. dset[0] = (42, self.f['/'].ref)
  82. out = dset[0]
  83. self.assertEqual(type(out[1]), h5py.Reference) # isinstance does NOT work
  84. def test_scalar(self):
  85. """ Indexing returns a real Python object on scalar datasets """
  86. dset = self.f.create_dataset('x', (), dtype=h5py.ref_dtype)
  87. dset[()] = self.f.ref
  88. self.assertEqual(type(dset[()]), h5py.Reference)
  89. def test_bytestr(self):
  90. """ Indexing a byte string dataset returns a real python byte string
  91. """
  92. dset = self.f.create_dataset('x', (1,), dtype=h5py.string_dtype(encoding='ascii'))
  93. dset[0] = b"Hello there!"
  94. self.assertEqual(type(dset[0]), bytes)
  95. class TestSimpleSlicing(TestCase):
  96. """
  97. Feature: Simple NumPy-style slices (start:stop:step) are supported.
  98. """
  99. def setUp(self):
  100. self.f = File(self.mktemp(), 'w')
  101. self.arr = np.arange(10)
  102. self.dset = self.f.create_dataset('x', data=self.arr)
  103. def tearDown(self):
  104. if self.f:
  105. self.f.close()
  106. def test_negative_stop(self):
  107. """ Negative stop indexes work as they do in NumPy """
  108. self.assertArrayEqual(self.dset[2:-2], self.arr[2:-2])
  109. def test_write(self):
  110. """Assigning to a 1D slice of a 2D dataset
  111. """
  112. dset = self.f.create_dataset('x2', (10, 2))
  113. x = np.zeros((10, 1))
  114. dset[:, 0] = x[:, 0]
  115. with self.assertRaises(TypeError):
  116. dset[:, 1] = x
  117. class TestArraySlicing(BaseSlicing):
  118. """
  119. Feature: Array types are handled appropriately
  120. """
  121. def test_read(self):
  122. """ Read arrays tack array dimensions onto end of shape tuple """
  123. dt = np.dtype('(3,)f8')
  124. dset = self.f.create_dataset('x',(10,),dtype=dt)
  125. self.assertEqual(dset.shape, (10,))
  126. self.assertEqual(dset.dtype, dt)
  127. # Full read
  128. out = dset[...]
  129. self.assertEqual(out.dtype, np.dtype('f8'))
  130. self.assertEqual(out.shape, (10,3))
  131. # Single element
  132. out = dset[0]
  133. self.assertEqual(out.dtype, np.dtype('f8'))
  134. self.assertEqual(out.shape, (3,))
  135. # Range
  136. out = dset[2:8:2]
  137. self.assertEqual(out.dtype, np.dtype('f8'))
  138. self.assertEqual(out.shape, (3,3))
  139. def test_write_broadcast(self):
  140. """ Array fill from constant is not supported (issue 211).
  141. """
  142. dt = np.dtype('(3,)i')
  143. dset = self.f.create_dataset('x', (10,), dtype=dt)
  144. with self.assertRaises(TypeError):
  145. dset[...] = 42
  146. def test_write_element(self):
  147. """ Write a single element to the array
  148. Issue 211.
  149. """
  150. dt = np.dtype('(3,)f8')
  151. dset = self.f.create_dataset('x', (10,), dtype=dt)
  152. data = np.array([1,2,3.0])
  153. dset[4] = data
  154. out = dset[4]
  155. self.assertTrue(np.all(out == data))
  156. def test_write_slices(self):
  157. """ Write slices to array type """
  158. dt = np.dtype('(3,)i')
  159. data1 = np.ones((2,), dtype=dt)
  160. data2 = np.ones((4,5), dtype=dt)
  161. dset = self.f.create_dataset('x', (10,9,11), dtype=dt)
  162. dset[0,0,2:4] = data1
  163. self.assertArrayEqual(dset[0,0,2:4], data1)
  164. dset[3, 1:5, 6:11] = data2
  165. self.assertArrayEqual(dset[3, 1:5, 6:11], data2)
  166. def test_roundtrip(self):
  167. """ Read the contents of an array and write them back
  168. Issue 211.
  169. """
  170. dt = np.dtype('(3,)f8')
  171. dset = self.f.create_dataset('x', (10,), dtype=dt)
  172. out = dset[...]
  173. dset[...] = out
  174. self.assertTrue(np.all(dset[...] == out))
  175. class TestZeroLengthSlicing(BaseSlicing):
  176. """
  177. Slices resulting in empty arrays
  178. """
  179. def test_slice_zero_length_dimension(self):
  180. """ Slice a dataset with a zero in its shape vector
  181. along the zero-length dimension """
  182. for i, shape in enumerate([(0,), (0, 3), (0, 2, 1)]):
  183. dset = self.f.create_dataset('x%d'%i, shape, dtype=int, maxshape=(None,)*len(shape))
  184. self.assertEqual(dset.shape, shape)
  185. out = dset[...]
  186. self.assertIsInstance(out, np.ndarray)
  187. self.assertEqual(out.shape, shape)
  188. out = dset[:]
  189. self.assertIsInstance(out, np.ndarray)
  190. self.assertEqual(out.shape, shape)
  191. if len(shape) > 1:
  192. out = dset[:, :1]
  193. self.assertIsInstance(out, np.ndarray)
  194. self.assertEqual(out.shape[:2], (0, 1))
  195. def test_slice_other_dimension(self):
  196. """ Slice a dataset with a zero in its shape vector
  197. along a non-zero-length dimension """
  198. for i, shape in enumerate([(3, 0), (1, 2, 0), (2, 0, 1)]):
  199. dset = self.f.create_dataset('x%d'%i, shape, dtype=int, maxshape=(None,)*len(shape))
  200. self.assertEqual(dset.shape, shape)
  201. out = dset[:1]
  202. self.assertIsInstance(out, np.ndarray)
  203. self.assertEqual(out.shape, (1,)+shape[1:])
  204. def test_slice_of_length_zero(self):
  205. """ Get a slice of length zero from a non-empty dataset """
  206. for i, shape in enumerate([(3,), (2, 2,), (2, 1, 5)]):
  207. dset = self.f.create_dataset('x%d'%i, data=np.zeros(shape, int), maxshape=(None,)*len(shape))
  208. self.assertEqual(dset.shape, shape)
  209. out = dset[1:1]
  210. self.assertIsInstance(out, np.ndarray)
  211. self.assertEqual(out.shape, (0,)+shape[1:])
  212. class TestFieldNames(BaseSlicing):
  213. """
  214. Field names for read & write
  215. """
  216. dt = np.dtype([('a', 'f'), ('b', 'i'), ('c', 'f4')])
  217. data = np.ones((100,), dtype=dt)
  218. def setUp(self):
  219. BaseSlicing.setUp(self)
  220. self.dset = self.f.create_dataset('x', (100,), dtype=self.dt)
  221. self.dset[...] = self.data
  222. def test_read(self):
  223. """ Test read with field selections """
  224. self.assertArrayEqual(self.dset['a'], self.data['a'])
  225. def test_unicode_names(self):
  226. """ Unicode field names for for read and write """
  227. self.assertArrayEqual(self.dset['a'], self.data['a'])
  228. self.dset['a'] = 42
  229. data = self.data.copy()
  230. data['a'] = 42
  231. self.assertArrayEqual(self.dset['a'], data['a'])
  232. def test_write(self):
  233. """ Test write with field selections """
  234. data2 = self.data.copy()
  235. data2['a'] *= 2
  236. self.dset['a'] = data2
  237. self.assertTrue(np.all(self.dset[...] == data2))
  238. data2['b'] *= 4
  239. self.dset['b'] = data2
  240. self.assertTrue(np.all(self.dset[...] == data2))
  241. data2['a'] *= 3
  242. data2['c'] *= 3
  243. self.dset['a','c'] = data2
  244. self.assertTrue(np.all(self.dset[...] == data2))
  245. def test_write_noncompound(self):
  246. """ Test write with non-compound source (single-field) """
  247. data2 = self.data.copy()
  248. data2['b'] = 1.0
  249. self.dset['b'] = 1.0
  250. self.assertTrue(np.all(self.dset[...] == data2))
  251. class TestMultiBlockSlice(BaseSlicing):
  252. def setUp(self):
  253. super(TestMultiBlockSlice, self).setUp()
  254. self.arr = np.arange(10)
  255. self.dset = self.f.create_dataset('x', data=self.arr)
  256. def test_default(self):
  257. # Default selects entire dataset as one block
  258. mbslice = MultiBlockSlice()
  259. self.assertEqual(mbslice.indices(10), (0, 1, 10, 1))
  260. np.testing.assert_array_equal(self.dset[mbslice], self.arr)
  261. def test_default_explicit(self):
  262. mbslice = MultiBlockSlice(start=0, count=10, stride=1, block=1)
  263. self.assertEqual(mbslice.indices(10), (0, 1, 10, 1))
  264. np.testing.assert_array_equal(self.dset[mbslice], self.arr)
  265. def test_start(self):
  266. mbslice = MultiBlockSlice(start=4)
  267. self.assertEqual(mbslice.indices(10), (4, 1, 6, 1))
  268. np.testing.assert_array_equal(self.dset[mbslice], np.array([4, 5, 6, 7, 8, 9]))
  269. def test_count(self):
  270. mbslice = MultiBlockSlice(count=7)
  271. self.assertEqual(mbslice.indices(10), (0, 1, 7, 1))
  272. np.testing.assert_array_equal(
  273. self.dset[mbslice], np.array([0, 1, 2, 3, 4, 5, 6])
  274. )
  275. def test_count_more_than_length_error(self):
  276. mbslice = MultiBlockSlice(count=11)
  277. with self.assertRaises(ValueError):
  278. mbslice.indices(10)
  279. def test_stride(self):
  280. mbslice = MultiBlockSlice(stride=2)
  281. self.assertEqual(mbslice.indices(10), (0, 2, 5, 1))
  282. np.testing.assert_array_equal(self.dset[mbslice], np.array([0, 2, 4, 6, 8]))
  283. def test_stride_zero_error(self):
  284. with self.assertRaises(ValueError):
  285. # This would cause a ZeroDivisionError if not caught
  286. MultiBlockSlice(stride=0, block=0).indices(10)
  287. def test_stride_block_equal(self):
  288. mbslice = MultiBlockSlice(stride=2, block=2)
  289. self.assertEqual(mbslice.indices(10), (0, 2, 5, 2))
  290. np.testing.assert_array_equal(self.dset[mbslice], self.arr)
  291. def test_block_more_than_stride_error(self):
  292. with self.assertRaises(ValueError):
  293. MultiBlockSlice(block=3)
  294. with self.assertRaises(ValueError):
  295. MultiBlockSlice(stride=2, block=3)
  296. def test_stride_more_than_block(self):
  297. mbslice = MultiBlockSlice(stride=3, block=2)
  298. self.assertEqual(mbslice.indices(10), (0, 3, 3, 2))
  299. np.testing.assert_array_equal(self.dset[mbslice], np.array([0, 1, 3, 4, 6, 7]))
  300. def test_block_overruns_extent_error(self):
  301. # If fully described then must fit within extent
  302. mbslice = MultiBlockSlice(start=2, count=2, stride=5, block=4)
  303. with self.assertRaises(ValueError):
  304. mbslice.indices(10)
  305. def test_fully_described(self):
  306. mbslice = MultiBlockSlice(start=1, count=2, stride=5, block=4)
  307. self.assertEqual(mbslice.indices(10), (1, 5, 2, 4))
  308. np.testing.assert_array_equal(
  309. self.dset[mbslice], np.array([1, 2, 3, 4, 6, 7, 8, 9])
  310. )
  311. def test_count_calculated(self):
  312. # If not given, count should be calculated to select as many full blocks as possible
  313. mbslice = MultiBlockSlice(start=1, stride=3, block=2)
  314. self.assertEqual(mbslice.indices(10), (1, 3, 3, 2))
  315. np.testing.assert_array_equal(self.dset[mbslice], np.array([1, 2, 4, 5, 7, 8]))
  316. def test_zero_count_calculated_error(self):
  317. # In this case, there is no possible count to select even one block, so error
  318. mbslice = MultiBlockSlice(start=8, stride=4, block=3)
  319. with self.assertRaises(ValueError):
  320. mbslice.indices(10)