Skip to content

BUG: Index.intersection casting to object instead of numeric #38122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,7 @@ Other
- Fixed metadata propagation in the :class:`Series.dt`, :class:`Series.str` accessors, :class:`DataFrame.duplicated`, :class:`DataFrame.stack`, :class:`DataFrame.unstack`, :class:`DataFrame.pivot`, :class:`DataFrame.append`, :class:`DataFrame.diff`, :class:`DataFrame.applymap` and :class:`DataFrame.update` methods (:issue:`28283`, :issue:`37381`)
- Fixed metadata propagation when selecting columns with ``DataFrame.__getitem__`` (:issue:`28283`)
- Bug in :meth:`Index.union` behaving differently depending on whether operand is an :class:`Index` or other list-like (:issue:`36384`)
- Bug in :meth:`Index.intersection` with non-matching numeric dtypes casting to ``object`` dtype instead of minimal common dtype (:issue:`38122`)
- Passing an array with 2 or more dimensions to the :class:`Series` constructor now raises the more specific ``ValueError`` rather than a bare ``Exception`` (:issue:`35744`)
- Bug in ``dir`` where ``dir(obj)`` wouldn't show attributes defined on the instance for pandas objects (:issue:`37173`)
- Bug in :meth:`RangeIndex.difference` returning :class:`Int64Index` in some cases where it should return :class:`RangeIndex` (:issue:`38028`)
Expand Down
6 changes: 4 additions & 2 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pandas.util._decorators import Appender, cache_readonly, doc

from pandas.core.dtypes.cast import (
find_common_type,
maybe_cast_to_integer_array,
validate_numeric_casting,
)
Expand Down Expand Up @@ -2826,8 +2827,9 @@ def intersection(self, other, sort=False):
return self._get_reconciled_name_object(other)

if not is_dtype_equal(self.dtype, other.dtype):
this = self.astype("O")
other = other.astype("O")
dtype = find_common_type([self.dtype, other.dtype])
this = self.astype(dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prob nbd but can do copy=False

other = other.astype(dtype)
return this.intersection(other, sort=sort)

result = self._intersection(other, sort=sort)
Expand Down
10 changes: 4 additions & 6 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3713,16 +3713,14 @@ def _convert_can_do_setop(self, other):
if not isinstance(other, Index):

if len(other) == 0:
other = MultiIndex(
levels=[[]] * self.nlevels,
codes=[[]] * self.nlevels,
verify_integrity=False,
)
return self[:0], self.names
else:
msg = "other must be a MultiIndex or a list of tuples"
try:
other = MultiIndex.from_tuples(other)
except TypeError as err:
except (ValueError, TypeError) as err:
# ValueError raised by tupels_to_object_array if we
# have non-object dtype
raise TypeError(msg) from err
else:
result_names = get_unanimous_names(self, other)
Expand Down
20 changes: 19 additions & 1 deletion pandas/tests/indexes/multi/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

import pandas as pd
from pandas import MultiIndex, Series
from pandas import Index, MultiIndex, Series
import pandas._testing as tm


Expand Down Expand Up @@ -294,6 +294,24 @@ def test_intersection(idx, sort):
# assert result.equals(tuples)


def test_intersection_non_object(idx, sort):
other = Index(range(3), name="foo")

result = idx.intersection(other, sort=sort)
expected = MultiIndex(levels=idx.levels, codes=[[]] * idx.nlevels, names=None)
tm.assert_index_equal(result, expected, exact=True)

# if we pass a length-0 ndarray (i.e. no name, we retain our idx.name)
result = idx.intersection(np.asarray(other)[:0], sort=sort)
expected = MultiIndex(levels=idx.levels, codes=[[]] * idx.nlevels, names=idx.names)
tm.assert_index_equal(result, expected, exact=True)

msg = "other must be a MultiIndex or a list of tuples"
with pytest.raises(TypeError, match=msg):
# With non-zero length non-index, we try and fail to convert to tuples
idx.intersection(np.asarray(other), sort=sort)


def test_intersect_equal_sort():
# GH-24959
idx = pd.MultiIndex.from_product([[1, 0], ["a", "b"]])
Expand Down
31 changes: 30 additions & 1 deletion pandas/tests/indexes/ranges/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,40 @@
import numpy as np
import pytest

from pandas import Index, Int64Index, RangeIndex
from pandas import Index, Int64Index, RangeIndex, UInt64Index
import pandas._testing as tm


class TestRangeIndexSetOps:
@pytest.mark.parametrize("klass", [RangeIndex, Int64Index, UInt64Index])
def test_intersection_mismatched_dtype(self, klass):
# check that we cast to float, not object
index = RangeIndex(start=0, stop=20, step=2, name="foo")
index = klass(index)

flt = index.astype(np.float64)

# bc index.equals(flt), we go through fastpath and get RangeIndex back
result = index.intersection(flt)
tm.assert_index_equal(result, index, exact=True)

result = flt.intersection(index)
tm.assert_index_equal(result, flt, exact=True)

# neither empty, not-equals
result = index.intersection(flt[1:])
tm.assert_index_equal(result, flt[1:], exact=True)

result = flt[1:].intersection(index)
tm.assert_index_equal(result, flt[1:], exact=True)

# empty other
result = index.intersection(flt[:0])
tm.assert_index_equal(result, flt[:0], exact=True)

result = flt[:0].intersection(index)
tm.assert_index_equal(result, flt[:0], exact=True)

def test_intersection(self, sort):
# intersect with Int64Index
index = RangeIndex(start=0, stop=20, step=2)
Expand Down