Skip to content

Commit b62ad89

Browse files
committed
BUG: Fix replacing in string series with NA (pandas-dev#32621)
1 parent 0a76844 commit b62ad89

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

pandas/core/internals/managers.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99

1010
from pandas._libs import Timedelta, Timestamp, internals as libinternals, lib
11-
from pandas._typing import ArrayLike, DtypeObj, Label
11+
from pandas._typing import ArrayLike, DtypeObj, Label, Scalar
1212
from pandas.util._validators import validate_bool_kwarg
1313

1414
from pandas.core.dtypes.cast import (
@@ -1942,7 +1942,15 @@ def _compare_or_regex_search(a, b, regex=False):
19421942
mask : array_like of bool
19431943
"""
19441944

1945-
def _check(result, a, b):
1945+
def _check_comparison_types(
1946+
result: Union[ArrayLike, Scalar],
1947+
a: Union[ArrayLike, Scalar],
1948+
b: Union[ArrayLike, Scalar],
1949+
) -> Union[ArrayLike, Scalar]:
1950+
"""
1951+
Raises an error if the two arrays cannot be compared,
1952+
otherwise returns the comparison result as expected.
1953+
"""
19461954
if is_scalar(result) and (
19471955
isinstance(a, np.ndarray) or isinstance(b, np.ndarray)
19481956
):
@@ -1983,15 +1991,18 @@ def _check(result, a, b):
19831991

19841992
if is_datetimelike_v_numeric(a, b) or is_numeric_v_string_like(a, b):
19851993
# GH#29553 avoid deprecation warnings from numpy
1986-
return _check(False, a, b)
1987-
else:
1988-
result = op(a)
1989-
if isinstance(result, np.ndarray):
1990-
tmp = np.zeros(mask.shape, dtype=np.bool)
1991-
tmp[mask] = result
1992-
result = tmp
1994+
return _check_comparison_types(False, a, b)
1995+
1996+
result = op(a)
1997+
1998+
if isinstance(result, np.ndarray):
1999+
# The shape of the mask can differ to that of the result
2000+
# since we may compare only a subset of a's or b's elements
2001+
tmp = np.zeros(mask.shape, dtype=np.bool)
2002+
tmp[mask] = result
2003+
result = tmp
19932004

1994-
return _check(result, a, b)
2005+
return _check_comparison_types(result, a, b)
19952006

19962007

19972008
def _fast_count_smallints(arr):

0 commit comments

Comments
 (0)