Skip to content

Commit d82e540

Browse files
authored
BUG: NaT.__cmp__(invalid) should raise TypeError (#35585)
1 parent c4691d6 commit d82e540

File tree

3 files changed

+73
-21
lines changed

3 files changed

+73
-21
lines changed

doc/source/whatsnew/v1.2.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ Categorical
6060
Datetimelike
6161
^^^^^^^^^^^^
6262
- Bug in :attr:`DatetimeArray.date` where a ``ValueError`` would be raised with a read-only backing array (:issue:`33530`)
63+
- Bug in ``NaT`` comparisons failing to raise ``TypeError`` on invalid inequality comparisons (:issue:`35046`)
6364
-
6465

6566
Timedelta

pandas/_libs/tslibs/nattype.pyx

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -107,30 +107,25 @@ cdef class _NaT(datetime):
107107
__array_priority__ = 100
108108

109109
def __richcmp__(_NaT self, object other, int op):
110-
cdef:
111-
int ndim = getattr(other, "ndim", -1)
110+
if util.is_datetime64_object(other) or PyDateTime_Check(other):
111+
# We treat NaT as datetime-like for this comparison
112+
return _nat_scalar_rules[op]
112113

113-
if ndim == -1:
114+
elif util.is_timedelta64_object(other) or PyDelta_Check(other):
115+
# We treat NaT as timedelta-like for this comparison
114116
return _nat_scalar_rules[op]
115117

116118
elif util.is_array(other):
117-
result = np.empty(other.shape, dtype=np.bool_)
118-
result.fill(_nat_scalar_rules[op])
119+
if other.dtype.kind in "mM":
120+
result = np.empty(other.shape, dtype=np.bool_)
121+
result.fill(_nat_scalar_rules[op])
122+
elif other.dtype.kind == "O":
123+
result = np.array([PyObject_RichCompare(self, x, op) for x in other])
124+
else:
125+
return NotImplemented
119126
return result
120127

121-
elif ndim == 0:
122-
if util.is_datetime64_object(other):
123-
return _nat_scalar_rules[op]
124-
else:
125-
raise TypeError(
126-
f"Cannot compare type {type(self).__name__} "
127-
f"with type {type(other).__name__}"
128-
)
129-
130-
# Note: instead of passing "other, self, _reverse_ops[op]", we observe
131-
# that `_nat_scalar_rules` is invariant under `_reverse_ops`,
132-
# rendering it unnecessary.
133-
return PyObject_RichCompare(other, self, op)
128+
return NotImplemented
134129

135130
def __add__(self, other):
136131
if self is not c_NaT:

pandas/tests/scalar/test_nat.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -513,11 +513,67 @@ def test_to_numpy_alias():
513513
assert isna(expected) and isna(result)
514514

515515

516-
@pytest.mark.parametrize("other", [Timedelta(0), Timestamp(0)])
516+
@pytest.mark.parametrize(
517+
"other",
518+
[
519+
Timedelta(0),
520+
Timedelta(0).to_pytimedelta(),
521+
pytest.param(
522+
Timedelta(0).to_timedelta64(),
523+
marks=pytest.mark.xfail(
524+
reason="td64 doesnt return NotImplemented, see numpy#17017"
525+
),
526+
),
527+
Timestamp(0),
528+
Timestamp(0).to_pydatetime(),
529+
pytest.param(
530+
Timestamp(0).to_datetime64(),
531+
marks=pytest.mark.xfail(
532+
reason="dt64 doesnt return NotImplemented, see numpy#17017"
533+
),
534+
),
535+
Timestamp(0).tz_localize("UTC"),
536+
NaT,
537+
],
538+
)
517539
def test_nat_comparisons(compare_operators_no_eq_ne, other):
518540
# GH 26039
519-
assert getattr(NaT, compare_operators_no_eq_ne)(other) is False
520-
assert getattr(other, compare_operators_no_eq_ne)(NaT) is False
541+
opname = compare_operators_no_eq_ne
542+
543+
assert getattr(NaT, opname)(other) is False
544+
545+
op = getattr(operator, opname.strip("_"))
546+
assert op(NaT, other) is False
547+
assert op(other, NaT) is False
548+
549+
550+
@pytest.mark.parametrize("other", [np.timedelta64(0, "ns"), np.datetime64("now", "ns")])
551+
def test_nat_comparisons_numpy(other):
552+
# Once numpy#17017 is fixed and the xfailed cases in test_nat_comparisons
553+
# pass, this test can be removed
554+
assert not NaT == other
555+
assert NaT != other
556+
assert not NaT < other
557+
assert not NaT > other
558+
assert not NaT <= other
559+
assert not NaT >= other
560+
561+
562+
@pytest.mark.parametrize("other", ["foo", 2, 2.0])
563+
@pytest.mark.parametrize("op", [operator.le, operator.lt, operator.ge, operator.gt])
564+
def test_nat_comparisons_invalid(other, op):
565+
# GH#35585
566+
assert not NaT == other
567+
assert not other == NaT
568+
569+
assert NaT != other
570+
assert other != NaT
571+
572+
with pytest.raises(TypeError):
573+
op(NaT, other)
574+
575+
with pytest.raises(TypeError):
576+
op(other, NaT)
521577

522578

523579
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)