From d550b4b53fca661d02206ba508370db1dfa10bc7 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 25 Feb 2021 11:00:28 -0800 Subject: [PATCH] REF/BUG: ensure we get DTA/TDA/PA back from Index._intersection --- pandas/core/indexes/base.py | 16 +++++++++------ pandas/core/indexes/datetimelike.py | 25 +++++++++++------------- pandas/core/indexes/range.py | 4 ++-- pandas/tests/indexes/period/test_join.py | 2 +- 4 files changed, 24 insertions(+), 23 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index e633d6b28a8c5..b9f340d15b566 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -141,7 +141,10 @@ PandasObject, ) import pandas.core.common as com -from pandas.core.construction import extract_array +from pandas.core.construction import ( + ensure_wrapped_if_datetimelike, + extract_array, +) from pandas.core.indexers import deprecate_ndim_indexing from pandas.core.indexes.frozen import FrozenList from pandas.core.ops import get_op_result_name @@ -2912,7 +2915,7 @@ def union(self, other, sort=None): return self._wrap_setop_result(other, result) - def _union(self, other, sort): + def _union(self, other: Index, sort): """ Specific union logic should go here. In subclasses, union behavior should be overwritten here rather than in `self.union`. @@ -3041,7 +3044,7 @@ def intersection(self, other, sort=False): result = self._intersection(other, sort=sort) return self._wrap_setop_result(other, result) - def _intersection(self, other, sort=False): + def _intersection(self, other: Index, sort=False): """ intersection specialized to the case with matching dtypes. """ @@ -3055,13 +3058,14 @@ def _intersection(self, other, sort=False): except TypeError: pass else: - return algos.unique1d(result) + # TODO: algos.unique1d should preserve DTA/TDA + res = algos.unique1d(result) + return ensure_wrapped_if_datetimelike(res) try: indexer = other.get_indexer(lvals) - except (InvalidIndexError, IncompatibleFrequency): + except InvalidIndexError: # InvalidIndexError raised by get_indexer if non-unique - # IncompatibleFrequency raised by PeriodIndex.get_indexer indexer, _ = other.get_indexer_non_unique(lvals) mask = indexer != -1 diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 6d5992540ef49..29df20c609a4f 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -47,7 +47,6 @@ is_scalar, ) from pandas.core.dtypes.concat import concat_compat -from pandas.core.dtypes.generic import ABCSeries from pandas.core.arrays import ( DatetimeArray, @@ -86,24 +85,20 @@ def _join_i8_wrapper(joinf, with_indexers: bool = True): @staticmethod # type: ignore[misc] def wrapper(left, right): # Note: these only get called with left.dtype == right.dtype - if isinstance( - left, (np.ndarray, DatetimeIndexOpsMixin, ABCSeries, DatetimeLikeArrayMixin) - ): - left = left.view("i8") - if isinstance( - right, - (np.ndarray, DatetimeIndexOpsMixin, ABCSeries, DatetimeLikeArrayMixin), - ): - right = right.view("i8") + orig_left = left + + left = left.view("i8") + right = right.view("i8") results = joinf(left, right) if with_indexers: - # dtype should be timedelta64[ns] for TimedeltaIndex - # and datetime64[ns] for DatetimeIndex - dtype = cast(np.dtype, left.dtype).base join_index, left_indexer, right_indexer = results - join_index = join_index.view(dtype) + if not isinstance(orig_left, np.ndarray): + # When called from Index._intersection/_union, we have the EA + join_index = join_index.view(orig_left._ndarray.dtype) + join_index = orig_left._from_backing_data(join_index) + return join_index, left_indexer, right_indexer return results @@ -708,6 +703,8 @@ def _intersection(self, other: Index, sort=False) -> Index: # We need to invalidate the freq because Index._intersection # uses _shallow_copy on a view of self._data, which will preserve # self.freq if we're not careful. + # At this point we should have result.dtype == self.dtype + # and type(result) is type(self._data) result = self._wrap_setop_result(other, result) return result._with_freq(None)._with_freq("infer") diff --git a/pandas/core/indexes/range.py b/pandas/core/indexes/range.py index a0f546a6bd748..67af4b628e552 100644 --- a/pandas/core/indexes/range.py +++ b/pandas/core/indexes/range.py @@ -527,7 +527,7 @@ def equals(self, other: object) -> bool: # -------------------------------------------------------------------- # Set Operations - def _intersection(self, other, sort=False): + def _intersection(self, other: Index, sort=False): if not isinstance(other, RangeIndex): # Int64Index @@ -602,7 +602,7 @@ def _extended_gcd(self, a, b): old_t, t = t, old_t - quotient * t return old_r, old_s, old_t - def _union(self, other, sort): + def _union(self, other: Index, sort): """ Form the union of two Index objects and sorts if possible diff --git a/pandas/tests/indexes/period/test_join.py b/pandas/tests/indexes/period/test_join.py index 2f16daa36d1fd..aa2393aceee52 100644 --- a/pandas/tests/indexes/period/test_join.py +++ b/pandas/tests/indexes/period/test_join.py @@ -16,7 +16,7 @@ def test_join_outer_indexer(self): pi = period_range("1/1/2000", "1/20/2000", freq="D") result = pi._outer_indexer(pi._values, pi._values) - tm.assert_numpy_array_equal(result[0], pi.asi8) + tm.assert_extension_array_equal(result[0], pi._values) tm.assert_numpy_array_equal(result[1], np.arange(len(pi), dtype=np.int64)) tm.assert_numpy_array_equal(result[2], np.arange(len(pi), dtype=np.int64))