Skip to content

Commit bff366a

Browse files
authored
REF/BUG: ensure we get DTA/TDA/PA back from Index._intersection (#40064)
1 parent e5e1fba commit bff366a

File tree

4 files changed

+24
-23
lines changed

4 files changed

+24
-23
lines changed

pandas/core/indexes/base.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,10 @@
141141
PandasObject,
142142
)
143143
import pandas.core.common as com
144-
from pandas.core.construction import extract_array
144+
from pandas.core.construction import (
145+
ensure_wrapped_if_datetimelike,
146+
extract_array,
147+
)
145148
from pandas.core.indexers import deprecate_ndim_indexing
146149
from pandas.core.indexes.frozen import FrozenList
147150
from pandas.core.ops import get_op_result_name
@@ -2913,7 +2916,7 @@ def union(self, other, sort=None):
29132916

29142917
return self._wrap_setop_result(other, result)
29152918

2916-
def _union(self, other, sort):
2919+
def _union(self, other: Index, sort):
29172920
"""
29182921
Specific union logic should go here. In subclasses, union behavior
29192922
should be overwritten here rather than in `self.union`.
@@ -3042,7 +3045,7 @@ def intersection(self, other, sort=False):
30423045
result = self._intersection(other, sort=sort)
30433046
return self._wrap_setop_result(other, result)
30443047

3045-
def _intersection(self, other, sort=False):
3048+
def _intersection(self, other: Index, sort=False):
30463049
"""
30473050
intersection specialized to the case with matching dtypes.
30483051
"""
@@ -3056,13 +3059,14 @@ def _intersection(self, other, sort=False):
30563059
except TypeError:
30573060
pass
30583061
else:
3059-
return algos.unique1d(result)
3062+
# TODO: algos.unique1d should preserve DTA/TDA
3063+
res = algos.unique1d(result)
3064+
return ensure_wrapped_if_datetimelike(res)
30603065

30613066
try:
30623067
indexer = other.get_indexer(lvals)
3063-
except (InvalidIndexError, IncompatibleFrequency):
3068+
except InvalidIndexError:
30643069
# InvalidIndexError raised by get_indexer if non-unique
3065-
# IncompatibleFrequency raised by PeriodIndex.get_indexer
30663070
indexer, _ = other.get_indexer_non_unique(lvals)
30673071

30683072
mask = indexer != -1

pandas/core/indexes/datetimelike.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
is_scalar,
4848
)
4949
from pandas.core.dtypes.concat import concat_compat
50-
from pandas.core.dtypes.generic import ABCSeries
5150

5251
from pandas.core.arrays import (
5352
DatetimeArray,
@@ -86,24 +85,20 @@ def _join_i8_wrapper(joinf, with_indexers: bool = True):
8685
@staticmethod # type: ignore[misc]
8786
def wrapper(left, right):
8887
# Note: these only get called with left.dtype == right.dtype
89-
if isinstance(
90-
left, (np.ndarray, DatetimeIndexOpsMixin, ABCSeries, DatetimeLikeArrayMixin)
91-
):
92-
left = left.view("i8")
93-
if isinstance(
94-
right,
95-
(np.ndarray, DatetimeIndexOpsMixin, ABCSeries, DatetimeLikeArrayMixin),
96-
):
97-
right = right.view("i8")
88+
orig_left = left
89+
90+
left = left.view("i8")
91+
right = right.view("i8")
9892

9993
results = joinf(left, right)
10094
if with_indexers:
101-
# dtype should be timedelta64[ns] for TimedeltaIndex
102-
# and datetime64[ns] for DatetimeIndex
103-
dtype = cast(np.dtype, left.dtype).base
10495

10596
join_index, left_indexer, right_indexer = results
106-
join_index = join_index.view(dtype)
97+
if not isinstance(orig_left, np.ndarray):
98+
# When called from Index._intersection/_union, we have the EA
99+
join_index = join_index.view(orig_left._ndarray.dtype)
100+
join_index = orig_left._from_backing_data(join_index)
101+
107102
return join_index, left_indexer, right_indexer
108103
return results
109104

@@ -708,6 +703,8 @@ def _intersection(self, other: Index, sort=False) -> Index:
708703
# We need to invalidate the freq because Index._intersection
709704
# uses _shallow_copy on a view of self._data, which will preserve
710705
# self.freq if we're not careful.
706+
# At this point we should have result.dtype == self.dtype
707+
# and type(result) is type(self._data)
711708
result = self._wrap_setop_result(other, result)
712709
return result._with_freq(None)._with_freq("infer")
713710

pandas/core/indexes/range.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def equals(self, other: object) -> bool:
527527
# --------------------------------------------------------------------
528528
# Set Operations
529529

530-
def _intersection(self, other, sort=False):
530+
def _intersection(self, other: Index, sort=False):
531531

532532
if not isinstance(other, RangeIndex):
533533
# Int64Index
@@ -602,7 +602,7 @@ def _extended_gcd(self, a, b):
602602
old_t, t = t, old_t - quotient * t
603603
return old_r, old_s, old_t
604604

605-
def _union(self, other, sort):
605+
def _union(self, other: Index, sort):
606606
"""
607607
Form the union of two Index objects and sorts if possible
608608

pandas/tests/indexes/period/test_join.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_join_outer_indexer(self):
1616
pi = period_range("1/1/2000", "1/20/2000", freq="D")
1717

1818
result = pi._outer_indexer(pi._values, pi._values)
19-
tm.assert_numpy_array_equal(result[0], pi.asi8)
19+
tm.assert_extension_array_equal(result[0], pi._values)
2020
tm.assert_numpy_array_equal(result[1], np.arange(len(pi), dtype=np.int64))
2121
tm.assert_numpy_array_equal(result[2], np.arange(len(pi), dtype=np.int64))
2222

0 commit comments

Comments
 (0)