diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index b79905796f7cd..c315a5c03256c 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -49,7 +49,6 @@ ) from pandas.core.dtypes.generic import ( ABCExtensionArray, - ABCIndex, ABCIndexClass, ABCMultiIndex, ABCSeries, @@ -60,7 +59,7 @@ from pandas.core.indexers import validate_indices if TYPE_CHECKING: - from pandas import Categorical, DataFrame, Series + from pandas import Categorical, DataFrame, Index, Series _shared_docs: Dict[str, str] = {} @@ -69,7 +68,7 @@ # dtype access # # --------------- # def _ensure_data( - values, dtype: Optional[DtypeObj] = None + values: ArrayLike, dtype: Optional[DtypeObj] = None ) -> Tuple[np.ndarray, DtypeObj]: """ routine to ensure that our data is of the correct @@ -95,6 +94,12 @@ def _ensure_data( pandas_dtype : np.dtype or ExtensionDtype """ + if dtype is not None: + # We only have non-None dtype when called from `isin`, and + # both Datetimelike and Categorical dispatch before getting here. + assert not needs_i8_conversion(dtype) + assert not is_categorical_dtype(dtype) + if not isinstance(values, ABCMultiIndex): # extract_array would raise values = extract_array(values, extract_numpy=True) @@ -131,21 +136,20 @@ def _ensure_data( return ensure_object(values), np.dtype("object") # datetimelike - vals_dtype = getattr(values, "dtype", None) - if needs_i8_conversion(vals_dtype) or needs_i8_conversion(dtype): - if is_period_dtype(vals_dtype) or is_period_dtype(dtype): + if needs_i8_conversion(values.dtype) or needs_i8_conversion(dtype): + if is_period_dtype(values.dtype) or is_period_dtype(dtype): from pandas import PeriodIndex - values = PeriodIndex(values) + values = PeriodIndex(values)._data dtype = values.dtype - elif is_timedelta64_dtype(vals_dtype) or is_timedelta64_dtype(dtype): + elif is_timedelta64_dtype(values.dtype) or is_timedelta64_dtype(dtype): from pandas import TimedeltaIndex - values = TimedeltaIndex(values) + values = TimedeltaIndex(values)._data dtype = values.dtype else: # Datetime - if values.ndim > 1 and is_datetime64_ns_dtype(vals_dtype): + if values.ndim > 1 and is_datetime64_ns_dtype(values.dtype): # Avoid calling the DatetimeIndex constructor as it is 1D only # Note: this is reached by DataFrame.rank calls GH#27027 # TODO(EA2D): special case not needed with 2D EAs @@ -155,12 +159,12 @@ def _ensure_data( from pandas import DatetimeIndex - values = DatetimeIndex(values) + values = DatetimeIndex(values)._data dtype = values.dtype return values.asi8, dtype - elif is_categorical_dtype(vals_dtype) and ( + elif is_categorical_dtype(values.dtype) and ( is_categorical_dtype(dtype) or dtype is None ): values = values.codes @@ -237,11 +241,11 @@ def _ensure_arraylike(values): } -def _get_hashtable_algo(values): +def _get_hashtable_algo(values: np.ndarray): """ Parameters ---------- - values : arraylike + values : np.ndarray Returns ------- @@ -255,15 +259,15 @@ def _get_hashtable_algo(values): return htable, values -def _get_values_for_rank(values): +def _get_values_for_rank(values: ArrayLike): if is_categorical_dtype(values): - values = values._values_for_rank() + values = cast("Categorical", values)._values_for_rank() values, _ = _ensure_data(values) return values -def get_data_algo(values): +def get_data_algo(values: ArrayLike): values = _get_values_for_rank(values) ndtype = _check_object_for_strings(values) @@ -421,20 +425,28 @@ def isin(comps: AnyArrayLike, values: AnyArrayLike) -> np.ndarray: f"to isin(), you passed a [{type(values).__name__}]" ) - if not isinstance(values, (ABCIndex, ABCSeries, ABCExtensionArray, np.ndarray)): + if not isinstance( + values, (ABCIndexClass, ABCSeries, ABCExtensionArray, np.ndarray) + ): values = construct_1d_object_array_from_listlike(list(values)) # TODO: could use ensure_arraylike here + elif isinstance(values, ABCMultiIndex): + # Avoid raising in extract_array + values = np.array(values) comps = _ensure_arraylike(comps) comps = extract_array(comps, extract_numpy=True) - if is_categorical_dtype(comps): + if is_categorical_dtype(comps.dtype): # TODO(extension) # handle categoricals return cast("Categorical", comps).isin(values) - if needs_i8_conversion(comps): + if needs_i8_conversion(comps.dtype): # Dispatch to DatetimeLikeArrayMixin.isin return array(comps).isin(values) + elif needs_i8_conversion(values.dtype) and not is_object_dtype(comps.dtype): + # e.g. comps are integers and values are datetime64s + return np.zeros(comps.shape, dtype=bool) comps, dtype = _ensure_data(comps) values, _ = _ensure_data(values, dtype=dtype) @@ -474,7 +486,7 @@ def isin(comps: AnyArrayLike, values: AnyArrayLike) -> np.ndarray: def factorize_array( - values, na_sentinel: int = -1, size_hint=None, na_value=None, mask=None + values: np.ndarray, na_sentinel: int = -1, size_hint=None, na_value=None, mask=None ) -> Tuple[np.ndarray, np.ndarray]: """ Factorize an array-like to codes and uniques. @@ -540,7 +552,7 @@ def factorize( sort: bool = False, na_sentinel: Optional[int] = -1, size_hint: Optional[int] = None, -) -> Tuple[np.ndarray, Union[np.ndarray, ABCIndex]]: +) -> Tuple[np.ndarray, Union[np.ndarray, "Index"]]: """ Encode the object as an enumerated type or categorical variable. @@ -838,7 +850,7 @@ def value_counts_arraylike(values, dropna: bool): return keys, counts -def duplicated(values, keep="first") -> np.ndarray: +def duplicated(values: ArrayLike, keep: str = "first") -> np.ndarray: """ Return boolean ndarray denoting duplicate values. diff --git a/pandas/tests/test_algos.py b/pandas/tests/test_algos.py index c76369c213a70..89d0a6723c890 100644 --- a/pandas/tests/test_algos.py +++ b/pandas/tests/test_algos.py @@ -842,6 +842,27 @@ def test_i8(self): expected = np.array([True, True, False]) tm.assert_numpy_array_equal(result, expected) + @pytest.mark.parametrize("dtype1", ["m8[ns]", "M8[ns]", "M8[ns, UTC]", "period[D]"]) + @pytest.mark.parametrize("dtype", ["i8", "f8", "u8"]) + def test_isin_datetimelike_values_numeric_comps(self, dtype, dtype1): + # Anything but object and we get all-False shortcut + + dta = date_range("2013-01-01", periods=3)._values + if dtype1 == "period[D]": + # TODO: fix Series.view to get this on its own + arr = dta.to_period("D") + elif dtype1 == "M8[ns, UTC]": + # TODO: fix Series.view to get this on its own + arr = dta.tz_localize("UTC") + else: + arr = Series(dta.view("i8")).view(dtype1)._values + + comps = arr.view("i8").astype(dtype) + + result = algos.isin(comps, arr) + expected = np.zeros(comps.shape, dtype=bool) + tm.assert_numpy_array_equal(result, expected) + def test_large(self): s = date_range("20000101", periods=2000000, freq="s").values result = algos.isin(s, s[0:2])