testing.pyx 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import cmath
  2. import math
  3. import numpy as np
  4. from numpy cimport import_array
  5. import_array()
  6. from pandas._libs.lib import is_complex
  7. from pandas._libs.util cimport (
  8. is_array,
  9. is_real_number_object,
  10. )
  11. from pandas.core.dtypes.common import is_dtype_equal
  12. from pandas.core.dtypes.missing import (
  13. array_equivalent,
  14. isna,
  15. )
  16. cdef bint isiterable(obj):
  17. return hasattr(obj, '__iter__')
  18. cdef bint has_length(obj):
  19. return hasattr(obj, '__len__')
  20. cdef bint is_dictlike(obj):
  21. return hasattr(obj, 'keys') and hasattr(obj, '__getitem__')
  22. cpdef assert_dict_equal(a, b, bint compare_keys=True):
  23. assert is_dictlike(a) and is_dictlike(b), (
  24. "Cannot compare dict objects, one or both is not dict-like"
  25. )
  26. a_keys = frozenset(a.keys())
  27. b_keys = frozenset(b.keys())
  28. if compare_keys:
  29. assert a_keys == b_keys
  30. for k in a_keys:
  31. assert_almost_equal(a[k], b[k])
  32. return True
  33. cpdef assert_almost_equal(a, b,
  34. rtol=1.e-5, atol=1.e-8,
  35. bint check_dtype=True,
  36. obj=None, lobj=None, robj=None, index_values=None):
  37. """
  38. Check that left and right objects are almost equal.
  39. Parameters
  40. ----------
  41. a : object
  42. b : object
  43. rtol : float, default 1e-5
  44. Relative tolerance.
  45. .. versionadded:: 1.1.0
  46. atol : float, default 1e-8
  47. Absolute tolerance.
  48. .. versionadded:: 1.1.0
  49. check_dtype: bool, default True
  50. check dtype if both a and b are np.ndarray.
  51. obj : str, default None
  52. Specify object name being compared, internally used to show
  53. appropriate assertion message.
  54. lobj : str, default None
  55. Specify left object name being compared, internally used to show
  56. appropriate assertion message.
  57. robj : str, default None
  58. Specify right object name being compared, internally used to show
  59. appropriate assertion message.
  60. index_values : ndarray, default None
  61. Specify shared index values of objects being compared, internally used
  62. to show appropriate assertion message.
  63. .. versionadded:: 1.1.0
  64. """
  65. cdef:
  66. double diff = 0.0
  67. Py_ssize_t i, na, nb
  68. double fa, fb
  69. bint is_unequal = False, a_is_ndarray, b_is_ndarray
  70. if lobj is None:
  71. lobj = a
  72. if robj is None:
  73. robj = b
  74. if isinstance(a, dict) or isinstance(b, dict):
  75. return assert_dict_equal(a, b)
  76. if isinstance(a, str) or isinstance(b, str):
  77. assert a == b, f"{a} != {b}"
  78. return True
  79. a_is_ndarray = is_array(a)
  80. b_is_ndarray = is_array(b)
  81. if obj is None:
  82. if a_is_ndarray or b_is_ndarray:
  83. obj = 'numpy array'
  84. else:
  85. obj = 'Iterable'
  86. if isiterable(a):
  87. if not isiterable(b):
  88. from pandas._testing import assert_class_equal
  89. # classes can't be the same, to raise error
  90. assert_class_equal(a, b, obj=obj)
  91. assert has_length(a) and has_length(b), (
  92. f"Can't compare objects without length, one or both is invalid: ({a}, {b})"
  93. )
  94. if a_is_ndarray and b_is_ndarray:
  95. na, nb = a.size, b.size
  96. if a.shape != b.shape:
  97. from pandas._testing import raise_assert_detail
  98. raise_assert_detail(
  99. obj, f'{obj} shapes are different', a.shape, b.shape)
  100. if check_dtype and not is_dtype_equal(a.dtype, b.dtype):
  101. from pandas._testing import assert_attr_equal
  102. assert_attr_equal('dtype', a, b, obj=obj)
  103. if array_equivalent(a, b, strict_nan=True):
  104. return True
  105. else:
  106. na, nb = len(a), len(b)
  107. if na != nb:
  108. from pandas._testing import raise_assert_detail
  109. # if we have a small diff set, print it
  110. if abs(na - nb) < 10:
  111. r = list(set(a) ^ set(b))
  112. else:
  113. r = None
  114. raise_assert_detail(obj, f"{obj} length are different", na, nb, r)
  115. for i in range(len(a)):
  116. try:
  117. assert_almost_equal(a[i], b[i], rtol=rtol, atol=atol)
  118. except AssertionError:
  119. is_unequal = True
  120. diff += 1
  121. if is_unequal:
  122. from pandas._testing import raise_assert_detail
  123. msg = (f"{obj} values are different "
  124. f"({np.round(diff * 100.0 / na, 5)} %)")
  125. raise_assert_detail(obj, msg, lobj, robj, index_values=index_values)
  126. return True
  127. elif isiterable(b):
  128. from pandas._testing import assert_class_equal
  129. # classes can't be the same, to raise error
  130. assert_class_equal(a, b, obj=obj)
  131. if isna(a) and isna(b):
  132. # TODO: Should require same-dtype NA?
  133. # nan / None comparison
  134. return True
  135. if a == b:
  136. # object comparison
  137. return True
  138. if is_real_number_object(a) and is_real_number_object(b):
  139. if array_equivalent(a, b, strict_nan=True):
  140. # inf comparison
  141. return True
  142. fa, fb = a, b
  143. if not math.isclose(fa, fb, rel_tol=rtol, abs_tol=atol):
  144. assert False, (f"expected {fb:.5f} but got {fa:.5f}, "
  145. f"with rtol={rtol}, atol={atol}")
  146. return True
  147. if is_complex(a) and is_complex(b):
  148. if array_equivalent(a, b, strict_nan=True):
  149. # inf comparison
  150. return True
  151. if not cmath.isclose(a, b, rel_tol=rtol, abs_tol=atol):
  152. assert False, (f"expected {b:.5f} but got {a:.5f}, "
  153. f"with rtol={rtol}, atol={atol}")
  154. return True
  155. raise AssertionError(f"{a} != {b}")