Skip to content

Commit 38c6864

Browse files
committed
[BUG] Aggregated bool has inconsistent dtype
Addresses: GH7001
1 parent bed9103 commit 38c6864

File tree

5 files changed

+86
-8
lines changed

5 files changed

+86
-8
lines changed

pandas/core/dtypes/cast.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ensure_int64,
3131
ensure_object,
3232
ensure_str,
33+
groupby_result_dtype,
3334
is_bool,
3435
is_bool_dtype,
3536
is_complex,
@@ -172,7 +173,9 @@ def maybe_downcast_to_dtype(result, dtype):
172173
return result
173174

174175

175-
def maybe_downcast_numeric(result, dtype, do_round: bool = False):
176+
def maybe_downcast_numeric(
177+
result, dtype, do_round: bool = False, how: str = "",
178+
):
176179
"""
177180
Subset of maybe_downcast_to_dtype restricted to numeric dtypes.
178181
@@ -181,6 +184,7 @@ def maybe_downcast_numeric(result, dtype, do_round: bool = False):
181184
result : ndarray or ExtensionArray
182185
dtype : np.dtype or ExtensionDtype
183186
do_round : bool
187+
how : str
184188
185189
Returns
186190
-------
@@ -195,6 +199,8 @@ def maybe_downcast_numeric(result, dtype, do_round: bool = False):
195199
# earlier
196200
result = np.array(result)
197201

202+
dtype = groupby_result_dtype(dtype, how)
203+
198204
def trans(x):
199205
if do_round:
200206
return x.round()

pandas/core/dtypes/common.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1788,3 +1788,27 @@ def pandas_dtype(dtype) -> DtypeObj:
17881788
raise TypeError(f"dtype '{dtype}' not understood")
17891789

17901790
return npdtype
1791+
1792+
1793+
def groupby_result_dtype(dtype, how) -> DtypeObj:
1794+
"""
1795+
Get the desired dtype of an aggregation result based on the
1796+
input dtype and how the aggregation is done.
1797+
1798+
Parameters
1799+
----------
1800+
dtype : dtype, type
1801+
The input dtype for the groupby.
1802+
how : str
1803+
How the aggregation is performed.
1804+
1805+
Returns
1806+
-------
1807+
The desired dtype of the aggregation result.
1808+
"""
1809+
d = {
1810+
(np.dtype(np.bool), "add"): np.dtype(np.int64),
1811+
(np.dtype(np.bool), "cumsum"): np.dtype(np.int64),
1812+
(np.dtype(np.bool), "sum"): np.dtype(np.int64),
1813+
}
1814+
return d.get((dtype, how), dtype)

pandas/core/groupby/generic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ def _transform_fast(self, result, func_nm: str) -> Series:
526526
cast = self._transform_should_cast(func_nm)
527527
out = algorithms.take_1d(result._values, ids)
528528
if cast:
529-
out = self._try_cast(out, self.obj)
529+
out = self._try_cast(out, self.obj, how=func_nm)
530530
return Series(out, index=self.obj.index, name=self.obj.name)
531531

532532
def filter(self, func, dropna=True, *args, **kwargs):
@@ -1073,7 +1073,7 @@ def _cython_agg_blocks(
10731073

10741074
if result is not no_result:
10751075
# see if we can cast the block back to the original dtype
1076-
result = maybe_downcast_numeric(result, block.dtype)
1076+
result = maybe_downcast_numeric(result, block.dtype, how=how)
10771077

10781078
if block.is_extension and isinstance(result, np.ndarray):
10791079
# e.g. block.values was an IntegerArray
@@ -1460,7 +1460,7 @@ def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:
14601460
# TODO: we have no test cases that get here with EA dtypes;
14611461
# try_cast may not be needed if EAs never get here
14621462
if cast:
1463-
res = self._try_cast(res, obj.iloc[:, i])
1463+
res = self._try_cast(res, obj.iloc[:, i], how=func_nm)
14641464
output.append(res)
14651465

14661466
return DataFrame._from_arrays(output, columns=result.columns, index=obj.index)

pandas/core/groupby/groupby.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class providing the base-class of operations.
4242
from pandas.core.dtypes.cast import maybe_downcast_to_dtype
4343
from pandas.core.dtypes.common import (
4444
ensure_float,
45+
groupby_result_dtype,
4546
is_datetime64_dtype,
4647
is_extension_array_dtype,
4748
is_integer_dtype,
@@ -792,7 +793,7 @@ def _cumcount_array(self, ascending: bool = True):
792793
rev[sorter] = np.arange(count, dtype=np.intp)
793794
return out[rev].astype(np.int64, copy=False)
794795

795-
def _try_cast(self, result, obj, numeric_only: bool = False):
796+
def _try_cast(self, result, obj, numeric_only: bool = False, how: str = ""):
796797
"""
797798
Try to cast the result to our obj original type,
798799
we may have roundtripped through object in the mean-time.
@@ -806,6 +807,8 @@ def _try_cast(self, result, obj, numeric_only: bool = False):
806807
else:
807808
dtype = obj.dtype
808809

810+
dtype = groupby_result_dtype(dtype, how)
811+
809812
if not is_scalar(result):
810813
if is_extension_array_dtype(dtype) and dtype.kind != "M":
811814
# The function can return something of any type, so check
@@ -852,7 +855,7 @@ def _cython_transform(self, how: str, numeric_only: bool = True, **kwargs):
852855
continue
853856

854857
if self._transform_should_cast(how):
855-
result = self._try_cast(result, obj)
858+
result = self._try_cast(result, obj, how=how)
856859

857860
key = base.OutputKey(label=name, position=idx)
858861
output[key] = result
@@ -895,12 +898,12 @@ def _cython_agg_general(
895898
assert len(agg_names) == result.shape[1]
896899
for result_column, result_name in zip(result.T, agg_names):
897900
key = base.OutputKey(label=result_name, position=idx)
898-
output[key] = self._try_cast(result_column, obj)
901+
output[key] = self._try_cast(result_column, obj, how=how)
899902
idx += 1
900903
else:
901904
assert result.ndim == 1
902905
key = base.OutputKey(label=name, position=idx)
903-
output[key] = self._try_cast(result, obj)
906+
output[key] = self._try_cast(result, obj, how=how)
904907
idx += 1
905908

906909
if len(output) == 0:

pandas/tests/groupby/test_groupby.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from pandas.errors import PerformanceWarning
99

10+
from pandas.core.dtypes.common import is_integer_dtype
11+
1012
import pandas as pd
1113
from pandas import DataFrame, Index, MultiIndex, Series, Timestamp, date_range, read_csv
1214
import pandas._testing as tm
@@ -2057,3 +2059,46 @@ def test_groups_repr_truncates(max_seq_items, expected):
20572059

20582060
result = df.groupby(np.array(df.a)).groups.__repr__()
20592061
assert result == expected
2062+
2063+
2064+
def test_bool_agg_dtype():
2065+
# GH 7001
2066+
# Bool aggregation results in int
2067+
df = pd.DataFrame({"a": [1, 1], "b": [False, True]})
2068+
s = df.set_index("a")["b"]
2069+
2070+
result = df.groupby("a").sum()["b"].dtype
2071+
assert is_integer_dtype(result)
2072+
2073+
result = s.groupby("a").sum().dtype
2074+
assert is_integer_dtype(result)
2075+
2076+
result = df.groupby("a").cumsum()["b"].dtype
2077+
assert is_integer_dtype(result)
2078+
2079+
result = s.groupby("a").cumsum().dtype
2080+
assert is_integer_dtype(result)
2081+
2082+
result = df.groupby("a").agg("sum")["b"].dtype
2083+
assert is_integer_dtype(result)
2084+
2085+
result = s.groupby("a").agg("sum").dtype
2086+
assert is_integer_dtype(result)
2087+
2088+
result = df.groupby("a").agg("cumsum")["b"].dtype
2089+
assert is_integer_dtype(result)
2090+
2091+
result = s.groupby("a").agg("cumsum").dtype
2092+
assert is_integer_dtype(result)
2093+
2094+
result = df.groupby("a").transform("sum")["b"].dtype
2095+
assert is_integer_dtype(result)
2096+
2097+
result = s.groupby("a").transform("sum").dtype
2098+
assert is_integer_dtype(result)
2099+
2100+
result = df.groupby("a").transform("cumsum")["b"].dtype
2101+
assert is_integer_dtype(result)
2102+
2103+
result = s.groupby("a").transform("cumsum").dtype
2104+
assert is_integer_dtype(result)

0 commit comments

Comments
 (0)