123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238 |
- from __future__ import annotations
- from contextlib import contextmanager
- import os
- from pathlib import Path
- import random
- from shutil import rmtree
- import string
- import tempfile
- from typing import (
- IO,
- Any,
- )
- import numpy as np
- from pandas.io.common import get_handle
- @contextmanager
- def decompress_file(path, compression):
- """
- Open a compressed file and return a file object.
- Parameters
- ----------
- path : str
- The path where the file is read from.
- compression : {'gzip', 'bz2', 'zip', 'xz', None}
- Name of the decompression to use
- Returns
- -------
- file object
- """
- with get_handle(path, "rb", compression=compression, is_text=False) as handle:
- yield handle.handle
- @contextmanager
- def set_timezone(tz: str):
- """
- Context manager for temporarily setting a timezone.
- Parameters
- ----------
- tz : str
- A string representing a valid timezone.
- Examples
- --------
- >>> from datetime import datetime
- >>> from dateutil.tz import tzlocal
- >>> tzlocal().tzname(datetime.now())
- 'IST'
- >>> with set_timezone('US/Eastern'):
- ... tzlocal().tzname(datetime.now())
- ...
- 'EDT'
- """
- import os
- import time
- def setTZ(tz):
- if tz is None:
- try:
- del os.environ["TZ"]
- except KeyError:
- pass
- else:
- os.environ["TZ"] = tz
- time.tzset()
- orig_tz = os.environ.get("TZ")
- setTZ(tz)
- try:
- yield
- finally:
- setTZ(orig_tz)
- @contextmanager
- def ensure_clean(filename=None, return_filelike: bool = False, **kwargs: Any):
- """
- Gets a temporary path and agrees to remove on close.
- This implementation does not use tempfile.mkstemp to avoid having a file handle.
- If the code using the returned path wants to delete the file itself, windows
- requires that no program has a file handle to it.
- Parameters
- ----------
- filename : str (optional)
- suffix of the created file.
- return_filelike : bool (default False)
- if True, returns a file-like which is *always* cleaned. Necessary for
- savefig and other functions which want to append extensions.
- **kwargs
- Additional keywords are passed to open().
- """
- folder = Path(tempfile.gettempdir())
- if filename is None:
- filename = ""
- filename = (
- "".join(random.choices(string.ascii_letters + string.digits, k=30)) + filename
- )
- path = folder / filename
- path.touch()
- handle_or_str: str | IO = str(path)
- if return_filelike:
- kwargs.setdefault("mode", "w+b")
- handle_or_str = open(path, **kwargs)
- try:
- yield handle_or_str
- finally:
- if not isinstance(handle_or_str, str):
- handle_or_str.close()
- if path.is_file():
- path.unlink()
- @contextmanager
- def ensure_clean_dir():
- """
- Get a temporary directory path and agrees to remove on close.
- Yields
- ------
- Temporary directory path
- """
- directory_name = tempfile.mkdtemp(suffix="")
- try:
- yield directory_name
- finally:
- try:
- rmtree(directory_name)
- except OSError:
- pass
- @contextmanager
- def ensure_safe_environment_variables():
- """
- Get a context manager to safely set environment variables
- All changes will be undone on close, hence environment variables set
- within this contextmanager will neither persist nor change global state.
- """
- saved_environ = dict(os.environ)
- try:
- yield
- finally:
- os.environ.clear()
- os.environ.update(saved_environ)
- @contextmanager
- def with_csv_dialect(name, **kwargs):
- """
- Context manager to temporarily register a CSV dialect for parsing CSV.
- Parameters
- ----------
- name : str
- The name of the dialect.
- kwargs : mapping
- The parameters for the dialect.
- Raises
- ------
- ValueError : the name of the dialect conflicts with a builtin one.
- See Also
- --------
- csv : Python's CSV library.
- """
- import csv
- _BUILTIN_DIALECTS = {"excel", "excel-tab", "unix"}
- if name in _BUILTIN_DIALECTS:
- raise ValueError("Cannot override builtin dialect.")
- csv.register_dialect(name, **kwargs)
- yield
- csv.unregister_dialect(name)
- @contextmanager
- def use_numexpr(use, min_elements=None):
- from pandas.core.computation import expressions as expr
- if min_elements is None:
- min_elements = expr._MIN_ELEMENTS
- olduse = expr.USE_NUMEXPR
- oldmin = expr._MIN_ELEMENTS
- expr.set_use_numexpr(use)
- expr._MIN_ELEMENTS = min_elements
- yield
- expr._MIN_ELEMENTS = oldmin
- expr.set_use_numexpr(olduse)
- class RNGContext:
- """
- Context manager to set the numpy random number generator speed. Returns
- to the original value upon exiting the context manager.
- Parameters
- ----------
- seed : int
- Seed for numpy.random.seed
- Examples
- --------
- with RNGContext(42):
- np.random.randn()
- """
- def __init__(self, seed):
- self.seed = seed
- def __enter__(self):
- self.start_state = np.random.get_state()
- np.random.seed(self.seed)
- def __exit__(self, exc_type, exc_value, traceback):
- np.random.set_state(self.start_state)
|