conftest.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import numpy as np
  2. import pytest
  3. from pandas import Series
  4. from pandas.core import strings as strings
  5. _any_string_method = [
  6. ("cat", (), {"sep": ","}),
  7. ("cat", (Series(list("zyx")),), {"sep": ",", "join": "left"}),
  8. ("center", (10,), {}),
  9. ("contains", ("a",), {}),
  10. ("count", ("a",), {}),
  11. ("decode", ("UTF-8",), {}),
  12. ("encode", ("UTF-8",), {}),
  13. ("endswith", ("a",), {}),
  14. ("endswith", ("a",), {"na": True}),
  15. ("endswith", ("a",), {"na": False}),
  16. ("extract", ("([a-z]*)",), {"expand": False}),
  17. ("extract", ("([a-z]*)",), {"expand": True}),
  18. ("extractall", ("([a-z]*)",), {}),
  19. ("find", ("a",), {}),
  20. ("findall", ("a",), {}),
  21. ("get", (0,), {}),
  22. # because "index" (and "rindex") fail intentionally
  23. # if the string is not found, search only for empty string
  24. ("index", ("",), {}),
  25. ("join", (",",), {}),
  26. ("ljust", (10,), {}),
  27. ("match", ("a",), {}),
  28. ("fullmatch", ("a",), {}),
  29. ("normalize", ("NFC",), {}),
  30. ("pad", (10,), {}),
  31. ("partition", (" ",), {"expand": False}),
  32. ("partition", (" ",), {"expand": True}),
  33. ("repeat", (3,), {}),
  34. ("replace", ("a", "z"), {}),
  35. ("rfind", ("a",), {}),
  36. ("rindex", ("",), {}),
  37. ("rjust", (10,), {}),
  38. ("rpartition", (" ",), {"expand": False}),
  39. ("rpartition", (" ",), {"expand": True}),
  40. ("slice", (0, 1), {}),
  41. ("slice_replace", (0, 1, "z"), {}),
  42. ("split", (" ",), {"expand": False}),
  43. ("split", (" ",), {"expand": True}),
  44. ("startswith", ("a",), {}),
  45. ("startswith", ("a",), {"na": True}),
  46. ("startswith", ("a",), {"na": False}),
  47. # translating unicode points of "a" to "d"
  48. ("translate", ({97: 100},), {}),
  49. ("wrap", (2,), {}),
  50. ("zfill", (10,), {}),
  51. ] + list(
  52. zip(
  53. [
  54. # methods without positional arguments: zip with empty tuple and empty dict
  55. "capitalize",
  56. "cat",
  57. "get_dummies",
  58. "isalnum",
  59. "isalpha",
  60. "isdecimal",
  61. "isdigit",
  62. "islower",
  63. "isnumeric",
  64. "isspace",
  65. "istitle",
  66. "isupper",
  67. "len",
  68. "lower",
  69. "lstrip",
  70. "partition",
  71. "rpartition",
  72. "rsplit",
  73. "rstrip",
  74. "slice",
  75. "slice_replace",
  76. "split",
  77. "strip",
  78. "swapcase",
  79. "title",
  80. "upper",
  81. "casefold",
  82. ],
  83. [()] * 100,
  84. [{}] * 100,
  85. )
  86. )
  87. ids, _, _ = zip(*_any_string_method) # use method name as fixture-id
  88. missing_methods = {
  89. f for f in dir(strings.StringMethods) if not f.startswith("_")
  90. } - set(ids)
  91. # test that the above list captures all methods of StringMethods
  92. assert not missing_methods
  93. @pytest.fixture(params=_any_string_method, ids=ids)
  94. def any_string_method(request):
  95. """
  96. Fixture for all public methods of `StringMethods`
  97. This fixture returns a tuple of the method name and sample arguments
  98. necessary to call the method.
  99. Returns
  100. -------
  101. method_name : str
  102. The name of the method in `StringMethods`
  103. args : tuple
  104. Sample values for the positional arguments
  105. kwargs : dict
  106. Sample values for the keyword arguments
  107. Examples
  108. --------
  109. >>> def test_something(any_string_method):
  110. ... s = Series(['a', 'b', np.nan, 'd'])
  111. ...
  112. ... method_name, args, kwargs = any_string_method
  113. ... method = getattr(s.str, method_name)
  114. ... # will not raise
  115. ... method(*args, **kwargs)
  116. """
  117. return request.param
  118. # subset of the full set from pandas/conftest.py
  119. _any_allowed_skipna_inferred_dtype = [
  120. ("string", ["a", np.nan, "c"]),
  121. ("bytes", [b"a", np.nan, b"c"]),
  122. ("empty", [np.nan, np.nan, np.nan]),
  123. ("empty", []),
  124. ("mixed-integer", ["a", np.nan, 2]),
  125. ]
  126. ids, _ = zip(*_any_allowed_skipna_inferred_dtype) # use inferred type as id
  127. @pytest.fixture(params=_any_allowed_skipna_inferred_dtype, ids=ids)
  128. def any_allowed_skipna_inferred_dtype(request):
  129. """
  130. Fixture for all (inferred) dtypes allowed in StringMethods.__init__
  131. The covered (inferred) types are:
  132. * 'string'
  133. * 'empty'
  134. * 'bytes'
  135. * 'mixed'
  136. * 'mixed-integer'
  137. Returns
  138. -------
  139. inferred_dtype : str
  140. The string for the inferred dtype from _libs.lib.infer_dtype
  141. values : np.ndarray
  142. An array of object dtype that will be inferred to have
  143. `inferred_dtype`
  144. Examples
  145. --------
  146. >>> import pandas._libs.lib as lib
  147. >>>
  148. >>> def test_something(any_allowed_skipna_inferred_dtype):
  149. ... inferred_dtype, values = any_allowed_skipna_inferred_dtype
  150. ... # will pass
  151. ... assert lib.infer_dtype(values, skipna=True) == inferred_dtype
  152. ...
  153. ... # constructor for .str-accessor will also pass
  154. ... Series(values).str
  155. """
  156. inferred_dtype, values = request.param
  157. values = np.array(values, dtype=object) # object dtype to avoid casting
  158. # correctness of inference tested in tests/dtypes/test_inference.py
  159. return inferred_dtype, values