@@ -33,7 +33,7 @@ class providing the base-class of operations.
33
33
34
34
from pandas ._libs import Timestamp
35
35
import pandas ._libs .groupby as libgroupby
36
- from pandas ._typing import FrameOrSeries , Scalar
36
+ from pandas ._typing import DtypeObj , FrameOrSeries , Scalar
37
37
from pandas .compat import set_function_name
38
38
from pandas .compat .numpy import function as nv
39
39
from pandas .errors import AbstractMethodError
@@ -42,7 +42,6 @@ class providing the base-class of operations.
42
42
from pandas .core .dtypes .cast import maybe_downcast_to_dtype
43
43
from pandas .core .dtypes .common import (
44
44
ensure_float ,
45
- groupby_result_dtype ,
46
45
is_datetime64_dtype ,
47
46
is_extension_array_dtype ,
48
47
is_integer_dtype ,
@@ -795,7 +794,7 @@ def _cumcount_array(self, ascending: bool = True):
795
794
796
795
def _try_cast (self , result , obj , numeric_only : bool = False , how : str = "" ):
797
796
"""
798
- Try to cast the result to our obj original type,
797
+ Try to cast the result to the desired type,
799
798
we may have roundtripped through object in the mean-time.
800
799
801
800
If numeric_only is True, then only try to cast numerics
@@ -806,8 +805,7 @@ def _try_cast(self, result, obj, numeric_only: bool = False, how: str = ""):
806
805
dtype = obj ._values .dtype
807
806
else :
808
807
dtype = obj .dtype
809
-
810
- dtype = groupby_result_dtype (dtype , how )
808
+ dtype = self ._result_dtype (dtype , how )
811
809
812
810
if not is_scalar (result ):
813
811
if is_extension_array_dtype (dtype ) and dtype .kind != "M" :
@@ -1028,6 +1026,30 @@ def _apply_filter(self, indices, dropna):
1028
1026
filtered = self ._selected_obj .where (mask ) # Fill with NaNs.
1029
1027
return filtered
1030
1028
1029
+ @staticmethod
1030
+ def _result_dtype (dtype , how ) -> DtypeObj :
1031
+ """
1032
+ Get the desired dtype of a groupby result based on the
1033
+ input dtype and how the aggregation is done.
1034
+
1035
+ Parameters
1036
+ ----------
1037
+ dtype : dtype, type
1038
+ The input dtype of the groupby.
1039
+ how : str
1040
+ How the aggregation is performed.
1041
+
1042
+ Returns
1043
+ -------
1044
+ The desired dtype of the aggregation result.
1045
+ """
1046
+ d = {
1047
+ (np .dtype (np .bool ), "add" ): np .dtype (np .int64 ),
1048
+ (np .dtype (np .bool ), "cumsum" ): np .dtype (np .int64 ),
1049
+ (np .dtype (np .bool ), "sum" ): np .dtype (np .int64 ),
1050
+ }
1051
+ return d .get ((dtype , how ), dtype )
1052
+
1031
1053
1032
1054
class GroupBy (_GroupBy ):
1033
1055
"""
0 commit comments