Skip to content

Commit bd34d30

Browse files
committed
Moved function groupby_result_dtype to _GroupBy._result_dtype
- Reverted maybe_downcast_numeric to its original state - Parameterized tests
1 parent 8e31908 commit bd34d30

File tree

5 files changed

+47
-72
lines changed

5 files changed

+47
-72
lines changed

pandas/core/dtypes/cast.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
ensure_int64,
3131
ensure_object,
3232
ensure_str,
33-
groupby_result_dtype,
3433
is_bool,
3534
is_bool_dtype,
3635
is_complex,
@@ -173,9 +172,7 @@ def maybe_downcast_to_dtype(result, dtype):
173172
return result
174173

175174

176-
def maybe_downcast_numeric(
177-
result, dtype, do_round: bool = False, how: str = "",
178-
):
175+
def maybe_downcast_numeric(result, dtype, do_round: bool = False):
179176
"""
180177
Subset of maybe_downcast_to_dtype restricted to numeric dtypes.
181178
@@ -184,7 +181,6 @@ def maybe_downcast_numeric(
184181
result : ndarray or ExtensionArray
185182
dtype : np.dtype or ExtensionDtype
186183
do_round : bool
187-
how : str
188184
189185
Returns
190186
-------
@@ -199,8 +195,6 @@ def maybe_downcast_numeric(
199195
# earlier
200196
result = np.array(result)
201197

202-
dtype = groupby_result_dtype(dtype, how)
203-
204198
def trans(x):
205199
if do_round:
206200
return x.round()

pandas/core/dtypes/common.py

-24
Original file line numberDiff line numberDiff line change
@@ -1786,27 +1786,3 @@ def pandas_dtype(dtype) -> DtypeObj:
17861786
raise TypeError(f"dtype '{dtype}' not understood")
17871787

17881788
return npdtype
1789-
1790-
1791-
def groupby_result_dtype(dtype, how) -> DtypeObj:
1792-
"""
1793-
Get the desired dtype of an aggregation result based on the
1794-
input dtype and how the aggregation is done.
1795-
1796-
Parameters
1797-
----------
1798-
dtype : dtype, type
1799-
The input dtype for the groupby.
1800-
how : str
1801-
How the aggregation is performed.
1802-
1803-
Returns
1804-
-------
1805-
The desired dtype of the aggregation result.
1806-
"""
1807-
d = {
1808-
(np.dtype(np.bool), "add"): np.dtype(np.int64),
1809-
(np.dtype(np.bool), "cumsum"): np.dtype(np.int64),
1810-
(np.dtype(np.bool), "sum"): np.dtype(np.int64),
1811-
}
1812-
return d.get((dtype, how), dtype)

pandas/core/groupby/generic.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1072,8 +1072,10 @@ def _cython_agg_blocks(
10721072
assert not isinstance(result, DataFrame)
10731073

10741074
if result is not no_result:
1075-
# see if we can cast the block back to the original dtype
1076-
result = maybe_downcast_numeric(result, block.dtype, how=how)
1075+
# see if we can cast the block to the desired dtype
1076+
# this may not be the original dtype
1077+
dtype = self._result_dtype(block.dtype, how)
1078+
result = maybe_downcast_numeric(result, dtype)
10771079

10781080
if block.is_extension and isinstance(result, np.ndarray):
10791081
# e.g. block.values was an IntegerArray

pandas/core/groupby/groupby.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class providing the base-class of operations.
3333

3434
from pandas._libs import Timestamp
3535
import pandas._libs.groupby as libgroupby
36-
from pandas._typing import FrameOrSeries, Scalar
36+
from pandas._typing import DtypeObj, FrameOrSeries, Scalar
3737
from pandas.compat import set_function_name
3838
from pandas.compat.numpy import function as nv
3939
from pandas.errors import AbstractMethodError
@@ -42,7 +42,6 @@ class providing the base-class of operations.
4242
from pandas.core.dtypes.cast import maybe_downcast_to_dtype
4343
from pandas.core.dtypes.common import (
4444
ensure_float,
45-
groupby_result_dtype,
4645
is_datetime64_dtype,
4746
is_extension_array_dtype,
4847
is_integer_dtype,
@@ -795,7 +794,7 @@ def _cumcount_array(self, ascending: bool = True):
795794

796795
def _try_cast(self, result, obj, numeric_only: bool = False, how: str = ""):
797796
"""
798-
Try to cast the result to our obj original type,
797+
Try to cast the result to the desired type,
799798
we may have roundtripped through object in the mean-time.
800799
801800
If numeric_only is True, then only try to cast numerics
@@ -806,8 +805,7 @@ def _try_cast(self, result, obj, numeric_only: bool = False, how: str = ""):
806805
dtype = obj._values.dtype
807806
else:
808807
dtype = obj.dtype
809-
810-
dtype = groupby_result_dtype(dtype, how)
808+
dtype = self._result_dtype(dtype, how)
811809

812810
if not is_scalar(result):
813811
if is_extension_array_dtype(dtype) and dtype.kind != "M":
@@ -1028,6 +1026,30 @@ def _apply_filter(self, indices, dropna):
10281026
filtered = self._selected_obj.where(mask) # Fill with NaNs.
10291027
return filtered
10301028

1029+
@staticmethod
1030+
def _result_dtype(dtype, how) -> DtypeObj:
1031+
"""
1032+
Get the desired dtype of a groupby result based on the
1033+
input dtype and how the aggregation is done.
1034+
1035+
Parameters
1036+
----------
1037+
dtype : dtype, type
1038+
The input dtype of the groupby.
1039+
how : str
1040+
How the aggregation is performed.
1041+
1042+
Returns
1043+
-------
1044+
The desired dtype of the aggregation result.
1045+
"""
1046+
d = {
1047+
(np.dtype(np.bool), "add"): np.dtype(np.int64),
1048+
(np.dtype(np.bool), "cumsum"): np.dtype(np.int64),
1049+
(np.dtype(np.bool), "sum"): np.dtype(np.int64),
1050+
}
1051+
return d.get((dtype, how), dtype)
1052+
10311053

10321054
class GroupBy(_GroupBy):
10331055
"""

pandas/tests/groupby/test_groupby.py

+15-34
Original file line numberDiff line numberDiff line change
@@ -2061,44 +2061,25 @@ def test_groups_repr_truncates(max_seq_items, expected):
20612061
assert result == expected
20622062

20632063

2064-
def test_bool_agg_dtype():
2064+
@pytest.mark.parametrize(
2065+
"op",
2066+
[
2067+
lambda x: x.sum(),
2068+
lambda x: x.cumsum(),
2069+
lambda x: x.transform("sum"),
2070+
lambda x: x.transform("cumsum"),
2071+
lambda x: x.agg("sum"),
2072+
lambda x: x.agg("cumsum"),
2073+
],
2074+
)
2075+
def test_bool_agg_dtype(op):
20652076
# GH 7001
2066-
# Bool aggregation results in int
2077+
# Bool sum aggregations result in int
20672078
df = pd.DataFrame({"a": [1, 1], "b": [False, True]})
20682079
s = df.set_index("a")["b"]
20692080

2070-
result = df.groupby("a").sum()["b"].dtype
2071-
assert is_integer_dtype(result)
2072-
2073-
result = s.groupby("a").sum().dtype
2074-
assert is_integer_dtype(result)
2075-
2076-
result = df.groupby("a").cumsum()["b"].dtype
2077-
assert is_integer_dtype(result)
2078-
2079-
result = s.groupby("a").cumsum().dtype
2080-
assert is_integer_dtype(result)
2081-
2082-
result = df.groupby("a").agg("sum")["b"].dtype
2083-
assert is_integer_dtype(result)
2084-
2085-
result = s.groupby("a").agg("sum").dtype
2086-
assert is_integer_dtype(result)
2087-
2088-
result = df.groupby("a").agg("cumsum")["b"].dtype
2089-
assert is_integer_dtype(result)
2090-
2091-
result = s.groupby("a").agg("cumsum").dtype
2092-
assert is_integer_dtype(result)
2093-
2094-
result = df.groupby("a").transform("sum")["b"].dtype
2095-
assert is_integer_dtype(result)
2096-
2097-
result = s.groupby("a").transform("sum").dtype
2098-
assert is_integer_dtype(result)
2099-
2100-
result = df.groupby("a").transform("cumsum")["b"].dtype
2081+
result = op(df.groupby("a"))["b"].dtype
21012082
assert is_integer_dtype(result)
21022083

2103-
result = s.groupby("a").transform("cumsum").dtype
2084+
result = op(s.groupby("a")).dtype
21042085
assert is_integer_dtype(result)

0 commit comments

Comments
 (0)