Skip to content

Commit 004e2dc

Browse files
committed
Moved cast functions to core.dtypes.cast, tests to test_aggregate
1 parent e325805 commit 004e2dc

File tree

8 files changed

+128
-120
lines changed

8 files changed

+128
-120
lines changed

pandas/core/arrays/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from pandas.core.dtypes.cast import try_cast_to_ea
2+
13
from pandas.core.arrays.base import (
24
ExtensionArray,
35
ExtensionOpsMixin,
46
ExtensionScalarOpsMixin,
5-
try_cast_to_ea,
67
)
78
from pandas.core.arrays.boolean import BooleanArray
89
from pandas.core.arrays.categorical import Categorical

pandas/core/arrays/base.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pandas.util._decorators import Appender, Substitution
2020
from pandas.util._validators import validate_fillna_kwargs
2121

22+
from pandas.core.dtypes.cast import try_cast_to_ea
2223
from pandas.core.dtypes.common import is_array_like, is_list_like
2324
from pandas.core.dtypes.dtypes import ExtensionDtype
2425
from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries
@@ -32,29 +33,6 @@
3233
_extension_array_shared_docs: Dict[str, str] = dict()
3334

3435

35-
def try_cast_to_ea(cls_or_instance, obj, dtype=None):
36-
"""
37-
Call to `_from_sequence` that returns the object unchanged on Exception.
38-
39-
Parameters
40-
----------
41-
cls_or_instance : ExtensionArray subclass or instance
42-
obj : arraylike
43-
Values to pass to cls._from_sequence
44-
dtype : ExtensionDtype, optional
45-
46-
Returns
47-
-------
48-
ExtensionArray or obj
49-
"""
50-
try:
51-
result = cls_or_instance._from_sequence(obj, dtype=dtype)
52-
except Exception:
53-
# We can't predict what downstream EA constructors may raise
54-
result = obj
55-
return result
56-
57-
5836
class ExtensionArray:
5937
"""
6038
Abstract base class for custom 1-D array types.

pandas/core/dtypes/cast.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,85 @@ def trans(x):
246246
return result
247247

248248

249+
def maybe_cast_result(result, obj, numeric_only: bool = False, how: str = ""):
250+
"""
251+
Try to cast the result to the desired type,
252+
we may have roundtripped through object in the mean-time.
253+
254+
If numeric_only is True, then only try to cast numerics
255+
and not datetimelikes.
256+
257+
"""
258+
if obj.ndim > 1:
259+
dtype = obj._values.dtype
260+
else:
261+
dtype = obj.dtype
262+
dtype = maybe_cast_result_dtype(dtype, how)
263+
264+
if not is_scalar(result):
265+
if is_extension_array_dtype(dtype) and dtype.kind != "M":
266+
# The function can return something of any type, so check
267+
# if the type is compatible with the calling EA.
268+
# datetime64tz is handled correctly in agg_series,
269+
# so is excluded here.
270+
271+
if len(result) and isinstance(result[0], dtype.type):
272+
cls = dtype.construct_array_type()
273+
result = try_cast_to_ea(cls, result, dtype=dtype)
274+
275+
elif numeric_only and is_numeric_dtype(dtype) or not numeric_only:
276+
result = maybe_downcast_to_dtype(result, dtype)
277+
278+
return result
279+
280+
281+
def maybe_cast_result_dtype(dtype, how):
282+
"""
283+
Get the desired dtype of a groupby result based on the
284+
input dtype and how the aggregation is done.
285+
286+
Parameters
287+
----------
288+
dtype : dtype, type
289+
The input dtype of the groupby.
290+
how : str
291+
How the aggregation is performed.
292+
293+
Returns
294+
-------
295+
The desired dtype of the aggregation result.
296+
"""
297+
d = {
298+
(np.dtype(np.bool), "add"): np.dtype(np.int64),
299+
(np.dtype(np.bool), "cumsum"): np.dtype(np.int64),
300+
(np.dtype(np.bool), "sum"): np.dtype(np.int64),
301+
}
302+
return d.get((dtype, how), dtype)
303+
304+
305+
def try_cast_to_ea(cls_or_instance, obj, dtype=None):
306+
"""
307+
Call to `_from_sequence` that returns the object unchanged on Exception.
308+
309+
Parameters
310+
----------
311+
cls_or_instance : ExtensionArray subclass or instance
312+
obj : arraylike
313+
Values to pass to cls._from_sequence
314+
dtype : ExtensionDtype, optional
315+
316+
Returns
317+
-------
318+
ExtensionArray or obj
319+
"""
320+
try:
321+
result = cls_or_instance._from_sequence(obj, dtype=dtype)
322+
except Exception:
323+
# We can't predict what downstream EA constructors may raise
324+
result = obj
325+
return result
326+
327+
249328
def maybe_upcast_putmask(result: np.ndarray, mask: np.ndarray, other):
250329
"""
251330
A safe version of putmask that potentially upcasts the result.

pandas/core/groupby/generic.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from pandas.util._decorators import Appender, Substitution
3535

3636
from pandas.core.dtypes.cast import (
37+
maybe_cast_result,
38+
maybe_cast_result_dtype,
3739
maybe_convert_objects,
3840
maybe_downcast_numeric,
3941
maybe_downcast_to_dtype,
@@ -526,7 +528,7 @@ def _transform_fast(self, result, func_nm: str) -> Series:
526528
cast = self._transform_should_cast(func_nm)
527529
out = algorithms.take_1d(result._values, ids)
528530
if cast:
529-
out = self._try_cast(out, self.obj, how=func_nm)
531+
out = maybe_cast_result(out, self.obj, how=func_nm)
530532
return Series(out, index=self.obj.index, name=self.obj.name)
531533

532534
def filter(self, func, dropna=True, *args, **kwargs):
@@ -1074,7 +1076,7 @@ def _cython_agg_blocks(
10741076
if result is not no_result:
10751077
# see if we can cast the block to the desired dtype
10761078
# this may not be the original dtype
1077-
dtype = self._result_dtype(block.dtype, how)
1079+
dtype = maybe_cast_result_dtype(block.dtype, how)
10781080
result = maybe_downcast_numeric(result, dtype)
10791081

10801082
if block.is_extension and isinstance(result, np.ndarray):
@@ -1177,7 +1179,7 @@ def _aggregate_item_by_item(self, func, *args, **kwargs) -> DataFrame:
11771179

11781180
else:
11791181
if cast:
1180-
result[item] = self._try_cast(result[item], data)
1182+
result[item] = maybe_cast_result(result[item], data)
11811183

11821184
result_columns = obj.columns
11831185
if cannot_agg:
@@ -1462,7 +1464,7 @@ def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:
14621464
# TODO: we have no test cases that get here with EA dtypes;
14631465
# try_cast may not be needed if EAs never get here
14641466
if cast:
1465-
res = self._try_cast(res, obj.iloc[:, i], how=func_nm)
1467+
res = maybe_cast_result(res, obj.iloc[:, i], how=func_nm)
14661468
output.append(res)
14671469

14681470
return DataFrame._from_arrays(output, columns=result.columns, index=obj.index)

pandas/core/groupby/groupby.py

Lines changed: 8 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,16 @@ class providing the base-class of operations.
3333

3434
from pandas._libs import Timestamp
3535
import pandas._libs.groupby as libgroupby
36-
from pandas._typing import DtypeObj, FrameOrSeries, Scalar
36+
from pandas._typing import FrameOrSeries, Scalar
3737
from pandas.compat import set_function_name
3838
from pandas.compat.numpy import function as nv
3939
from pandas.errors import AbstractMethodError
4040
from pandas.util._decorators import Appender, Substitution, cache_readonly
4141

42-
from pandas.core.dtypes.cast import maybe_downcast_to_dtype
42+
from pandas.core.dtypes.cast import maybe_cast_result
4343
from pandas.core.dtypes.common import (
4444
ensure_float,
4545
is_datetime64_dtype,
46-
is_extension_array_dtype,
4746
is_integer_dtype,
4847
is_numeric_dtype,
4948
is_object_dtype,
@@ -53,7 +52,7 @@ class providing the base-class of operations.
5352

5453
from pandas.core import nanops
5554
import pandas.core.algorithms as algorithms
56-
from pandas.core.arrays import Categorical, DatetimeArray, try_cast_to_ea
55+
from pandas.core.arrays import Categorical, DatetimeArray
5756
from pandas.core.base import DataError, PandasObject, SelectionMixin
5857
import pandas.core.common as com
5958
from pandas.core.frame import DataFrame
@@ -792,37 +791,6 @@ def _cumcount_array(self, ascending: bool = True):
792791
rev[sorter] = np.arange(count, dtype=np.intp)
793792
return out[rev].astype(np.int64, copy=False)
794793

795-
def _try_cast(self, result, obj, numeric_only: bool = False, how: str = ""):
796-
"""
797-
Try to cast the result to the desired type,
798-
we may have roundtripped through object in the mean-time.
799-
800-
If numeric_only is True, then only try to cast numerics
801-
and not datetimelikes.
802-
803-
"""
804-
if obj.ndim > 1:
805-
dtype = obj._values.dtype
806-
else:
807-
dtype = obj.dtype
808-
dtype = self._result_dtype(dtype, how)
809-
810-
if not is_scalar(result):
811-
if is_extension_array_dtype(dtype) and dtype.kind != "M":
812-
# The function can return something of any type, so check
813-
# if the type is compatible with the calling EA.
814-
# datetime64tz is handled correctly in agg_series,
815-
# so is excluded here.
816-
817-
if len(result) and isinstance(result[0], dtype.type):
818-
cls = dtype.construct_array_type()
819-
result = try_cast_to_ea(cls, result, dtype=dtype)
820-
821-
elif numeric_only and is_numeric_dtype(dtype) or not numeric_only:
822-
result = maybe_downcast_to_dtype(result, dtype)
823-
824-
return result
825-
826794
def _transform_should_cast(self, func_nm: str) -> bool:
827795
"""
828796
Parameters
@@ -853,7 +821,7 @@ def _cython_transform(self, how: str, numeric_only: bool = True, **kwargs):
853821
continue
854822

855823
if self._transform_should_cast(how):
856-
result = self._try_cast(result, obj, how=how)
824+
result = maybe_cast_result(result, obj, how=how)
857825

858826
key = base.OutputKey(label=name, position=idx)
859827
output[key] = result
@@ -896,12 +864,12 @@ def _cython_agg_general(
896864
assert len(agg_names) == result.shape[1]
897865
for result_column, result_name in zip(result.T, agg_names):
898866
key = base.OutputKey(label=result_name, position=idx)
899-
output[key] = self._try_cast(result_column, obj, how=how)
867+
output[key] = maybe_cast_result(result_column, obj, how=how)
900868
idx += 1
901869
else:
902870
assert result.ndim == 1
903871
key = base.OutputKey(label=name, position=idx)
904-
output[key] = self._try_cast(result, obj, how=how)
872+
output[key] = maybe_cast_result(result, obj, how=how)
905873
idx += 1
906874

907875
if len(output) == 0:
@@ -930,7 +898,7 @@ def _python_agg_general(self, func, *args, **kwargs):
930898

931899
assert result is not None
932900
key = base.OutputKey(label=name, position=idx)
933-
output[key] = self._try_cast(result, obj, numeric_only=True)
901+
output[key] = maybe_cast_result(result, obj, numeric_only=True)
934902

935903
if len(output) == 0:
936904
return self._python_apply_general(f)
@@ -945,7 +913,7 @@ def _python_agg_general(self, func, *args, **kwargs):
945913
if is_numeric_dtype(values.dtype):
946914
values = ensure_float(values)
947915

948-
output[key] = self._try_cast(values[mask], result)
916+
output[key] = maybe_cast_result(values[mask], result)
949917

950918
return self._wrap_aggregated_output(output)
951919

@@ -1026,30 +994,6 @@ def _apply_filter(self, indices, dropna):
1026994
filtered = self._selected_obj.where(mask) # Fill with NaNs.
1027995
return filtered
1028996

1029-
@staticmethod
1030-
def _result_dtype(dtype, how) -> DtypeObj:
1031-
"""
1032-
Get the desired dtype of a groupby result based on the
1033-
input dtype and how the aggregation is done.
1034-
1035-
Parameters
1036-
----------
1037-
dtype : dtype, type
1038-
The input dtype of the groupby.
1039-
how : str
1040-
How the aggregation is performed.
1041-
1042-
Returns
1043-
-------
1044-
The desired dtype of the aggregation result.
1045-
"""
1046-
d = {
1047-
(np.dtype(np.bool), "add"): np.dtype(np.int64),
1048-
(np.dtype(np.bool), "cumsum"): np.dtype(np.int64),
1049-
(np.dtype(np.bool), "sum"): np.dtype(np.int64),
1050-
}
1051-
return d.get((dtype, how), dtype)
1052-
1053997

1054998
class GroupBy(_GroupBy):
1055999
"""

pandas/core/series.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@
2727
from pandas.util._decorators import Appender, Substitution, doc
2828
from pandas.util._validators import validate_bool_kwarg, validate_percentile
2929

30-
from pandas.core.dtypes.cast import convert_dtypes, validate_numeric_casting
30+
from pandas.core.dtypes.cast import (
31+
convert_dtypes,
32+
try_cast_to_ea,
33+
validate_numeric_casting,
34+
)
3135
from pandas.core.dtypes.common import (
3236
_is_unorderable_exception,
3337
ensure_platform_int,
@@ -59,7 +63,7 @@
5963
import pandas as pd
6064
from pandas.core import algorithms, base, generic, nanops, ops
6165
from pandas.core.accessor import CachedAccessor
62-
from pandas.core.arrays import ExtensionArray, try_cast_to_ea
66+
from pandas.core.arrays import ExtensionArray
6367
from pandas.core.arrays.categorical import CategoricalAccessor
6468
from pandas.core.arrays.sparse import SparseAccessor
6569
import pandas.core.common as com

pandas/tests/groupby/aggregate/test_aggregate.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import numpy as np
77
import pytest
88

9+
from pandas.core.dtypes.common import is_integer_dtype
10+
911
import pandas as pd
1012
from pandas import DataFrame, Index, MultiIndex, Series, concat
1113
import pandas._testing as tm
@@ -340,6 +342,30 @@ def test_groupby_agg_coercing_bools():
340342
tm.assert_series_equal(result, expected)
341343

342344

345+
@pytest.mark.parametrize(
346+
"op",
347+
[
348+
lambda x: x.sum(),
349+
lambda x: x.cumsum(),
350+
lambda x: x.transform("sum"),
351+
lambda x: x.transform("cumsum"),
352+
lambda x: x.agg("sum"),
353+
lambda x: x.agg("cumsum"),
354+
],
355+
)
356+
def test_bool_agg_dtype(op):
357+
# GH 7001
358+
# Bool sum aggregations result in int
359+
df = pd.DataFrame({"a": [1, 1], "b": [False, True]})
360+
s = df.set_index("a")["b"]
361+
362+
result = op(df.groupby("a"))["b"].dtype
363+
assert is_integer_dtype(result)
364+
365+
result = op(s.groupby("a")).dtype
366+
assert is_integer_dtype(result)
367+
368+
343369
def test_order_aggregate_multiple_funcs():
344370
# GH 25692
345371
df = pd.DataFrame({"A": [1, 1, 2, 2], "B": [1, 2, 3, 4]})

0 commit comments

Comments
 (0)