Skip to content

[backport 2.3.x] String dtype: propagate NaNs as False in predicate methods (eg .str.startswith) (#59616) #60014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@

import numpy as np

from pandas._libs import lib
from pandas.compat import (
pa_version_under10p1,
pa_version_under11p0,
pa_version_under13p0,
pa_version_under17p0,
)

from pandas.core.dtypes.missing import isna

if not pa_version_under10p1:
import pyarrow as pa
import pyarrow.compute as pc
Expand All @@ -38,7 +37,7 @@ class ArrowStringArrayMixin:
def __init__(self, *args, **kwargs) -> None:
raise NotImplementedError

def _convert_bool_result(self, result):
def _convert_bool_result(self, result, na=lib.no_default, method_name=None):
# Convert a bool-dtype result to the appropriate result type
raise NotImplementedError

Expand Down Expand Up @@ -212,7 +211,9 @@ def _str_removesuffix(self, suffix: str):
result = pc.if_else(ends_with, removed, self._pa_array)
return type(self)(result)

def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
def _str_startswith(
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
):
if isinstance(pat, str):
result = pc.starts_with(self._pa_array, pattern=pat)
else:
Expand All @@ -225,11 +226,11 @@ def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):

for p in pat[1:]:
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)
return self._convert_bool_result(result, na=na, method_name="startswith")

def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
def _str_endswith(
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
):
if isinstance(pat, str):
result = pc.ends_with(self._pa_array, pattern=pat)
else:
Expand All @@ -242,9 +243,7 @@ def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):

for p in pat[1:]:
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)
return self._convert_bool_result(result, na=na, method_name="endswith")

def _str_isalnum(self):
result = pc.utf8_is_alnum(self._pa_array)
Expand Down Expand Up @@ -283,7 +282,12 @@ def _str_isupper(self):
return self._convert_bool_result(result)

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
self,
pat,
case: bool = True,
flags: int = 0,
na: Scalar | lib.NoDefault = lib.no_default,
regex: bool = True,
):
if flags:
raise NotImplementedError(f"contains not implemented with {flags=}")
Expand All @@ -293,19 +297,25 @@ def _str_contains(
else:
pa_contains = pc.match_substring
result = pa_contains(self._pa_array, pat, ignore_case=not case)
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)
return self._convert_bool_result(result, na=na, method_name="contains")

def _str_match(
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
self,
pat: str,
case: bool = True,
flags: int = 0,
na: Scalar | lib.NoDefault = lib.no_default,
):
if not pat.startswith("^"):
pat = f"^{pat}"
return self._str_contains(pat, case, flags, na, regex=True)

def _str_fullmatch(
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
self,
pat,
case: bool = True,
flags: int = 0,
na: Scalar | lib.NoDefault = lib.no_default,
):
if not pat.endswith("$") or pat.endswith("\\$"):
pat = f"{pat}$"
Expand Down
6 changes: 5 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2285,7 +2285,11 @@ def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
for chunk in self._pa_array.iterchunks()
]

def _convert_bool_result(self, result):
def _convert_bool_result(self, result, na=lib.no_default, method_name=None):
if na is not lib.no_default and not isna(
na
): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return type(self)(result)

def _convert_int_result(self, result):
Expand Down
20 changes: 16 additions & 4 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2675,16 +2675,28 @@ def _replace(self, *, to_replace, value, inplace: bool = False):
# ------------------------------------------------------------------------
# String methods interface
def _str_map(
self, f, na_value=np.nan, dtype=np.dtype("object"), convert: bool = True
self, f, na_value=lib.no_default, dtype=np.dtype("object"), convert: bool = True
):
# Optimization to apply the callable `f` to the categories once
# and rebuild the result by `take`ing from the result with the codes.
# Returns the same type as the object-dtype implementation though.
from pandas.core.arrays import NumpyExtensionArray

categories = self.categories
codes = self.codes
result = NumpyExtensionArray(categories.to_numpy())._str_map(f, na_value, dtype)
if categories.dtype == "string":
result = categories.array._str_map(f, na_value, dtype) # type: ignore[attr-defined]
if (
categories.dtype.na_value is np.nan # type: ignore[union-attr]
and is_bool_dtype(dtype)
and (na_value is lib.no_default or isna(na_value))
):
# NaN propagates as False for functions with boolean return type
na_value = False
else:
from pandas.core.arrays import NumpyExtensionArray

result = NumpyExtensionArray(categories.to_numpy())._str_map(
f, na_value, dtype
)
return take_nd(result, codes, fill_value=na_value)

def _str_get_dummies(self, sep: str = "|"):
Expand Down
33 changes: 22 additions & 11 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,11 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
return cls._from_sequence(scalars, dtype=dtype)

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
self,
f,
na_value=lib.no_default,
dtype: Dtype | None = None,
convert: bool = True,
):
if self.dtype.na_value is np.nan:
return self._str_map_nan_semantics(
Expand All @@ -388,7 +392,7 @@ def _str_map(

if dtype is None:
dtype = self.dtype
if na_value is None:
if na_value is lib.no_default:
na_value = self.dtype.na_value

mask = isna(self)
Expand Down Expand Up @@ -458,12 +462,20 @@ def _str_map_str_or_object(
return lib.map_infer_mask(arr, f, mask.view("uint8"))

def _str_map_nan_semantics(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
self,
f,
na_value=lib.no_default,
dtype: Dtype | None = None,
convert: bool = True,
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value
if na_value is lib.no_default:
if is_bool_dtype(dtype):
# NaN propagates as False
na_value = False
else:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)
Expand All @@ -474,7 +486,8 @@ def _str_map_nan_semantics(
if is_integer_dtype(dtype):
na_value = 0
else:
na_value = True
# NaN propagates as False
na_value = False

result = lib.map_infer_mask(
arr,
Expand All @@ -484,15 +497,13 @@ def _str_map_nan_semantics(
na_value=na_value,
dtype=np.dtype(cast(type, dtype)),
)
if na_value_is_na and mask.any():
if na_value_is_na and is_integer_dtype(dtype) and mask.any():
# TODO: we could alternatively do this check before map_infer_mask
# and adjust the dtype/na_value we pass there. Which is more
# performant?
if is_integer_dtype(dtype):
result = result.astype("float64")
else:
result = result.astype("object")
result = result.astype("float64")
result[mask] = np.nan

return result

else:
Expand Down
42 changes: 28 additions & 14 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,29 @@ def insert(self, loc: int, item) -> ArrowStringArray:
raise TypeError("Scalar must be NA or str")
return super().insert(loc, item)

def _convert_bool_result(self, values):
def _convert_bool_result(self, values, na=lib.no_default, method_name=None):
if na is not lib.no_default and not isna(na) and not isinstance(na, bool):
# GH#59561
warnings.warn(
f"Allowing a non-bool 'na' in obj.str.{method_name} is deprecated "
"and will raise in a future version.",
FutureWarning,
stacklevel=find_stack_level(),
)
na = bool(na)

if self.dtype.na_value is np.nan:
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
if na is lib.no_default or isna(na):
# NaN propagates as False
values = values.fill_null(False)
else:
values = values.fill_null(na)
return values.to_numpy()
else:
if na is not lib.no_default and not isna(
na
): # pyright: ignore [reportGeneralTypeIssues]
values = values.fill_null(na)
return BooleanDtype().__from_arrow__(values)

def _maybe_convert_setitem_value(self, value):
Expand Down Expand Up @@ -309,22 +329,16 @@ def _data(self):
_str_slice = ArrowStringArrayMixin._str_slice

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
self,
pat,
case: bool = True,
flags: int = 0,
na=lib.no_default,
regex: bool = True,
):
if flags:
return super()._str_contains(pat, case, flags, na, regex)

if not isna(na):
if not isinstance(na, bool):
# GH#59561
warnings.warn(
"Allowing a non-bool 'na' in obj.str.contains is deprecated "
"and will raise in a future version.",
FutureWarning,
stacklevel=find_stack_level(),
)
na = bool(na)

return ArrowStringArrayMixin._str_contains(self, pat, case, flags, na, regex)

def _str_replace(
Expand Down
Loading
Loading