common.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. """ common utilities """
  2. import itertools
  3. import numpy as np
  4. from pandas import (
  5. DataFrame,
  6. Float64Index,
  7. MultiIndex,
  8. Series,
  9. UInt64Index,
  10. date_range,
  11. )
  12. import pandas._testing as tm
  13. def _mklbl(prefix, n):
  14. return [f"{prefix}{i}" for i in range(n)]
  15. def _axify(obj, key, axis):
  16. # create a tuple accessor
  17. axes = [slice(None)] * obj.ndim
  18. axes[axis] = key
  19. return tuple(axes)
  20. class Base:
  21. """indexing comprehensive base class"""
  22. _kinds = {"series", "frame"}
  23. _typs = {
  24. "ints",
  25. "uints",
  26. "labels",
  27. "mixed",
  28. "ts",
  29. "floats",
  30. "empty",
  31. "ts_rev",
  32. "multi",
  33. }
  34. def setup_method(self, method):
  35. self.series_ints = Series(np.random.rand(4), index=np.arange(0, 8, 2))
  36. self.frame_ints = DataFrame(
  37. np.random.randn(4, 4), index=np.arange(0, 8, 2), columns=np.arange(0, 12, 3)
  38. )
  39. self.series_uints = Series(
  40. np.random.rand(4), index=UInt64Index(np.arange(0, 8, 2))
  41. )
  42. self.frame_uints = DataFrame(
  43. np.random.randn(4, 4),
  44. index=UInt64Index(range(0, 8, 2)),
  45. columns=UInt64Index(range(0, 12, 3)),
  46. )
  47. self.series_floats = Series(
  48. np.random.rand(4), index=Float64Index(range(0, 8, 2))
  49. )
  50. self.frame_floats = DataFrame(
  51. np.random.randn(4, 4),
  52. index=Float64Index(range(0, 8, 2)),
  53. columns=Float64Index(range(0, 12, 3)),
  54. )
  55. m_idces = [
  56. MultiIndex.from_product([[1, 2], [3, 4]]),
  57. MultiIndex.from_product([[5, 6], [7, 8]]),
  58. MultiIndex.from_product([[9, 10], [11, 12]]),
  59. ]
  60. self.series_multi = Series(np.random.rand(4), index=m_idces[0])
  61. self.frame_multi = DataFrame(
  62. np.random.randn(4, 4), index=m_idces[0], columns=m_idces[1]
  63. )
  64. self.series_labels = Series(np.random.randn(4), index=list("abcd"))
  65. self.frame_labels = DataFrame(
  66. np.random.randn(4, 4), index=list("abcd"), columns=list("ABCD")
  67. )
  68. self.series_mixed = Series(np.random.randn(4), index=[2, 4, "null", 8])
  69. self.frame_mixed = DataFrame(np.random.randn(4, 4), index=[2, 4, "null", 8])
  70. self.series_ts = Series(
  71. np.random.randn(4), index=date_range("20130101", periods=4)
  72. )
  73. self.frame_ts = DataFrame(
  74. np.random.randn(4, 4), index=date_range("20130101", periods=4)
  75. )
  76. dates_rev = date_range("20130101", periods=4).sort_values(ascending=False)
  77. self.series_ts_rev = Series(np.random.randn(4), index=dates_rev)
  78. self.frame_ts_rev = DataFrame(np.random.randn(4, 4), index=dates_rev)
  79. self.frame_empty = DataFrame()
  80. self.series_empty = Series(dtype=object)
  81. # form agglomerates
  82. for kind in self._kinds:
  83. d = {}
  84. for typ in self._typs:
  85. d[typ] = getattr(self, f"{kind}_{typ}")
  86. setattr(self, kind, d)
  87. def generate_indices(self, f, values=False):
  88. """
  89. generate the indices
  90. if values is True , use the axis values
  91. is False, use the range
  92. """
  93. axes = f.axes
  94. if values:
  95. axes = (list(range(len(ax))) for ax in axes)
  96. return itertools.product(*axes)
  97. def get_value(self, name, f, i, values=False):
  98. """return the value for the location i"""
  99. # check against values
  100. if values:
  101. return f.values[i]
  102. elif name == "iat":
  103. return f.iloc[i]
  104. else:
  105. assert name == "at"
  106. return f.loc[i]
  107. def check_values(self, f, func, values=False):
  108. if f is None:
  109. return
  110. axes = f.axes
  111. indices = itertools.product(*axes)
  112. for i in indices:
  113. result = getattr(f, func)[i]
  114. # check against values
  115. if values:
  116. expected = f.values[i]
  117. else:
  118. expected = f
  119. for a in reversed(i):
  120. expected = expected.__getitem__(a)
  121. tm.assert_almost_equal(result, expected)
  122. def check_result(self, method, key, typs=None, axes=None, fails=None):
  123. def _eq(axis, obj, key):
  124. """compare equal for these 2 keys"""
  125. axified = _axify(obj, key, axis)
  126. try:
  127. getattr(obj, method).__getitem__(axified)
  128. except (IndexError, TypeError, KeyError) as detail:
  129. # if we are in fails, the ok, otherwise raise it
  130. if fails is not None:
  131. if isinstance(detail, fails):
  132. return
  133. raise
  134. if typs is None:
  135. typs = self._typs
  136. if axes is None:
  137. axes = [0, 1]
  138. else:
  139. assert axes in [0, 1]
  140. axes = [axes]
  141. # check
  142. for kind in self._kinds:
  143. d = getattr(self, kind)
  144. for ax in axes:
  145. for typ in typs:
  146. assert typ in self._typs
  147. obj = d[typ]
  148. if ax < obj.ndim:
  149. _eq(axis=ax, obj=obj, key=key)