Skip to content

Commit b3520d1

Browse files
author
Kei
committed
Account for categorical dtype
1 parent a3be335 commit b3520d1

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

pandas/core/groupby/groupby.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class providing the base-class of operations.
8484
needs_i8_conversion,
8585
pandas_dtype,
8686
)
87+
from pandas.core.dtypes.dtypes import CategoricalDtype
8788
from pandas.core.dtypes.missing import (
8889
isna,
8990
na_value_for_dtype,
@@ -2009,7 +2010,7 @@ def _convert_result_dtype(
20092010

20102011
converted_result_values = np.empty(out_shape, dtype=out_dtype)
20112012
if func not in cy_op.cast_blocklist:
2012-
res_dtype = cy_op._get_result_dtype(timezone_free_orig_input_values.dtype)
2013+
res_dtype = cy_op._get_result_dtype(input_values.dtype)
20132014
converted_result_values = maybe_downcast_to_dtype(
20142015
converted_result_values, res_dtype
20152016
)
@@ -2052,9 +2053,11 @@ def _preprocess_input_values(self, func, input_values: ArrayLike) -> ArrayLike:
20522053
input_values = input_values.view("int64")
20532054
elif dtype.kind == "b":
20542055
input_values = input_values.view("uint8")
2055-
2056-
if input_values.dtype == "float16":
2056+
elif input_values.dtype == "float16":
20572057
input_values = input_values.astype(np.float32)
2058+
elif isinstance(dtype, CategoricalDtype):
2059+
input_values = input_values[0].astype(bool)
2060+
input_values = input_values[None, :]
20582061

20592062
if func in ["any", "all"]:
20602063
input_values = input_values.astype(bool, copy=False).view(np.int8)

0 commit comments

Comments
 (0)