Skip to content

Commit 3754267

Browse files
authored
REF (string dtype): de-duplicate _str_map methods (#59443)
* REF: de-duplicate _str_map methods * mypy fixup
1 parent d0cb205 commit 3754267

File tree

2 files changed

+123
-128
lines changed

2 files changed

+123
-128
lines changed

pandas/core/arrays/string_.py

Lines changed: 78 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,8 @@ class BaseStringArray(ExtensionArray):
319319
Mixin class for StringArray, ArrowStringArray.
320320
"""
321321

322+
dtype: StringDtype
323+
322324
@doc(ExtensionArray.tolist)
323325
def tolist(self) -> list:
324326
if self.ndim > 1:
@@ -332,6 +334,37 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
332334
raise ValueError
333335
return cls._from_sequence(scalars, dtype=dtype)
334336

337+
def _str_map_str_or_object(
338+
self,
339+
dtype,
340+
na_value,
341+
arr: np.ndarray,
342+
f,
343+
mask: npt.NDArray[np.bool_],
344+
convert: bool,
345+
):
346+
# _str_map helper for case where dtype is either string dtype or object
347+
if is_string_dtype(dtype) and not is_object_dtype(dtype):
348+
# i.e. StringDtype
349+
result = lib.map_infer_mask(
350+
arr, f, mask.view("uint8"), convert=False, na_value=na_value
351+
)
352+
if self.dtype.storage == "pyarrow":
353+
import pyarrow as pa
354+
355+
result = pa.array(
356+
result, mask=mask, type=pa.large_string(), from_pandas=True
357+
)
358+
# error: Too many arguments for "BaseStringArray"
359+
return type(self)(result) # type: ignore[call-arg]
360+
361+
else:
362+
# This is when the result type is object. We reach this when
363+
# -> We know the result type is truly object (e.g. .encode returns bytes
364+
# or .findall returns a list).
365+
# -> We don't know the result type. E.g. `.get` can return anything.
366+
return lib.map_infer_mask(arr, f, mask.view("uint8"))
367+
335368

336369
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
337370
# incompatible with definition in base class "ExtensionArray"
@@ -697,9 +730,53 @@ def _cmp_method(self, other, op):
697730
# base class "NumpyExtensionArray" defined the type as "float")
698731
_str_na_value = libmissing.NA # type: ignore[assignment]
699732

733+
def _str_map_nan_semantics(
734+
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
735+
):
736+
if dtype is None:
737+
dtype = self.dtype
738+
if na_value is None:
739+
na_value = self.dtype.na_value
740+
741+
mask = isna(self)
742+
arr = np.asarray(self)
743+
convert = convert and not np.all(mask)
744+
745+
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
746+
na_value_is_na = isna(na_value)
747+
if na_value_is_na:
748+
if is_integer_dtype(dtype):
749+
na_value = 0
750+
else:
751+
na_value = True
752+
753+
result = lib.map_infer_mask(
754+
arr,
755+
f,
756+
mask.view("uint8"),
757+
convert=False,
758+
na_value=na_value,
759+
dtype=np.dtype(cast(type, dtype)),
760+
)
761+
if na_value_is_na and mask.any():
762+
if is_integer_dtype(dtype):
763+
result = result.astype("float64")
764+
else:
765+
result = result.astype("object")
766+
result[mask] = np.nan
767+
return result
768+
769+
else:
770+
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
771+
700772
def _str_map(
701773
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
702774
):
775+
if self.dtype.na_value is np.nan:
776+
return self._str_map_nan_semantics(
777+
f, na_value=na_value, dtype=dtype, convert=convert
778+
)
779+
703780
from pandas.arrays import BooleanArray
704781

705782
if dtype is None:
@@ -739,18 +816,8 @@ def _str_map(
739816

740817
return constructor(result, mask)
741818

742-
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
743-
# i.e. StringDtype
744-
result = lib.map_infer_mask(
745-
arr, f, mask.view("uint8"), convert=False, na_value=na_value
746-
)
747-
return StringArray(result)
748819
else:
749-
# This is when the result type is object. We reach this when
750-
# -> We know the result type is truly object (e.g. .encode returns bytes
751-
# or .findall returns a list).
752-
# -> We don't know the result type. E.g. `.get` can return anything.
753-
return lib.map_infer_mask(arr, f, mask.view("uint8"))
820+
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
754821

755822

756823
class StringArrayNumpySemantics(StringArray):
@@ -817,52 +884,3 @@ def value_counts(self, dropna: bool = True) -> Series:
817884
# ------------------------------------------------------------------------
818885
# String methods interface
819886
_str_na_value = np.nan
820-
821-
def _str_map(
822-
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
823-
):
824-
if dtype is None:
825-
dtype = self.dtype
826-
if na_value is None:
827-
na_value = self.dtype.na_value
828-
829-
mask = isna(self)
830-
arr = np.asarray(self)
831-
convert = convert and not np.all(mask)
832-
833-
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
834-
na_value_is_na = isna(na_value)
835-
if na_value_is_na:
836-
if is_integer_dtype(dtype):
837-
na_value = 0
838-
else:
839-
na_value = True
840-
841-
result = lib.map_infer_mask(
842-
arr,
843-
f,
844-
mask.view("uint8"),
845-
convert=False,
846-
na_value=na_value,
847-
dtype=np.dtype(cast(type, dtype)),
848-
)
849-
if na_value_is_na and mask.any():
850-
if is_integer_dtype(dtype):
851-
result = result.astype("float64")
852-
else:
853-
result = result.astype("object")
854-
result[mask] = np.nan
855-
return result
856-
857-
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
858-
# i.e. StringDtype
859-
result = lib.map_infer_mask(
860-
arr, f, mask.view("uint8"), convert=False, na_value=na_value
861-
)
862-
return type(self)(result)
863-
else:
864-
# This is when the result type is object. We reach this when
865-
# -> We know the result type is truly object (e.g. .encode returns bytes
866-
# or .findall returns a list).
867-
# -> We don't know the result type. E.g. `.get` can return anything.
868-
return lib.map_infer_mask(arr, f, mask.view("uint8"))

pandas/core/arrays/string_arrow.py

Lines changed: 45 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@
2525
from pandas.core.dtypes.common import (
2626
is_bool_dtype,
2727
is_integer_dtype,
28-
is_object_dtype,
2928
is_scalar,
30-
is_string_dtype,
3129
pandas_dtype,
3230
)
3331
from pandas.core.dtypes.missing import isna
@@ -281,9 +279,53 @@ def astype(self, dtype, copy: bool = True):
281279
# base class "ObjectStringArrayMixin" defined the type as "float")
282280
_str_na_value = libmissing.NA # type: ignore[assignment]
283281

282+
def _str_map_nan_semantics(
283+
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
284+
):
285+
if dtype is None:
286+
dtype = self.dtype
287+
if na_value is None:
288+
na_value = self.dtype.na_value
289+
290+
mask = isna(self)
291+
arr = np.asarray(self)
292+
293+
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
294+
if is_integer_dtype(dtype):
295+
na_value = np.nan
296+
else:
297+
na_value = False
298+
299+
dtype = np.dtype(cast(type, dtype))
300+
if mask.any():
301+
# numpy int/bool dtypes cannot hold NaNs so we must convert to
302+
# float64 for int (to match maybe_convert_objects) or
303+
# object for bool (again to match maybe_convert_objects)
304+
if is_integer_dtype(dtype):
305+
dtype = np.dtype("float64")
306+
else:
307+
dtype = np.dtype(object)
308+
result = lib.map_infer_mask(
309+
arr,
310+
f,
311+
mask.view("uint8"),
312+
convert=False,
313+
na_value=na_value,
314+
dtype=dtype,
315+
)
316+
return result
317+
318+
else:
319+
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
320+
284321
def _str_map(
285322
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
286323
):
324+
if self.dtype.na_value is np.nan:
325+
return self._str_map_nan_semantics(
326+
f, na_value=na_value, dtype=dtype, convert=convert
327+
)
328+
287329
# TODO: de-duplicate with StringArray method. This method is moreless copy and
288330
# paste.
289331

@@ -327,21 +369,8 @@ def _str_map(
327369

328370
return constructor(result, mask)
329371

330-
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
331-
# i.e. StringDtype
332-
result = lib.map_infer_mask(
333-
arr, f, mask.view("uint8"), convert=False, na_value=na_value
334-
)
335-
result = pa.array(
336-
result, mask=mask, type=pa.large_string(), from_pandas=True
337-
)
338-
return type(self)(result)
339372
else:
340-
# This is when the result type is object. We reach this when
341-
# -> We know the result type is truly object (e.g. .encode returns bytes
342-
# or .findall returns a list).
343-
# -> We don't know the result type. E.g. `.get` can return anything.
344-
return lib.map_infer_mask(arr, f, mask.view("uint8"))
373+
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
345374

346375
def _str_contains(
347376
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
@@ -614,58 +643,6 @@ def __getattribute__(self, item):
614643
return partial(getattr(ArrowStringArrayMixin, item), self)
615644
return super().__getattribute__(item)
616645

617-
def _str_map(
618-
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
619-
):
620-
if dtype is None:
621-
dtype = self.dtype
622-
if na_value is None:
623-
na_value = self.dtype.na_value
624-
625-
mask = isna(self)
626-
arr = np.asarray(self)
627-
628-
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
629-
if is_integer_dtype(dtype):
630-
na_value = np.nan
631-
else:
632-
na_value = False
633-
634-
dtype = np.dtype(cast(type, dtype))
635-
if mask.any():
636-
# numpy int/bool dtypes cannot hold NaNs so we must convert to
637-
# float64 for int (to match maybe_convert_objects) or
638-
# object for bool (again to match maybe_convert_objects)
639-
if is_integer_dtype(dtype):
640-
dtype = np.dtype("float64")
641-
else:
642-
dtype = np.dtype(object)
643-
result = lib.map_infer_mask(
644-
arr,
645-
f,
646-
mask.view("uint8"),
647-
convert=False,
648-
na_value=na_value,
649-
dtype=dtype,
650-
)
651-
return result
652-
653-
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
654-
# i.e. StringDtype
655-
result = lib.map_infer_mask(
656-
arr, f, mask.view("uint8"), convert=False, na_value=na_value
657-
)
658-
result = pa.array(
659-
result, mask=mask, type=pa.large_string(), from_pandas=True
660-
)
661-
return type(self)(result)
662-
else:
663-
# This is when the result type is object. We reach this when
664-
# -> We know the result type is truly object (e.g. .encode returns bytes
665-
# or .findall returns a list).
666-
# -> We don't know the result type. E.g. `.get` can return anything.
667-
return lib.map_infer_mask(arr, f, mask.view("uint8"))
668-
669646
def _convert_int_dtype(self, result):
670647
if isinstance(result, pa.Array):
671648
result = result.to_numpy(zero_copy_only=False)

0 commit comments

Comments
 (0)