contexts.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. from __future__ import annotations
  2. from contextlib import contextmanager
  3. import os
  4. from pathlib import Path
  5. import random
  6. from shutil import rmtree
  7. import string
  8. import tempfile
  9. from typing import (
  10. IO,
  11. Any,
  12. )
  13. import numpy as np
  14. from pandas.io.common import get_handle
  15. @contextmanager
  16. def decompress_file(path, compression):
  17. """
  18. Open a compressed file and return a file object.
  19. Parameters
  20. ----------
  21. path : str
  22. The path where the file is read from.
  23. compression : {'gzip', 'bz2', 'zip', 'xz', None}
  24. Name of the decompression to use
  25. Returns
  26. -------
  27. file object
  28. """
  29. with get_handle(path, "rb", compression=compression, is_text=False) as handle:
  30. yield handle.handle
  31. @contextmanager
  32. def set_timezone(tz: str):
  33. """
  34. Context manager for temporarily setting a timezone.
  35. Parameters
  36. ----------
  37. tz : str
  38. A string representing a valid timezone.
  39. Examples
  40. --------
  41. >>> from datetime import datetime
  42. >>> from dateutil.tz import tzlocal
  43. >>> tzlocal().tzname(datetime.now())
  44. 'IST'
  45. >>> with set_timezone('US/Eastern'):
  46. ... tzlocal().tzname(datetime.now())
  47. ...
  48. 'EDT'
  49. """
  50. import os
  51. import time
  52. def setTZ(tz):
  53. if tz is None:
  54. try:
  55. del os.environ["TZ"]
  56. except KeyError:
  57. pass
  58. else:
  59. os.environ["TZ"] = tz
  60. time.tzset()
  61. orig_tz = os.environ.get("TZ")
  62. setTZ(tz)
  63. try:
  64. yield
  65. finally:
  66. setTZ(orig_tz)
  67. @contextmanager
  68. def ensure_clean(filename=None, return_filelike: bool = False, **kwargs: Any):
  69. """
  70. Gets a temporary path and agrees to remove on close.
  71. This implementation does not use tempfile.mkstemp to avoid having a file handle.
  72. If the code using the returned path wants to delete the file itself, windows
  73. requires that no program has a file handle to it.
  74. Parameters
  75. ----------
  76. filename : str (optional)
  77. suffix of the created file.
  78. return_filelike : bool (default False)
  79. if True, returns a file-like which is *always* cleaned. Necessary for
  80. savefig and other functions which want to append extensions.
  81. **kwargs
  82. Additional keywords are passed to open().
  83. """
  84. folder = Path(tempfile.gettempdir())
  85. if filename is None:
  86. filename = ""
  87. filename = (
  88. "".join(random.choices(string.ascii_letters + string.digits, k=30)) + filename
  89. )
  90. path = folder / filename
  91. path.touch()
  92. handle_or_str: str | IO = str(path)
  93. if return_filelike:
  94. kwargs.setdefault("mode", "w+b")
  95. handle_or_str = open(path, **kwargs)
  96. try:
  97. yield handle_or_str
  98. finally:
  99. if not isinstance(handle_or_str, str):
  100. handle_or_str.close()
  101. if path.is_file():
  102. path.unlink()
  103. @contextmanager
  104. def ensure_clean_dir():
  105. """
  106. Get a temporary directory path and agrees to remove on close.
  107. Yields
  108. ------
  109. Temporary directory path
  110. """
  111. directory_name = tempfile.mkdtemp(suffix="")
  112. try:
  113. yield directory_name
  114. finally:
  115. try:
  116. rmtree(directory_name)
  117. except OSError:
  118. pass
  119. @contextmanager
  120. def ensure_safe_environment_variables():
  121. """
  122. Get a context manager to safely set environment variables
  123. All changes will be undone on close, hence environment variables set
  124. within this contextmanager will neither persist nor change global state.
  125. """
  126. saved_environ = dict(os.environ)
  127. try:
  128. yield
  129. finally:
  130. os.environ.clear()
  131. os.environ.update(saved_environ)
  132. @contextmanager
  133. def with_csv_dialect(name, **kwargs):
  134. """
  135. Context manager to temporarily register a CSV dialect for parsing CSV.
  136. Parameters
  137. ----------
  138. name : str
  139. The name of the dialect.
  140. kwargs : mapping
  141. The parameters for the dialect.
  142. Raises
  143. ------
  144. ValueError : the name of the dialect conflicts with a builtin one.
  145. See Also
  146. --------
  147. csv : Python's CSV library.
  148. """
  149. import csv
  150. _BUILTIN_DIALECTS = {"excel", "excel-tab", "unix"}
  151. if name in _BUILTIN_DIALECTS:
  152. raise ValueError("Cannot override builtin dialect.")
  153. csv.register_dialect(name, **kwargs)
  154. yield
  155. csv.unregister_dialect(name)
  156. @contextmanager
  157. def use_numexpr(use, min_elements=None):
  158. from pandas.core.computation import expressions as expr
  159. if min_elements is None:
  160. min_elements = expr._MIN_ELEMENTS
  161. olduse = expr.USE_NUMEXPR
  162. oldmin = expr._MIN_ELEMENTS
  163. expr.set_use_numexpr(use)
  164. expr._MIN_ELEMENTS = min_elements
  165. yield
  166. expr._MIN_ELEMENTS = oldmin
  167. expr.set_use_numexpr(olduse)
  168. class RNGContext:
  169. """
  170. Context manager to set the numpy random number generator speed. Returns
  171. to the original value upon exiting the context manager.
  172. Parameters
  173. ----------
  174. seed : int
  175. Seed for numpy.random.seed
  176. Examples
  177. --------
  178. with RNGContext(42):
  179. np.random.randn()
  180. """
  181. def __init__(self, seed):
  182. self.seed = seed
  183. def __enter__(self):
  184. self.start_state = np.random.get_state()
  185. np.random.seed(self.seed)
  186. def __exit__(self, exc_type, exc_value, traceback):
  187. np.random.set_state(self.start_state)