Skip to content

Commit df64bcd

Browse files
committed
API: CategoricalDtype.__eq__ with categories=None stricter
1 parent dc4eaf3 commit df64bcd

File tree

5 files changed

+37
-14
lines changed

5 files changed

+37
-14
lines changed

pandas/core/dtypes/dtypes.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -354,12 +354,10 @@ def __eq__(self, other: Any) -> bool:
354354
elif not (hasattr(other, "ordered") and hasattr(other, "categories")):
355355
return False
356356
elif self.categories is None or other.categories is None:
357-
# We're forced into a suboptimal corner thanks to math and
358-
# backwards compatibility. We require that `CDT(...) == 'category'`
359-
# for all CDTs **including** `CDT(None, ...)`. Therefore, *all*
360-
# CDT(., .) = CDT(None, False) and *all*
361-
# CDT(., .) = CDT(None, True).
362-
return True
357+
# For non-fully-initialized dtypes, these are only equal to
358+
# - the string "categorical" (handled above)
359+
# - other CategoricalDtype with categories=None
360+
return self.categories is other.categories
363361
elif self.ordered or other.ordered:
364362
# At least one has ordered=True; equal if both have ordered=True
365363
# and the same values for categories in the same order.

pandas/tests/dtypes/cast/test_infer_dtype.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from pandas import (
1010
Categorical,
11+
CategoricalDtype,
1112
Interval,
1213
Period,
1314
Series,
@@ -149,8 +150,8 @@ def test_infer_dtype_from_scalar_errors():
149150
(np.array([[1.0, 2.0]]), np.float_, False),
150151
(Categorical(list("aabc")), np.object_, False),
151152
(Categorical([1, 2, 3]), np.int64, False),
152-
(Categorical(list("aabc")), "category", True),
153-
(Categorical([1, 2, 3]), "category", True),
153+
(Categorical(list("aabc")), CategoricalDtype(categories=["a", "b", "c"]), True),
154+
(Categorical([1, 2, 3]), CategoricalDtype(categories=[1, 2, 3]), True),
154155
(Timestamp("20160101"), np.object_, False),
155156
(np.datetime64("2016-01-01"), np.dtype("=M8[D]"), False),
156157
(date_range("20160101", periods=3), np.dtype("=M8[ns]"), False),

pandas/tests/dtypes/test_common.py

-1
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,6 @@ def test_is_complex_dtype():
641641
(pd.CategoricalIndex(["a", "b"]).dtype, CategoricalDtype(["a", "b"])),
642642
(pd.CategoricalIndex(["a", "b"]), CategoricalDtype(["a", "b"])),
643643
(CategoricalDtype(), CategoricalDtype()),
644-
(CategoricalDtype(["a", "b"]), CategoricalDtype()),
645644
(pd.DatetimeIndex([1, 2]), np.dtype("=M8[ns]")),
646645
(pd.DatetimeIndex([1, 2]).dtype, np.dtype("=M8[ns]")),
647646
("<M8[ns]", np.dtype("<M8[ns]")),

pandas/tests/dtypes/test_dtypes.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -834,10 +834,27 @@ def test_categorical_equality(self, ordered1, ordered2):
834834
c1 = CategoricalDtype(list("abc"), ordered1)
835835
c2 = CategoricalDtype(None, ordered2)
836836
c3 = CategoricalDtype(None, ordered1)
837-
assert c1 == c2
838-
assert c2 == c1
837+
assert c1 != c2
838+
assert c2 != c1
839839
assert c2 == c3
840840

841+
def test_categorical_dtype_equality_requires_categories(self):
842+
# CategoricalDtype with categories=None is *not* equal to
843+
# any fully-initialized CategoricalDtype
844+
first = CategoricalDtype(["a", "b"])
845+
second = CategoricalDtype()
846+
third = CategoricalDtype(ordered=True)
847+
848+
assert second == second
849+
assert third == third
850+
851+
assert first != second
852+
assert second != first
853+
assert first != third
854+
assert third != first
855+
assert second == third
856+
assert third == second
857+
841858
@pytest.mark.parametrize("categories", [list("abc"), None])
842859
@pytest.mark.parametrize("other", ["category", "not a category"])
843860
def test_categorical_equality_strings(self, categories, ordered, other):

pandas/tests/reshape/merge/test_merge.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -1622,7 +1622,7 @@ def test_identical(self, left):
16221622
merged = pd.merge(left, left, on="X")
16231623
result = merged.dtypes.sort_index()
16241624
expected = Series(
1625-
[CategoricalDtype(), np.dtype("O"), np.dtype("O")],
1625+
[CategoricalDtype(categories=["foo", "bar"]), np.dtype("O"), np.dtype("O")],
16261626
index=["X", "Y_x", "Y_y"],
16271627
)
16281628
tm.assert_series_equal(result, expected)
@@ -1633,7 +1633,11 @@ def test_basic(self, left, right):
16331633
merged = pd.merge(left, right, on="X")
16341634
result = merged.dtypes.sort_index()
16351635
expected = Series(
1636-
[CategoricalDtype(), np.dtype("O"), np.dtype("int64")],
1636+
[
1637+
CategoricalDtype(categories=["foo", "bar"]),
1638+
np.dtype("O"),
1639+
np.dtype("int64"),
1640+
],
16371641
index=["X", "Y", "Z"],
16381642
)
16391643
tm.assert_series_equal(result, expected)
@@ -1713,7 +1717,11 @@ def test_other_columns(self, left, right):
17131717
merged = pd.merge(left, right, on="X")
17141718
result = merged.dtypes.sort_index()
17151719
expected = Series(
1716-
[CategoricalDtype(), np.dtype("O"), CategoricalDtype()],
1720+
[
1721+
CategoricalDtype(categories=["foo", "bar"]),
1722+
np.dtype("O"),
1723+
CategoricalDtype(categories=[1, 2]),
1724+
],
17171725
index=["X", "Y", "Z"],
17181726
)
17191727
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)