function.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. """
  2. For compatibility with numpy libraries, pandas functions or methods have to
  3. accept '*args' and '**kwargs' parameters to accommodate numpy arguments that
  4. are not actually used or respected in the pandas implementation.
  5. To ensure that users do not abuse these parameters, validation is performed in
  6. 'validators.py' to make sure that any extra parameters passed correspond ONLY
  7. to those in the numpy signature. Part of that validation includes whether or
  8. not the user attempted to pass in non-default values for these extraneous
  9. parameters. As we want to discourage users from relying on these parameters
  10. when calling the pandas implementation, we want them only to pass in the
  11. default values for these parameters.
  12. This module provides a set of commonly used default arguments for functions and
  13. methods that are spread throughout the codebase. This module will make it
  14. easier to adjust to future upstream changes in the analogous numpy signatures.
  15. """
  16. from __future__ import annotations
  17. from typing import Any
  18. from numpy import ndarray
  19. from pandas._libs.lib import (
  20. is_bool,
  21. is_integer,
  22. )
  23. from pandas.errors import UnsupportedFunctionCall
  24. from pandas.util._validators import (
  25. validate_args,
  26. validate_args_and_kwargs,
  27. validate_kwargs,
  28. )
  29. class CompatValidator:
  30. def __init__(
  31. self,
  32. defaults,
  33. fname=None,
  34. method: str | None = None,
  35. max_fname_arg_count=None,
  36. ):
  37. self.fname = fname
  38. self.method = method
  39. self.defaults = defaults
  40. self.max_fname_arg_count = max_fname_arg_count
  41. def __call__(
  42. self,
  43. args,
  44. kwargs,
  45. fname=None,
  46. max_fname_arg_count=None,
  47. method: str | None = None,
  48. ) -> None:
  49. if args or kwargs:
  50. fname = self.fname if fname is None else fname
  51. max_fname_arg_count = (
  52. self.max_fname_arg_count
  53. if max_fname_arg_count is None
  54. else max_fname_arg_count
  55. )
  56. method = self.method if method is None else method
  57. if method == "args":
  58. validate_args(fname, args, max_fname_arg_count, self.defaults)
  59. elif method == "kwargs":
  60. validate_kwargs(fname, kwargs, self.defaults)
  61. elif method == "both":
  62. validate_args_and_kwargs(
  63. fname, args, kwargs, max_fname_arg_count, self.defaults
  64. )
  65. else:
  66. raise ValueError(f"invalid validation method '{method}'")
  67. ARGMINMAX_DEFAULTS = {"out": None}
  68. validate_argmin = CompatValidator(
  69. ARGMINMAX_DEFAULTS, fname="argmin", method="both", max_fname_arg_count=1
  70. )
  71. validate_argmax = CompatValidator(
  72. ARGMINMAX_DEFAULTS, fname="argmax", method="both", max_fname_arg_count=1
  73. )
  74. def process_skipna(skipna, args):
  75. if isinstance(skipna, ndarray) or skipna is None:
  76. args = (skipna,) + args
  77. skipna = True
  78. return skipna, args
  79. def validate_argmin_with_skipna(skipna, args, kwargs):
  80. """
  81. If 'Series.argmin' is called via the 'numpy' library, the third parameter
  82. in its signature is 'out', which takes either an ndarray or 'None', so
  83. check if the 'skipna' parameter is either an instance of ndarray or is
  84. None, since 'skipna' itself should be a boolean
  85. """
  86. skipna, args = process_skipna(skipna, args)
  87. validate_argmin(args, kwargs)
  88. return skipna
  89. def validate_argmax_with_skipna(skipna, args, kwargs):
  90. """
  91. If 'Series.argmax' is called via the 'numpy' library, the third parameter
  92. in its signature is 'out', which takes either an ndarray or 'None', so
  93. check if the 'skipna' parameter is either an instance of ndarray or is
  94. None, since 'skipna' itself should be a boolean
  95. """
  96. skipna, args = process_skipna(skipna, args)
  97. validate_argmax(args, kwargs)
  98. return skipna
  99. ARGSORT_DEFAULTS: dict[str, int | str | None] = {}
  100. ARGSORT_DEFAULTS["axis"] = -1
  101. ARGSORT_DEFAULTS["kind"] = "quicksort"
  102. ARGSORT_DEFAULTS["order"] = None
  103. ARGSORT_DEFAULTS["kind"] = None
  104. validate_argsort = CompatValidator(
  105. ARGSORT_DEFAULTS, fname="argsort", max_fname_arg_count=0, method="both"
  106. )
  107. # two different signatures of argsort, this second validation for when the
  108. # `kind` param is supported
  109. ARGSORT_DEFAULTS_KIND: dict[str, int | None] = {}
  110. ARGSORT_DEFAULTS_KIND["axis"] = -1
  111. ARGSORT_DEFAULTS_KIND["order"] = None
  112. validate_argsort_kind = CompatValidator(
  113. ARGSORT_DEFAULTS_KIND, fname="argsort", max_fname_arg_count=0, method="both"
  114. )
  115. def validate_argsort_with_ascending(ascending, args, kwargs):
  116. """
  117. If 'Categorical.argsort' is called via the 'numpy' library, the first
  118. parameter in its signature is 'axis', which takes either an integer or
  119. 'None', so check if the 'ascending' parameter has either integer type or is
  120. None, since 'ascending' itself should be a boolean
  121. """
  122. if is_integer(ascending) or ascending is None:
  123. args = (ascending,) + args
  124. ascending = True
  125. validate_argsort_kind(args, kwargs, max_fname_arg_count=3)
  126. return ascending
  127. CLIP_DEFAULTS: dict[str, Any] = {"out": None}
  128. validate_clip = CompatValidator(
  129. CLIP_DEFAULTS, fname="clip", method="both", max_fname_arg_count=3
  130. )
  131. def validate_clip_with_axis(axis, args, kwargs):
  132. """
  133. If 'NDFrame.clip' is called via the numpy library, the third parameter in
  134. its signature is 'out', which can takes an ndarray, so check if the 'axis'
  135. parameter is an instance of ndarray, since 'axis' itself should either be
  136. an integer or None
  137. """
  138. if isinstance(axis, ndarray):
  139. args = (axis,) + args
  140. axis = None
  141. validate_clip(args, kwargs)
  142. return axis
  143. CUM_FUNC_DEFAULTS: dict[str, Any] = {}
  144. CUM_FUNC_DEFAULTS["dtype"] = None
  145. CUM_FUNC_DEFAULTS["out"] = None
  146. validate_cum_func = CompatValidator(
  147. CUM_FUNC_DEFAULTS, method="both", max_fname_arg_count=1
  148. )
  149. validate_cumsum = CompatValidator(
  150. CUM_FUNC_DEFAULTS, fname="cumsum", method="both", max_fname_arg_count=1
  151. )
  152. def validate_cum_func_with_skipna(skipna, args, kwargs, name):
  153. """
  154. If this function is called via the 'numpy' library, the third parameter in
  155. its signature is 'dtype', which takes either a 'numpy' dtype or 'None', so
  156. check if the 'skipna' parameter is a boolean or not
  157. """
  158. if not is_bool(skipna):
  159. args = (skipna,) + args
  160. skipna = True
  161. validate_cum_func(args, kwargs, fname=name)
  162. return skipna
  163. ALLANY_DEFAULTS: dict[str, bool | None] = {}
  164. ALLANY_DEFAULTS["dtype"] = None
  165. ALLANY_DEFAULTS["out"] = None
  166. ALLANY_DEFAULTS["keepdims"] = False
  167. ALLANY_DEFAULTS["axis"] = None
  168. validate_all = CompatValidator(
  169. ALLANY_DEFAULTS, fname="all", method="both", max_fname_arg_count=1
  170. )
  171. validate_any = CompatValidator(
  172. ALLANY_DEFAULTS, fname="any", method="both", max_fname_arg_count=1
  173. )
  174. LOGICAL_FUNC_DEFAULTS = {"out": None, "keepdims": False}
  175. validate_logical_func = CompatValidator(LOGICAL_FUNC_DEFAULTS, method="kwargs")
  176. MINMAX_DEFAULTS = {"axis": None, "out": None, "keepdims": False}
  177. validate_min = CompatValidator(
  178. MINMAX_DEFAULTS, fname="min", method="both", max_fname_arg_count=1
  179. )
  180. validate_max = CompatValidator(
  181. MINMAX_DEFAULTS, fname="max", method="both", max_fname_arg_count=1
  182. )
  183. RESHAPE_DEFAULTS: dict[str, str] = {"order": "C"}
  184. validate_reshape = CompatValidator(
  185. RESHAPE_DEFAULTS, fname="reshape", method="both", max_fname_arg_count=1
  186. )
  187. REPEAT_DEFAULTS: dict[str, Any] = {"axis": None}
  188. validate_repeat = CompatValidator(
  189. REPEAT_DEFAULTS, fname="repeat", method="both", max_fname_arg_count=1
  190. )
  191. ROUND_DEFAULTS: dict[str, Any] = {"out": None}
  192. validate_round = CompatValidator(
  193. ROUND_DEFAULTS, fname="round", method="both", max_fname_arg_count=1
  194. )
  195. SORT_DEFAULTS: dict[str, int | str | None] = {}
  196. SORT_DEFAULTS["axis"] = -1
  197. SORT_DEFAULTS["kind"] = "quicksort"
  198. SORT_DEFAULTS["order"] = None
  199. validate_sort = CompatValidator(SORT_DEFAULTS, fname="sort", method="kwargs")
  200. STAT_FUNC_DEFAULTS: dict[str, Any | None] = {}
  201. STAT_FUNC_DEFAULTS["dtype"] = None
  202. STAT_FUNC_DEFAULTS["out"] = None
  203. SUM_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
  204. SUM_DEFAULTS["axis"] = None
  205. SUM_DEFAULTS["keepdims"] = False
  206. SUM_DEFAULTS["initial"] = None
  207. PROD_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
  208. PROD_DEFAULTS["axis"] = None
  209. PROD_DEFAULTS["keepdims"] = False
  210. PROD_DEFAULTS["initial"] = None
  211. MEDIAN_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
  212. MEDIAN_DEFAULTS["overwrite_input"] = False
  213. MEDIAN_DEFAULTS["keepdims"] = False
  214. STAT_FUNC_DEFAULTS["keepdims"] = False
  215. validate_stat_func = CompatValidator(STAT_FUNC_DEFAULTS, method="kwargs")
  216. validate_sum = CompatValidator(
  217. SUM_DEFAULTS, fname="sum", method="both", max_fname_arg_count=1
  218. )
  219. validate_prod = CompatValidator(
  220. PROD_DEFAULTS, fname="prod", method="both", max_fname_arg_count=1
  221. )
  222. validate_mean = CompatValidator(
  223. STAT_FUNC_DEFAULTS, fname="mean", method="both", max_fname_arg_count=1
  224. )
  225. validate_median = CompatValidator(
  226. MEDIAN_DEFAULTS, fname="median", method="both", max_fname_arg_count=1
  227. )
  228. STAT_DDOF_FUNC_DEFAULTS: dict[str, bool | None] = {}
  229. STAT_DDOF_FUNC_DEFAULTS["dtype"] = None
  230. STAT_DDOF_FUNC_DEFAULTS["out"] = None
  231. STAT_DDOF_FUNC_DEFAULTS["keepdims"] = False
  232. validate_stat_ddof_func = CompatValidator(STAT_DDOF_FUNC_DEFAULTS, method="kwargs")
  233. TAKE_DEFAULTS: dict[str, str | None] = {}
  234. TAKE_DEFAULTS["out"] = None
  235. TAKE_DEFAULTS["mode"] = "raise"
  236. validate_take = CompatValidator(TAKE_DEFAULTS, fname="take", method="kwargs")
  237. def validate_take_with_convert(convert, args, kwargs):
  238. """
  239. If this function is called via the 'numpy' library, the third parameter in
  240. its signature is 'axis', which takes either an ndarray or 'None', so check
  241. if the 'convert' parameter is either an instance of ndarray or is None
  242. """
  243. if isinstance(convert, ndarray) or convert is None:
  244. args = (convert,) + args
  245. convert = True
  246. validate_take(args, kwargs, max_fname_arg_count=3, method="both")
  247. return convert
  248. TRANSPOSE_DEFAULTS = {"axes": None}
  249. validate_transpose = CompatValidator(
  250. TRANSPOSE_DEFAULTS, fname="transpose", method="both", max_fname_arg_count=0
  251. )
  252. def validate_window_func(name, args, kwargs) -> None:
  253. numpy_args = ("axis", "dtype", "out")
  254. msg = (
  255. f"numpy operations are not valid with window objects. "
  256. f"Use .{name}() directly instead "
  257. )
  258. if len(args) > 0:
  259. raise UnsupportedFunctionCall(msg)
  260. for arg in numpy_args:
  261. if arg in kwargs:
  262. raise UnsupportedFunctionCall(msg)
  263. def validate_rolling_func(name, args, kwargs) -> None:
  264. numpy_args = ("axis", "dtype", "out")
  265. msg = (
  266. f"numpy operations are not valid with window objects. "
  267. f"Use .rolling(...).{name}() instead "
  268. )
  269. if len(args) > 0:
  270. raise UnsupportedFunctionCall(msg)
  271. for arg in numpy_args:
  272. if arg in kwargs:
  273. raise UnsupportedFunctionCall(msg)
  274. def validate_expanding_func(name, args, kwargs) -> None:
  275. numpy_args = ("axis", "dtype", "out")
  276. msg = (
  277. f"numpy operations are not valid with window objects. "
  278. f"Use .expanding(...).{name}() instead "
  279. )
  280. if len(args) > 0:
  281. raise UnsupportedFunctionCall(msg)
  282. for arg in numpy_args:
  283. if arg in kwargs:
  284. raise UnsupportedFunctionCall(msg)
  285. def validate_groupby_func(name, args, kwargs, allowed=None) -> None:
  286. """
  287. 'args' and 'kwargs' should be empty, except for allowed kwargs because all
  288. of their necessary parameters are explicitly listed in the function
  289. signature
  290. """
  291. if allowed is None:
  292. allowed = []
  293. kwargs = set(kwargs) - set(allowed)
  294. if len(args) + len(kwargs) > 0:
  295. raise UnsupportedFunctionCall(
  296. "numpy operations are not valid with groupby. "
  297. f"Use .groupby(...).{name}() instead"
  298. )
  299. RESAMPLER_NUMPY_OPS = ("min", "max", "sum", "prod", "mean", "std", "var")
  300. def validate_resampler_func(method: str, args, kwargs) -> None:
  301. """
  302. 'args' and 'kwargs' should be empty because all of their necessary
  303. parameters are explicitly listed in the function signature
  304. """
  305. if len(args) + len(kwargs) > 0:
  306. if method in RESAMPLER_NUMPY_OPS:
  307. raise UnsupportedFunctionCall(
  308. "numpy operations are not valid with resample. "
  309. f"Use .resample(...).{method}() instead"
  310. )
  311. else:
  312. raise TypeError("too many arguments passed in")
  313. def validate_minmax_axis(axis: int | None, ndim: int = 1) -> None:
  314. """
  315. Ensure that the axis argument passed to min, max, argmin, or argmax is zero
  316. or None, as otherwise it will be incorrectly ignored.
  317. Parameters
  318. ----------
  319. axis : int or None
  320. ndim : int, default 1
  321. Raises
  322. ------
  323. ValueError
  324. """
  325. if axis is None:
  326. return
  327. if axis >= ndim or (axis < 0 and ndim + axis < 0):
  328. raise ValueError(f"`axis` must be fewer than the number of dimensions ({ndim})")