|
8 | 8 | import numpy as np
|
9 | 9 |
|
10 | 10 | from pandas._libs import Timedelta, Timestamp, internals as libinternals, lib
|
11 |
| -from pandas._typing import ArrayLike, DtypeObj, Label |
| 11 | +from pandas._typing import ArrayLike, DtypeObj, Label, Scalar |
12 | 12 | from pandas.util._validators import validate_bool_kwarg
|
13 | 13 |
|
14 | 14 | from pandas.core.dtypes.cast import (
|
@@ -1942,7 +1942,15 @@ def _compare_or_regex_search(a, b, regex=False):
|
1942 | 1942 | mask : array_like of bool
|
1943 | 1943 | """
|
1944 | 1944 |
|
1945 |
| - def _check(result, a, b): |
| 1945 | + def _check_comparison_types( |
| 1946 | + result: Union[ArrayLike, Scalar], |
| 1947 | + a: Union[ArrayLike, Scalar], |
| 1948 | + b: Union[ArrayLike, Scalar], |
| 1949 | + ) -> Union[ArrayLike, Scalar]: |
| 1950 | + """ |
| 1951 | + Raises an error if the two arrays cannot be compared, |
| 1952 | + otherwise returns the comparison result as expected. |
| 1953 | + """ |
1946 | 1954 | if is_scalar(result) and (
|
1947 | 1955 | isinstance(a, np.ndarray) or isinstance(b, np.ndarray)
|
1948 | 1956 | ):
|
@@ -1983,15 +1991,18 @@ def _check(result, a, b):
|
1983 | 1991 |
|
1984 | 1992 | if is_datetimelike_v_numeric(a, b) or is_numeric_v_string_like(a, b):
|
1985 | 1993 | # GH#29553 avoid deprecation warnings from numpy
|
1986 |
| - return _check(False, a, b) |
1987 |
| - else: |
1988 |
| - result = op(a) |
1989 |
| - if isinstance(result, np.ndarray): |
1990 |
| - tmp = np.zeros(mask.shape, dtype=np.bool) |
1991 |
| - tmp[mask] = result |
1992 |
| - result = tmp |
| 1994 | + return _check_comparison_types(False, a, b) |
| 1995 | + |
| 1996 | + result = op(a) |
| 1997 | + |
| 1998 | + if isinstance(result, np.ndarray): |
| 1999 | + # The shape of the mask can differ to that of the result |
| 2000 | + # since we may compare only a subset of a's or b's elements |
| 2001 | + tmp = np.zeros(mask.shape, dtype=np.bool) |
| 2002 | + tmp[mask] = result |
| 2003 | + result = tmp |
1993 | 2004 |
|
1994 |
| - return _check(result, a, b) |
| 2005 | + return _check_comparison_types(result, a, b) |
1995 | 2006 |
|
1996 | 2007 |
|
1997 | 2008 | def _fast_count_smallints(arr):
|
|
0 commit comments