overrides.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. """Tools for testing implementations of __array_function__ and ufunc overrides
  2. """
  3. from numpy.core.overrides import ARRAY_FUNCTIONS as _array_functions
  4. from numpy import ufunc as _ufunc
  5. import numpy.core.umath as _umath
  6. def get_overridable_numpy_ufuncs():
  7. """List all numpy ufuncs overridable via `__array_ufunc__`
  8. Parameters
  9. ----------
  10. None
  11. Returns
  12. -------
  13. set
  14. A set containing all overridable ufuncs in the public numpy API.
  15. """
  16. ufuncs = {obj for obj in _umath.__dict__.values()
  17. if isinstance(obj, _ufunc)}
  18. return ufuncs
  19. def allows_array_ufunc_override(func):
  20. """Determine if a function can be overridden via `__array_ufunc__`
  21. Parameters
  22. ----------
  23. func : callable
  24. Function that may be overridable via `__array_ufunc__`
  25. Returns
  26. -------
  27. bool
  28. `True` if `func` is overridable via `__array_ufunc__` and
  29. `False` otherwise.
  30. Notes
  31. -----
  32. This function is equivalent to ``isinstance(func, np.ufunc)`` and
  33. will work correctly for ufuncs defined outside of Numpy.
  34. """
  35. return isinstance(func, np.ufunc)
  36. def get_overridable_numpy_array_functions():
  37. """List all numpy functions overridable via `__array_function__`
  38. Parameters
  39. ----------
  40. None
  41. Returns
  42. -------
  43. set
  44. A set containing all functions in the public numpy API that are
  45. overridable via `__array_function__`.
  46. """
  47. # 'import numpy' doesn't import recfunctions, so make sure it's imported
  48. # so ufuncs defined there show up in the ufunc listing
  49. from numpy.lib import recfunctions
  50. return _array_functions.copy()
  51. def allows_array_function_override(func):
  52. """Determine if a Numpy function can be overridden via `__array_function__`
  53. Parameters
  54. ----------
  55. func : callable
  56. Function that may be overridable via `__array_function__`
  57. Returns
  58. -------
  59. bool
  60. `True` if `func` is a function in the Numpy API that is
  61. overridable via `__array_function__` and `False` otherwise.
  62. """
  63. return func in _array_functions