@@ -319,6 +319,8 @@ class BaseStringArray(ExtensionArray):
319
319
Mixin class for StringArray, ArrowStringArray.
320
320
"""
321
321
322
+ dtype : StringDtype
323
+
322
324
@doc (ExtensionArray .tolist )
323
325
def tolist (self ) -> list :
324
326
if self .ndim > 1 :
@@ -332,6 +334,37 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
332
334
raise ValueError
333
335
return cls ._from_sequence (scalars , dtype = dtype )
334
336
337
+ def _str_map_str_or_object (
338
+ self ,
339
+ dtype ,
340
+ na_value ,
341
+ arr : np .ndarray ,
342
+ f ,
343
+ mask : npt .NDArray [np .bool_ ],
344
+ convert : bool ,
345
+ ):
346
+ # _str_map helper for case where dtype is either string dtype or object
347
+ if is_string_dtype (dtype ) and not is_object_dtype (dtype ):
348
+ # i.e. StringDtype
349
+ result = lib .map_infer_mask (
350
+ arr , f , mask .view ("uint8" ), convert = False , na_value = na_value
351
+ )
352
+ if self .dtype .storage == "pyarrow" :
353
+ import pyarrow as pa
354
+
355
+ result = pa .array (
356
+ result , mask = mask , type = pa .large_string (), from_pandas = True
357
+ )
358
+ # error: Too many arguments for "BaseStringArray"
359
+ return type (self )(result ) # type: ignore[call-arg]
360
+
361
+ else :
362
+ # This is when the result type is object. We reach this when
363
+ # -> We know the result type is truly object (e.g. .encode returns bytes
364
+ # or .findall returns a list).
365
+ # -> We don't know the result type. E.g. `.get` can return anything.
366
+ return lib .map_infer_mask (arr , f , mask .view ("uint8" ))
367
+
335
368
336
369
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
337
370
# incompatible with definition in base class "ExtensionArray"
@@ -697,9 +730,53 @@ def _cmp_method(self, other, op):
697
730
# base class "NumpyExtensionArray" defined the type as "float")
698
731
_str_na_value = libmissing .NA # type: ignore[assignment]
699
732
733
+ def _str_map_nan_semantics (
734
+ self , f , na_value = None , dtype : Dtype | None = None , convert : bool = True
735
+ ):
736
+ if dtype is None :
737
+ dtype = self .dtype
738
+ if na_value is None :
739
+ na_value = self .dtype .na_value
740
+
741
+ mask = isna (self )
742
+ arr = np .asarray (self )
743
+ convert = convert and not np .all (mask )
744
+
745
+ if is_integer_dtype (dtype ) or is_bool_dtype (dtype ):
746
+ na_value_is_na = isna (na_value )
747
+ if na_value_is_na :
748
+ if is_integer_dtype (dtype ):
749
+ na_value = 0
750
+ else :
751
+ na_value = True
752
+
753
+ result = lib .map_infer_mask (
754
+ arr ,
755
+ f ,
756
+ mask .view ("uint8" ),
757
+ convert = False ,
758
+ na_value = na_value ,
759
+ dtype = np .dtype (cast (type , dtype )),
760
+ )
761
+ if na_value_is_na and mask .any ():
762
+ if is_integer_dtype (dtype ):
763
+ result = result .astype ("float64" )
764
+ else :
765
+ result = result .astype ("object" )
766
+ result [mask ] = np .nan
767
+ return result
768
+
769
+ else :
770
+ return self ._str_map_str_or_object (dtype , na_value , arr , f , mask , convert )
771
+
700
772
def _str_map (
701
773
self , f , na_value = None , dtype : Dtype | None = None , convert : bool = True
702
774
):
775
+ if self .dtype .na_value is np .nan :
776
+ return self ._str_map_nan_semantics (
777
+ f , na_value = na_value , dtype = dtype , convert = convert
778
+ )
779
+
703
780
from pandas .arrays import BooleanArray
704
781
705
782
if dtype is None :
@@ -739,18 +816,8 @@ def _str_map(
739
816
740
817
return constructor (result , mask )
741
818
742
- elif is_string_dtype (dtype ) and not is_object_dtype (dtype ):
743
- # i.e. StringDtype
744
- result = lib .map_infer_mask (
745
- arr , f , mask .view ("uint8" ), convert = False , na_value = na_value
746
- )
747
- return StringArray (result )
748
819
else :
749
- # This is when the result type is object. We reach this when
750
- # -> We know the result type is truly object (e.g. .encode returns bytes
751
- # or .findall returns a list).
752
- # -> We don't know the result type. E.g. `.get` can return anything.
753
- return lib .map_infer_mask (arr , f , mask .view ("uint8" ))
820
+ return self ._str_map_str_or_object (dtype , na_value , arr , f , mask , convert )
754
821
755
822
756
823
class StringArrayNumpySemantics (StringArray ):
@@ -817,52 +884,3 @@ def value_counts(self, dropna: bool = True) -> Series:
817
884
# ------------------------------------------------------------------------
818
885
# String methods interface
819
886
_str_na_value = np .nan
820
-
821
- def _str_map (
822
- self , f , na_value = None , dtype : Dtype | None = None , convert : bool = True
823
- ):
824
- if dtype is None :
825
- dtype = self .dtype
826
- if na_value is None :
827
- na_value = self .dtype .na_value
828
-
829
- mask = isna (self )
830
- arr = np .asarray (self )
831
- convert = convert and not np .all (mask )
832
-
833
- if is_integer_dtype (dtype ) or is_bool_dtype (dtype ):
834
- na_value_is_na = isna (na_value )
835
- if na_value_is_na :
836
- if is_integer_dtype (dtype ):
837
- na_value = 0
838
- else :
839
- na_value = True
840
-
841
- result = lib .map_infer_mask (
842
- arr ,
843
- f ,
844
- mask .view ("uint8" ),
845
- convert = False ,
846
- na_value = na_value ,
847
- dtype = np .dtype (cast (type , dtype )),
848
- )
849
- if na_value_is_na and mask .any ():
850
- if is_integer_dtype (dtype ):
851
- result = result .astype ("float64" )
852
- else :
853
- result = result .astype ("object" )
854
- result [mask ] = np .nan
855
- return result
856
-
857
- elif is_string_dtype (dtype ) and not is_object_dtype (dtype ):
858
- # i.e. StringDtype
859
- result = lib .map_infer_mask (
860
- arr , f , mask .view ("uint8" ), convert = False , na_value = na_value
861
- )
862
- return type (self )(result )
863
- else :
864
- # This is when the result type is object. We reach this when
865
- # -> We know the result type is truly object (e.g. .encode returns bytes
866
- # or .findall returns a list).
867
- # -> We don't know the result type. E.g. `.get` can return anything.
868
- return lib .map_infer_mask (arr , f , mask .view ("uint8" ))
0 commit comments