diff --git a/pandas/_testing.py b/pandas/_testing.py index 33ec4e4886aa6..5e94ac3b3d108 100644 --- a/pandas/_testing.py +++ b/pandas/_testing.py @@ -34,6 +34,7 @@ is_interval_dtype, is_list_like, is_number, + is_numeric_dtype, is_period_dtype, is_sequence, is_timedelta64_dtype, @@ -1064,7 +1065,6 @@ def assert_series_equal( right, check_dtype=True, check_index_type="equiv", - check_series_type=True, check_less_precise=False, check_names=True, check_exact=False, @@ -1085,8 +1085,6 @@ def assert_series_equal( check_index_type : bool or {'equiv'}, default 'equiv' Whether to check the Index class, dtype and inferred_type are identical. - check_series_type : bool, default True - Whether to check the Series class is identical. check_less_precise : bool or int, default False Specify comparison precision. Only used when check_exact is False. 5 digits (False) or 3 digits (True) after decimal points are compared. @@ -1118,11 +1116,10 @@ def assert_series_equal( # instance validation _check_isinstance(left, right, Series) - if check_series_type: - # ToDo: There are some tests using rhs is sparse - # lhs is dense. Should use assert_class_equal in future - assert isinstance(left, type(right)) - # assert_class_equal(left, right, obj=obj) + # TODO: There are some tests using rhs is sparse + # lhs is dense. Should use assert_class_equal in future + assert isinstance(left, type(right)) + # assert_class_equal(left, right, obj=obj) # length comparison if len(left) != len(right): @@ -1147,8 +1144,8 @@ def assert_series_equal( # is False. We'll still raise if only one is a `Categorical`, # regardless of `check_categorical` if ( - is_categorical_dtype(left) - and is_categorical_dtype(right) + is_categorical_dtype(left.dtype) + and is_categorical_dtype(right.dtype) and not check_categorical ): pass @@ -1156,38 +1153,31 @@ def assert_series_equal( assert_attr_equal("dtype", left, right, obj=f"Attributes of {obj}") if check_exact: + if not is_numeric_dtype(left.dtype): + raise AssertionError("check_exact may only be used with numeric Series") + assert_numpy_array_equal( - left._internal_get_values(), - right._internal_get_values(), - check_dtype=check_dtype, - obj=str(obj), + left._values, right._values, check_dtype=check_dtype, obj=str(obj) ) - elif check_datetimelike_compat: + elif check_datetimelike_compat and ( + needs_i8_conversion(left.dtype) or needs_i8_conversion(right.dtype) + ): # we want to check only if we have compat dtypes # e.g. integer and M|m are NOT compat, but we can simply check # the values in that case - if needs_i8_conversion(left) or needs_i8_conversion(right): - - # datetimelike may have different objects (e.g. datetime.datetime - # vs Timestamp) but will compare equal - if not Index(left._values).equals(Index(right._values)): - msg = ( - f"[datetimelike_compat=True] {left._values} " - f"is not equal to {right._values}." - ) - raise AssertionError(msg) - else: - assert_numpy_array_equal( - left._internal_get_values(), - right._internal_get_values(), - check_dtype=check_dtype, + + # datetimelike may have different objects (e.g. datetime.datetime + # vs Timestamp) but will compare equal + if not Index(left._values).equals(Index(right._values)): + msg = ( + f"[datetimelike_compat=True] {left._values} " + f"is not equal to {right._values}." ) - elif is_interval_dtype(left) or is_interval_dtype(right): + raise AssertionError(msg) + elif is_interval_dtype(left.dtype) or is_interval_dtype(right.dtype): assert_interval_array_equal(left.array, right.array) - elif is_extension_array_dtype(left.dtype) and is_datetime64tz_dtype(left.dtype): + elif is_datetime64tz_dtype(left.dtype): # .values is an ndarray, but ._values is the ExtensionArray. - # TODO: Use .array - assert is_extension_array_dtype(right.dtype) assert_extension_array_equal(left._values, right._values) elif ( is_extension_array_dtype(left) diff --git a/pandas/tests/io/pytables/test_store.py b/pandas/tests/io/pytables/test_store.py index fd585a73f6ce6..888222b503b10 100644 --- a/pandas/tests/io/pytables/test_store.py +++ b/pandas/tests/io/pytables/test_store.py @@ -2312,7 +2312,7 @@ def test_index_types(self, setup_path): values = np.random.randn(2) func = lambda l, r: tm.assert_series_equal( - l, r, check_dtype=True, check_index_type=True, check_series_type=True + l, r, check_dtype=True, check_index_type=True ) with catch_warnings(record=True):