Skip to content

REF/BUG: ensure we get DTA/TDA/PA back from Index._intersection #40064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@
PandasObject,
)
import pandas.core.common as com
from pandas.core.construction import extract_array
from pandas.core.construction import (
ensure_wrapped_if_datetimelike,
extract_array,
)
from pandas.core.indexers import deprecate_ndim_indexing
from pandas.core.indexes.frozen import FrozenList
from pandas.core.ops import get_op_result_name
Expand Down Expand Up @@ -2912,7 +2915,7 @@ def union(self, other, sort=None):

return self._wrap_setop_result(other, result)

def _union(self, other, sort):
def _union(self, other: Index, sort):
"""
Specific union logic should go here. In subclasses, union behavior
should be overwritten here rather than in `self.union`.
Expand Down Expand Up @@ -3041,7 +3044,7 @@ def intersection(self, other, sort=False):
result = self._intersection(other, sort=sort)
return self._wrap_setop_result(other, result)

def _intersection(self, other, sort=False):
def _intersection(self, other: Index, sort=False):
"""
intersection specialized to the case with matching dtypes.
"""
Expand All @@ -3055,13 +3058,14 @@ def _intersection(self, other, sort=False):
except TypeError:
pass
else:
return algos.unique1d(result)
# TODO: algos.unique1d should preserve DTA/TDA
res = algos.unique1d(result)
return ensure_wrapped_if_datetimelike(res)

try:
indexer = other.get_indexer(lvals)
except (InvalidIndexError, IncompatibleFrequency):
except InvalidIndexError:
# InvalidIndexError raised by get_indexer if non-unique
# IncompatibleFrequency raised by PeriodIndex.get_indexer
indexer, _ = other.get_indexer_non_unique(lvals)

mask = indexer != -1
Expand Down
25 changes: 11 additions & 14 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
is_scalar,
)
from pandas.core.dtypes.concat import concat_compat
from pandas.core.dtypes.generic import ABCSeries

from pandas.core.arrays import (
DatetimeArray,
Expand Down Expand Up @@ -86,24 +85,20 @@ def _join_i8_wrapper(joinf, with_indexers: bool = True):
@staticmethod # type: ignore[misc]
def wrapper(left, right):
# Note: these only get called with left.dtype == right.dtype
if isinstance(
left, (np.ndarray, DatetimeIndexOpsMixin, ABCSeries, DatetimeLikeArrayMixin)
):
left = left.view("i8")
if isinstance(
right,
(np.ndarray, DatetimeIndexOpsMixin, ABCSeries, DatetimeLikeArrayMixin),
):
right = right.view("i8")
orig_left = left

left = left.view("i8")
right = right.view("i8")

results = joinf(left, right)
if with_indexers:
# dtype should be timedelta64[ns] for TimedeltaIndex
# and datetime64[ns] for DatetimeIndex
dtype = cast(np.dtype, left.dtype).base

join_index, left_indexer, right_indexer = results
join_index = join_index.view(dtype)
if not isinstance(orig_left, np.ndarray):
# When called from Index._intersection/_union, we have the EA
join_index = join_index.view(orig_left._ndarray.dtype)
join_index = orig_left._from_backing_data(join_index)

return join_index, left_indexer, right_indexer
return results

Expand Down Expand Up @@ -708,6 +703,8 @@ def _intersection(self, other: Index, sort=False) -> Index:
# We need to invalidate the freq because Index._intersection
# uses _shallow_copy on a view of self._data, which will preserve
# self.freq if we're not careful.
# At this point we should have result.dtype == self.dtype
# and type(result) is type(self._data)
result = self._wrap_setop_result(other, result)
return result._with_freq(None)._with_freq("infer")

Expand Down
4 changes: 2 additions & 2 deletions pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def equals(self, other: object) -> bool:
# --------------------------------------------------------------------
# Set Operations

def _intersection(self, other, sort=False):
def _intersection(self, other: Index, sort=False):

if not isinstance(other, RangeIndex):
# Int64Index
Expand Down Expand Up @@ -602,7 +602,7 @@ def _extended_gcd(self, a, b):
old_t, t = t, old_t - quotient * t
return old_r, old_s, old_t

def _union(self, other, sort):
def _union(self, other: Index, sort):
"""
Form the union of two Index objects and sorts if possible

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/indexes/period/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_join_outer_indexer(self):
pi = period_range("1/1/2000", "1/20/2000", freq="D")

result = pi._outer_indexer(pi._values, pi._values)
tm.assert_numpy_array_equal(result[0], pi.asi8)
tm.assert_extension_array_equal(result[0], pi._values)
tm.assert_numpy_array_equal(result[1], np.arange(len(pi), dtype=np.int64))
tm.assert_numpy_array_equal(result[2], np.arange(len(pi), dtype=np.int64))

Expand Down