common.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # This file is part of h5py, a Python interface to the HDF5 library.
  2. #
  3. # http://www.h5py.org
  4. #
  5. # Copyright 2008-2013 Andrew Collette and contributors
  6. #
  7. # License: Standard 3-clause BSD; see "license.txt" for full license terms
  8. # and contributor agreement.
  9. import sys
  10. import os
  11. import shutil
  12. import inspect
  13. import tempfile
  14. import subprocess
  15. from contextlib import contextmanager
  16. from functools import wraps
  17. import numpy as np
  18. import h5py
  19. import unittest as ut
  20. # Check if non-ascii filenames are supported
  21. # Evidently this is the most reliable way to check
  22. # See also h5py issue #263 and ipython #466
  23. # To test for this, run the testsuite with LC_ALL=C
  24. try:
  25. testfile, fname = tempfile.mkstemp(chr(0x03b7))
  26. except UnicodeError:
  27. UNICODE_FILENAMES = False
  28. else:
  29. UNICODE_FILENAMES = True
  30. os.close(testfile)
  31. os.unlink(fname)
  32. del fname
  33. del testfile
  34. class TestCase(ut.TestCase):
  35. """
  36. Base class for unit tests.
  37. """
  38. @classmethod
  39. def setUpClass(cls):
  40. cls.tempdir = tempfile.mkdtemp(prefix='h5py-test_')
  41. @classmethod
  42. def tearDownClass(cls):
  43. shutil.rmtree(cls.tempdir)
  44. def mktemp(self, suffix='.hdf5', prefix='', dir=None):
  45. if dir is None:
  46. dir = self.tempdir
  47. return tempfile.mktemp(suffix, prefix, dir=dir)
  48. def mktemp_mpi(self, comm=None, suffix='.hdf5', prefix='', dir=None):
  49. if comm is None:
  50. from mpi4py import MPI
  51. comm = MPI.COMM_WORLD
  52. fname = None
  53. if comm.Get_rank() == 0:
  54. fname = self.mktemp(suffix, prefix, dir)
  55. fname = comm.bcast(fname, 0)
  56. return fname
  57. def setUp(self):
  58. self.f = h5py.File(self.mktemp(), 'w')
  59. def tearDown(self):
  60. try:
  61. if self.f:
  62. self.f.close()
  63. except:
  64. pass
  65. def assertSameElements(self, a, b):
  66. for x in a:
  67. match = False
  68. for y in b:
  69. if x == y:
  70. match = True
  71. if not match:
  72. raise AssertionError("Item '%s' appears in a but not b" % x)
  73. for x in b:
  74. match = False
  75. for y in a:
  76. if x == y:
  77. match = True
  78. if not match:
  79. raise AssertionError("Item '%s' appears in b but not a" % x)
  80. def assertArrayEqual(self, dset, arr, message=None, precision=None):
  81. """ Make sure dset and arr have the same shape, dtype and contents, to
  82. within the given precision.
  83. Note that dset may be a NumPy array or an HDF5 dataset.
  84. """
  85. if precision is None:
  86. precision = 1e-5
  87. if message is None:
  88. message = ''
  89. else:
  90. message = ' (%s)' % message
  91. if np.isscalar(dset) or np.isscalar(arr):
  92. assert np.isscalar(dset) and np.isscalar(arr), \
  93. 'Scalar/array mismatch ("%r" vs "%r")%s' % (dset, arr, message)
  94. assert dset - arr < precision, \
  95. "Scalars differ by more than %.3f%s" % (precision, message)
  96. return
  97. assert dset.shape == arr.shape, \
  98. "Shape mismatch (%s vs %s)%s" % (dset.shape, arr.shape, message)
  99. assert dset.dtype == arr.dtype, \
  100. "Dtype mismatch (%s vs %s)%s" % (dset.dtype, arr.dtype, message)
  101. if arr.dtype.names is not None:
  102. for n in arr.dtype.names:
  103. message = '[FIELD %s] %s' % (n, message)
  104. self.assertArrayEqual(dset[n], arr[n], message=message, precision=precision)
  105. elif arr.dtype.kind in ('i', 'f'):
  106. assert np.all(np.abs(dset[...] - arr[...]) < precision), \
  107. "Arrays differ by more than %.3f%s" % (precision, message)
  108. else:
  109. assert np.all(dset[...] == arr[...]), \
  110. "Arrays are not equal (dtype %s) %s" % (arr.dtype.str, message)
  111. def assertNumpyBehavior(self, dset, arr, s):
  112. """ Apply slicing arguments "s" to both dset and arr.
  113. Succeeds if the results of the slicing are identical, or the
  114. exception raised is of the same type for both.
  115. "arr" must be a Numpy array; "dset" may be a NumPy array or dataset.
  116. """
  117. exc = None
  118. try:
  119. arr_result = arr[s]
  120. except Exception as e:
  121. exc = type(e)
  122. if exc is None:
  123. self.assertArrayEqual(dset[s], arr_result)
  124. else:
  125. with self.assertRaises(exc):
  126. dset[s]
  127. NUMPY_RELEASE_VERSION = tuple([int(i) for i in np.__version__.split(".")[0:2]])
  128. @contextmanager
  129. def closed_tempfile(suffix='', text=None):
  130. """
  131. Context manager which yields the path to a closed temporary file with the
  132. suffix `suffix`. The file will be deleted on exiting the context. An
  133. additional argument `text` can be provided to have the file contain `text`.
  134. """
  135. with tempfile.NamedTemporaryFile(
  136. 'w+t', suffix=suffix, delete=False
  137. ) as test_file:
  138. file_name = test_file.name
  139. if text is not None:
  140. test_file.write(text)
  141. test_file.flush()
  142. yield file_name
  143. shutil.rmtree(file_name, ignore_errors=True)
  144. def insubprocess(f):
  145. """Runs a test in its own subprocess"""
  146. @wraps(f)
  147. def wrapper(request, *args, **kwargs):
  148. curr_test = inspect.getsourcefile(f) + "::" + request.node.name
  149. # get block around test name
  150. insub = "IN_SUBPROCESS_" + curr_test
  151. for c in "/\\,:.":
  152. insub = insub.replace(c, "_")
  153. defined = os.environ.get(insub, None)
  154. # fork process
  155. if defined:
  156. return f(request, *args, **kwargs)
  157. else:
  158. os.environ[insub] = '1'
  159. env = os.environ.copy()
  160. env[insub] = '1'
  161. env.update(getattr(f, 'subproc_env', {}))
  162. with closed_tempfile() as stdout:
  163. with open(stdout, 'w+t') as fh:
  164. rtn = subprocess.call([sys.executable, '-m', 'pytest', curr_test],
  165. stdout=fh, stderr=fh, env=env)
  166. with open(stdout, 'rt') as fh:
  167. out = fh.read()
  168. assert rtn == 0, "\n" + out
  169. return wrapper
  170. def subproc_env(d):
  171. """Set environment variables for the @insubprocess decorator"""
  172. def decorator(f):
  173. f.subproc_env = d
  174. return f
  175. return decorator