Skip to content

String dtype: propagate NaNs as False in predicate methods (eg .str.startswith) #59616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
9620e00
String dtype: propagate NaNs as False in predicate methods (eg .str.s…
jorisvandenbossche Aug 26, 2024
b06764e
use no_default for ArrowEA._str_endswith as well
jorisvandenbossche Aug 26, 2024
b235735
update type annotations
jorisvandenbossche Aug 26, 2024
562118e
update docstrings
jorisvandenbossche Aug 26, 2024
ef05ade
more type annotations
jorisvandenbossche Aug 27, 2024
7d2a746
Merge remote-tracking branch 'upstream/main' into string-dtype-predic…
jorisvandenbossche Aug 27, 2024
b9612fc
test and fix startswith/endswith
jorisvandenbossche Aug 27, 2024
cf242a2
test ismethods
jorisvandenbossche Aug 27, 2024
f9ffff7
Merge remote-tracking branch 'upstream/main' into string-dtype-predic…
jorisvandenbossche Aug 31, 2024
ad0d6e1
fix warnings
jorisvandenbossche Aug 31, 2024
bf02000
try fix typing
jorisvandenbossche Sep 2, 2024
b650064
Merge remote-tracking branch 'upstream/main' into string-dtype-predic…
jorisvandenbossche Sep 6, 2024
377ff3a
follow same behaviour for categorical[str]
jorisvandenbossche Sep 6, 2024
2dfd50b
simplify fill_null calls for string[pyarrow] case
jorisvandenbossche Sep 6, 2024
adf2b99
fix na_value handling for categorical case + update tests for expecte…
jorisvandenbossche Sep 6, 2024
ddd531a
fix typing + fix conversion for old pyarrow
jorisvandenbossche Sep 6, 2024
e401b55
Merge remote-tracking branch 'upstream/main' into string-dtype-predic…
jorisvandenbossche Sep 16, 2024
3cb1b55
Merge remote-tracking branch 'upstream/main' into string-dtype-predic…
jorisvandenbossche Oct 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@

import numpy as np

from pandas._libs import lib
from pandas.compat import (
pa_version_under10p1,
pa_version_under11p0,
pa_version_under13p0,
pa_version_under17p0,
)

from pandas.core.dtypes.missing import isna

if not pa_version_under10p1:
import pyarrow as pa
import pyarrow.compute as pc
Expand All @@ -38,7 +37,7 @@ class ArrowStringArrayMixin:
def __init__(self, *args, **kwargs) -> None:
raise NotImplementedError

def _convert_bool_result(self, result):
def _convert_bool_result(self, result, na=lib.no_default, method_name=None):
# Convert a bool-dtype result to the appropriate result type
raise NotImplementedError

Expand Down Expand Up @@ -212,7 +211,9 @@ def _str_removesuffix(self, suffix: str):
result = pc.if_else(ends_with, removed, self._pa_array)
return type(self)(result)

def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
def _str_startswith(
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
):
if isinstance(pat, str):
result = pc.starts_with(self._pa_array, pattern=pat)
else:
Expand All @@ -225,11 +226,11 @@ def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):

for p in pat[1:]:
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)
return self._convert_bool_result(result, na=na, method_name="startswith")

def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
def _str_endswith(
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
):
if isinstance(pat, str):
result = pc.ends_with(self._pa_array, pattern=pat)
else:
Expand All @@ -242,9 +243,7 @@ def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):

for p in pat[1:]:
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)
return self._convert_bool_result(result, na=na, method_name="endswith")

def _str_isalnum(self):
result = pc.utf8_is_alnum(self._pa_array)
Expand Down Expand Up @@ -283,7 +282,12 @@ def _str_isupper(self):
return self._convert_bool_result(result)

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
self,
pat,
case: bool = True,
flags: int = 0,
na: Scalar | lib.NoDefault = lib.no_default,
regex: bool = True,
):
if flags:
raise NotImplementedError(f"contains not implemented with {flags=}")
Expand All @@ -293,19 +297,25 @@ def _str_contains(
else:
pa_contains = pc.match_substring
result = pa_contains(self._pa_array, pat, ignore_case=not case)
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)
return self._convert_bool_result(result, na=na, method_name="contains")

def _str_match(
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
self,
pat: str,
case: bool = True,
flags: int = 0,
na: Scalar | lib.NoDefault = lib.no_default,
):
if not pat.startswith("^"):
pat = f"^{pat}"
return self._str_contains(pat, case, flags, na, regex=True)

def _str_fullmatch(
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
self,
pat,
case: bool = True,
flags: int = 0,
na: Scalar | lib.NoDefault = lib.no_default,
):
if not pat.endswith("$") or pat.endswith("\\$"):
pat = f"{pat}$"
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,7 +2318,9 @@ def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
for chunk in self._pa_array.iterchunks()
]

def _convert_bool_result(self, result):
def _convert_bool_result(self, result, na=lib.no_default, method_name=None):
if na is not lib.no_default and not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return type(self)(result)

def _convert_int_result(self, result):
Expand Down
20 changes: 16 additions & 4 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2679,16 +2679,28 @@ def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
# ------------------------------------------------------------------------
# String methods interface
def _str_map(
self, f, na_value=np.nan, dtype=np.dtype("object"), convert: bool = True
self, f, na_value=lib.no_default, dtype=np.dtype("object"), convert: bool = True
):
# Optimization to apply the callable `f` to the categories once
# and rebuild the result by `take`ing from the result with the codes.
# Returns the same type as the object-dtype implementation though.
from pandas.core.arrays import NumpyExtensionArray

categories = self.categories
codes = self.codes
result = NumpyExtensionArray(categories.to_numpy())._str_map(f, na_value, dtype)
if categories.dtype == "string":
result = categories.array._str_map(f, na_value, dtype) # type: ignore[attr-defined]
if (
categories.dtype.na_value is np.nan # type: ignore[union-attr]
and is_bool_dtype(dtype)
and (na_value is lib.no_default or isna(na_value))
):
# NaN propagates as False for functions with boolean return type
na_value = False
else:
from pandas.core.arrays import NumpyExtensionArray

result = NumpyExtensionArray(categories.to_numpy())._str_map(
f, na_value, dtype
)
return take_nd(result, codes, fill_value=na_value)

def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
Expand Down
31 changes: 20 additions & 11 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,11 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
return cls._from_sequence(scalars, dtype=dtype)

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
self,
f,
na_value=lib.no_default,
dtype: Dtype | None = None,
convert: bool = True,
):
if self.dtype.na_value is np.nan:
return self._str_map_nan_semantics(f, na_value=na_value, dtype=dtype)
Expand All @@ -390,7 +394,7 @@ def _str_map(

if dtype is None:
dtype = self.dtype
if na_value is None:
if na_value is lib.no_default:
na_value = self.dtype.na_value

mask = isna(self)
Expand Down Expand Up @@ -459,11 +463,17 @@ def _str_map_str_or_object(
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))

def _str_map_nan_semantics(self, f, na_value=None, dtype: Dtype | None = None):
def _str_map_nan_semantics(
self, f, na_value=lib.no_default, dtype: Dtype | None = None
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value
if na_value is lib.no_default:
if is_bool_dtype(dtype):
# NaN propagates as False
na_value = False
else:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)
Expand All @@ -474,7 +484,8 @@ def _str_map_nan_semantics(self, f, na_value=None, dtype: Dtype | None = None):
if is_integer_dtype(dtype):
na_value = 0
else:
na_value = True
# NaN propagates as False
na_value = False

result = lib.map_infer_mask(
arr,
Expand All @@ -484,15 +495,13 @@ def _str_map_nan_semantics(self, f, na_value=None, dtype: Dtype | None = None):
na_value=na_value,
dtype=np.dtype(cast(type, dtype)),
)
if na_value_is_na and mask.any():
if na_value_is_na and is_integer_dtype(dtype) and mask.any():
# TODO: we could alternatively do this check before map_infer_mask
# and adjust the dtype/na_value we pass there. Which is more
# performant?
if is_integer_dtype(dtype):
result = result.astype("float64")
else:
result = result.astype("object")
result = result.astype("float64")
result[mask] = np.nan

return result

else:
Expand Down
40 changes: 26 additions & 14 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,27 @@ def insert(self, loc: int, item) -> ArrowStringArray:
raise TypeError("Scalar must be NA or str")
return super().insert(loc, item)

def _convert_bool_result(self, values):
def _convert_bool_result(self, values, na=lib.no_default, method_name=None):
if na is not lib.no_default and not isna(na) and not isinstance(na, bool):
# GH#59561
warnings.warn(
f"Allowing a non-bool 'na' in obj.str.{method_name} is deprecated "
"and will raise in a future version.",
FutureWarning,
stacklevel=find_stack_level(),
)
na = bool(na)

if self.dtype.na_value is np.nan:
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
if na is lib.no_default or isna(na):
# NaN propagates as False
values = values.fill_null(False)
else:
values = values.fill_null(na)
return values.to_numpy()
else:
if na is not lib.no_default and not isna(na): # pyright: ignore [reportGeneralTypeIssues]
values = values.fill_null(na)
return BooleanDtype().__from_arrow__(values)

def _maybe_convert_setitem_value(self, value):
Expand Down Expand Up @@ -306,22 +324,16 @@ def astype(self, dtype, copy: bool = True):
_str_slice = ArrowStringArrayMixin._str_slice

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
self,
pat,
case: bool = True,
flags: int = 0,
na=lib.no_default,
regex: bool = True,
):
if flags:
return super()._str_contains(pat, case, flags, na, regex)

if not isna(na):
if not isinstance(na, bool):
# GH#59561
warnings.warn(
"Allowing a non-bool 'na' in obj.str.contains is deprecated "
"and will raise in a future version.",
FutureWarning,
stacklevel=find_stack_level(),
)
na = bool(na)

return ArrowStringArrayMixin._str_contains(self, pat, case, flags, na, regex)

def _str_replace(
Expand Down
40 changes: 25 additions & 15 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,12 @@ def join(self, sep: str):

@forbid_nonstring_types(["bytes"])
def contains(
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
self,
pat,
case: bool = True,
flags: int = 0,
na=lib.no_default,
regex: bool = True,
):
r"""
Test if pattern or regex is contained within a string of a Series or Index.
Expand All @@ -1243,8 +1248,9 @@ def contains(
Flags to pass through to the re module, e.g. re.IGNORECASE.
na : scalar, optional
Fill value for missing values. The default depends on dtype of the
array. For object-dtype, ``numpy.nan`` is used. For ``StringDtype``,
``pandas.NA`` is used.
array. For object-dtype, ``numpy.nan`` is used. For the nullable
``StringDtype``, ``pandas.NA`` is used. For the ``"str"`` dtype,
``False`` is used.
regex : bool, default True
If True, assumes the pat is a regular expression.

Expand Down Expand Up @@ -1362,7 +1368,7 @@ def contains(
return self._wrap_result(result, fill_value=na, returns_string=False)

@forbid_nonstring_types(["bytes"])
def match(self, pat: str, case: bool = True, flags: int = 0, na=None):
def match(self, pat: str, case: bool = True, flags: int = 0, na=lib.no_default):
"""
Determine if each string starts with a match of a regular expression.

Expand All @@ -1376,8 +1382,9 @@ def match(self, pat: str, case: bool = True, flags: int = 0, na=None):
Regex module flags, e.g. re.IGNORECASE.
na : scalar, optional
Fill value for missing values. The default depends on dtype of the
array. For object-dtype, ``numpy.nan`` is used. For ``StringDtype``,
``pandas.NA`` is used.
array. For object-dtype, ``numpy.nan`` is used. For the nullable
``StringDtype``, ``pandas.NA`` is used. For the ``"str"`` dtype,
``False`` is used.

Returns
-------
Expand Down Expand Up @@ -1406,7 +1413,7 @@ def match(self, pat: str, case: bool = True, flags: int = 0, na=None):
return self._wrap_result(result, fill_value=na, returns_string=False)

@forbid_nonstring_types(["bytes"])
def fullmatch(self, pat, case: bool = True, flags: int = 0, na=None):
def fullmatch(self, pat, case: bool = True, flags: int = 0, na=lib.no_default):
"""
Determine if each string entirely matches a regular expression.

Expand All @@ -1420,8 +1427,9 @@ def fullmatch(self, pat, case: bool = True, flags: int = 0, na=None):
Regex module flags, e.g. re.IGNORECASE.
na : scalar, optional
Fill value for missing values. The default depends on dtype of the
array. For object-dtype, ``numpy.nan`` is used. For ``StringDtype``,
``pandas.NA`` is used.
array. For object-dtype, ``numpy.nan`` is used. For the nullable
``StringDtype``, ``pandas.NA`` is used. For the ``"str"`` dtype,
``False`` is used.

Returns
-------
Expand Down Expand Up @@ -2612,7 +2620,7 @@ def count(self, pat, flags: int = 0):

@forbid_nonstring_types(["bytes"])
def startswith(
self, pat: str | tuple[str, ...], na: Scalar | None = None
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
) -> Series | Index:
"""
Test if the start of each string element matches a pattern.
Expand All @@ -2624,10 +2632,11 @@ def startswith(
pat : str or tuple[str, ...]
Character sequence or tuple of strings. Regular expressions are not
accepted.
na : object, default NaN
na : scalar, optional
Object shown if element tested is not a string. The default depends
on dtype of the array. For object-dtype, ``numpy.nan`` is used.
For ``StringDtype``, ``pandas.NA`` is used.
For the nullable ``StringDtype``, ``pandas.NA`` is used.
For the ``"str"`` dtype, ``False`` is used.

Returns
-------
Expand Down Expand Up @@ -2682,7 +2691,7 @@ def startswith(

@forbid_nonstring_types(["bytes"])
def endswith(
self, pat: str | tuple[str, ...], na: Scalar | None = None
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
) -> Series | Index:
"""
Test if the end of each string element matches a pattern.
Expand All @@ -2694,10 +2703,11 @@ def endswith(
pat : str or tuple[str, ...]
Character sequence or tuple of strings. Regular expressions are not
accepted.
na : object, default NaN
na : scalar, optional
Object shown if element tested is not a string. The default depends
on dtype of the array. For object-dtype, ``numpy.nan`` is used.
For ``StringDtype``, ``pandas.NA`` is used.
For the nullable ``StringDtype``, ``pandas.NA`` is used.
For the ``"str"`` dtype, ``False`` is used.

Returns
-------
Expand Down
Loading
Loading