Skip to content

REF (string): de-duplicate ArrowStringArray methods (2) #59556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,54 @@ class ArrowStringArrayMixin:
def __init__(self, *args, **kwargs) -> None:
raise NotImplementedError

def _result_converter(self, result: pa.Array, na=None):
# Convert bool-dtype results to the appropriate output type
raise NotImplementedError

def _str_isalnum(self) -> Self:
result = pc.utf8_is_alnum(self._pa_array)
return self._result_converter(result)

def _str_isalpha(self):
result = pc.utf8_is_alpha(self._pa_array)
return self._result_converter(result)

def _str_isdecimal(self):
result = pc.utf8_is_decimal(self._pa_array)
return self._result_converter(result)

def _str_isdigit(self):
result = pc.utf8_is_digit(self._pa_array)
return self._result_converter(result)

def _str_islower(self):
result = pc.utf8_is_lower(self._pa_array)
return self._result_converter(result)

def _str_isnumeric(self):
result = pc.utf8_is_numeric(self._pa_array)
return self._result_converter(result)

def _str_isspace(self):
result = pc.utf8_is_space(self._pa_array)
return self._result_converter(result)

def _str_istitle(self):
result = pc.utf8_is_title(self._pa_array)
return self._result_converter(result)

def _str_isupper(self):
result = pc.utf8_is_upper(self._pa_array)
return self._result_converter(result)

def _convert_int_dtype(self, result):
# Convert int-dtype results to the appropriate output type
raise NotImplementedError

def _str_len(self):
result = pc.utf8_length(self._pa_array)
return self._convert_int_dtype(result)

def _str_pad(
self,
width: int,
Expand Down
41 changes: 9 additions & 32 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1972,7 +1972,7 @@ def _rank(
"""
See Series.rank.__doc__.
"""
return type(self)(
return self._convert_int_dtype(
self._rank_calc(
axis=axis,
method=method,
Expand Down Expand Up @@ -2288,7 +2288,14 @@ def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
def _str_count(self, pat: str, flags: int = 0) -> Self:
if flags:
raise NotImplementedError(f"count not implemented with {flags=}")
return type(self)(pc.count_substring_regex(self._pa_array, pat))
result = pc.count_substring_regex(self._pa_array, pat)
return self._convert_int_dtype(result)

def _result_converter(self, result, na=None):
return type(self)(result)

def _convert_int_dtype(self, result):
return type(self)(result)

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
Expand Down Expand Up @@ -2441,36 +2448,6 @@ def _str_slice(
pc.utf8_slice_codeunits(self._pa_array, start=start, stop=stop, step=step)
)

def _str_isalnum(self) -> Self:
return type(self)(pc.utf8_is_alnum(self._pa_array))

def _str_isalpha(self) -> Self:
return type(self)(pc.utf8_is_alpha(self._pa_array))

def _str_isdecimal(self) -> Self:
return type(self)(pc.utf8_is_decimal(self._pa_array))

def _str_isdigit(self) -> Self:
return type(self)(pc.utf8_is_digit(self._pa_array))

def _str_islower(self) -> Self:
return type(self)(pc.utf8_is_lower(self._pa_array))

def _str_isnumeric(self) -> Self:
return type(self)(pc.utf8_is_numeric(self._pa_array))

def _str_isspace(self) -> Self:
return type(self)(pc.utf8_is_space(self._pa_array))

def _str_istitle(self) -> Self:
return type(self)(pc.utf8_is_title(self._pa_array))

def _str_isupper(self) -> Self:
return type(self)(pc.utf8_is_upper(self._pa_array))

def _str_len(self) -> Self:
return type(self)(pc.utf8_length(self._pa_array))

def _str_lower(self) -> Self:
return type(self)(pc.utf8_lower(self._pa_array))

Expand Down
149 changes: 26 additions & 123 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@

from pandas._typing import (
ArrayLike,
AxisInt,
Dtype,
Scalar,
Self,
Expand Down Expand Up @@ -358,121 +357,45 @@ def _str_repeat(self, repeats: int | Sequence[int]):
if not isinstance(repeats, int):
return super()._str_repeat(repeats)
else:
return type(self)(pc.binary_repeat(self._pa_array, repeats))

def _str_match(
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if not pat.startswith("^"):
pat = f"^{pat}"
return self._str_contains(pat, case, flags, na, regex=True)

def _str_fullmatch(
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if not pat.endswith("$") or pat.endswith("\\$"):
pat = f"{pat}$"
return self._str_match(pat, case, flags, na)
return ArrowExtensionArray._str_repeat(self, repeats=repeats)

def _str_slice(
self, start: int | None = None, stop: int | None = None, step: int | None = None
) -> Self:
if stop is None:
return super()._str_slice(start, stop, step)
if start is None:
start = 0
if step is None:
step = 1
return type(self)(
pc.utf8_slice_codeunits(self._pa_array, start=start, stop=stop, step=step)
)

def _str_isalnum(self):
result = pc.utf8_is_alnum(self._pa_array)
return self._result_converter(result)

def _str_isalpha(self):
result = pc.utf8_is_alpha(self._pa_array)
return self._result_converter(result)

def _str_isdecimal(self):
result = pc.utf8_is_decimal(self._pa_array)
return self._result_converter(result)

def _str_isdigit(self):
result = pc.utf8_is_digit(self._pa_array)
return self._result_converter(result)

def _str_islower(self):
result = pc.utf8_is_lower(self._pa_array)
return self._result_converter(result)

def _str_isnumeric(self):
result = pc.utf8_is_numeric(self._pa_array)
return self._result_converter(result)

def _str_isspace(self):
result = pc.utf8_is_space(self._pa_array)
return self._result_converter(result)

def _str_istitle(self):
result = pc.utf8_is_title(self._pa_array)
return self._result_converter(result)

def _str_isupper(self):
result = pc.utf8_is_upper(self._pa_array)
return self._result_converter(result)

def _str_len(self):
result = pc.utf8_length(self._pa_array)
return self._convert_int_dtype(result)

def _str_lower(self) -> Self:
return type(self)(pc.utf8_lower(self._pa_array))

def _str_upper(self) -> Self:
return type(self)(pc.utf8_upper(self._pa_array))

def _str_strip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_trim_whitespace(self._pa_array)
else:
result = pc.utf8_trim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_lstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_ltrim_whitespace(self._pa_array)
else:
result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_rstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_rtrim_whitespace(self._pa_array)
else:
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
return type(self)(result)
return ArrowExtensionArray._str_slice(self, start=start, stop=stop, step=step)

_str_isalnum = ArrowStringArrayMixin._str_isalnum
_str_isalpha = ArrowStringArrayMixin._str_isalpha
_str_isdecimal = ArrowStringArrayMixin._str_isdecimal
_str_isdigit = ArrowStringArrayMixin._str_isdigit
_str_islower = ArrowStringArrayMixin._str_islower
_str_isnumeric = ArrowStringArrayMixin._str_isnumeric
_str_isspace = ArrowStringArrayMixin._str_isspace
_str_istitle = ArrowStringArrayMixin._str_istitle
_str_isupper = ArrowStringArrayMixin._str_isupper

_str_len = ArrowStringArrayMixin._str_len

_str_match = ArrowExtensionArray._str_match
_str_fullmatch = ArrowExtensionArray._str_fullmatch
_str_lower = ArrowExtensionArray._str_lower
_str_upper = ArrowExtensionArray._str_upper
_str_strip = ArrowExtensionArray._str_strip
_str_lstrip = ArrowExtensionArray._str_lstrip
_str_rstrip = ArrowExtensionArray._str_rstrip
_str_removesuffix = ArrowStringArrayMixin._str_removesuffix

def _str_removeprefix(self, prefix: str):
if not pa_version_under13p0:
starts_with = pc.starts_with(self._pa_array, pattern=prefix)
removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
result = pc.if_else(starts_with, removed, self._pa_array)
return type(self)(result)
return ArrowExtensionArray._str_removeprefix(self, prefix)
return super()._str_removeprefix(prefix)

def _str_removesuffix(self, suffix: str):
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
result = pc.if_else(ends_with, removed, self._pa_array)
return type(self)(result)

def _str_count(self, pat: str, flags: int = 0):
if flags:
return super()._str_count(pat, flags)
result = pc.count_substring_regex(self._pa_array, pat)
return self._convert_int_dtype(result)
return ArrowExtensionArray._str_count(self, pat, flags)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if start != 0 and end is not None:
Expand Down Expand Up @@ -528,27 +451,7 @@ def _reduce(
else:
return result

def _rank(
self,
*,
axis: AxisInt = 0,
method: str = "average",
na_option: str = "keep",
ascending: bool = True,
pct: bool = False,
):
"""
See Series.rank.__doc__.
"""
return self._convert_int_dtype(
self._rank_calc(
axis=axis,
method=method,
na_option=na_option,
ascending=ascending,
pct=pct,
)
)
_rank = ArrowExtensionArray._rank

def value_counts(self, dropna: bool = True) -> Series:
result = super().value_counts(dropna=dropna)
Expand Down