Skip to content

Commit cdc4d97

Browse files
PERF: masked ops for reductions (min/max) (#33261)
1 parent f404a3f commit cdc4d97

File tree

6 files changed

+95
-27
lines changed

6 files changed

+95
-27
lines changed

doc/source/whatsnew/v1.1.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ Performance improvements
276276
sparse values from ``scipy.sparse`` matrices using the
277277
:meth:`DataFrame.sparse.from_spmatrix` constructor (:issue:`32821`,
278278
:issue:`32825`, :issue:`32826`, :issue:`32856`, :issue:`32858`).
279-
- Performance improvement in :meth:`Series.sum` for nullable (integer and boolean) dtypes (:issue:`30982`).
279+
- Performance improvement in reductions (sum, min, max) for nullable (integer and boolean) dtypes (:issue:`30982`, :issue:`33261`).
280280

281281

282282
.. ---------------------------------------------------------------------------

pandas/core/array_algos/masked_reductions.py

+41
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,44 @@ def sum(
4545
return np.sum(values[~mask])
4646
else:
4747
return np.sum(values, where=~mask)
48+
49+
50+
def _minmax(func, values: np.ndarray, mask: np.ndarray, skipna: bool = True):
51+
"""
52+
Reduction for 1D masked array.
53+
54+
Parameters
55+
----------
56+
func : np.min or np.max
57+
values : np.ndarray
58+
Numpy array with the values (can be of any dtype that support the
59+
operation).
60+
mask : np.ndarray
61+
Boolean numpy array (True values indicate missing values).
62+
skipna : bool, default True
63+
Whether to skip NA.
64+
"""
65+
if not skipna:
66+
if mask.any():
67+
return libmissing.NA
68+
else:
69+
if values.size:
70+
return func(values)
71+
else:
72+
# min/max with empty array raise in numpy, pandas returns NA
73+
return libmissing.NA
74+
else:
75+
subset = values[~mask]
76+
if subset.size:
77+
return func(values[~mask])
78+
else:
79+
# min/max with empty array raise in numpy, pandas returns NA
80+
return libmissing.NA
81+
82+
83+
def min(values: np.ndarray, mask: np.ndarray, skipna: bool = True):
84+
return _minmax(np.min, values=values, mask=mask, skipna=skipna)
85+
86+
87+
def max(values: np.ndarray, mask: np.ndarray, skipna: bool = True):
88+
return _minmax(np.max, values=values, mask=mask, skipna=skipna)

pandas/core/arrays/boolean.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -696,8 +696,9 @@ def _reduce(self, name: str, skipna: bool = True, **kwargs):
696696
data = self._data
697697
mask = self._mask
698698

699-
if name == "sum":
700-
return masked_reductions.sum(data, mask, skipna=skipna, **kwargs)
699+
if name in {"sum", "min", "max"}:
700+
op = getattr(masked_reductions, name)
701+
return op(data, mask, skipna=skipna, **kwargs)
701702

702703
# coerce to a nan-aware float if needed
703704
if self._hasna:
@@ -715,9 +716,6 @@ def _reduce(self, name: str, skipna: bool = True, **kwargs):
715716
if int_result == result:
716717
result = int_result
717718

718-
elif name in ["min", "max"] and notna(result):
719-
result = np.bool_(result)
720-
721719
return result
722720

723721
def _maybe_mask_result(self, result, mask, other, op_name: str):

pandas/core/arrays/integer.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -562,8 +562,9 @@ def _reduce(self, name: str, skipna: bool = True, **kwargs):
562562
data = self._data
563563
mask = self._mask
564564

565-
if name == "sum":
566-
return masked_reductions.sum(data, mask, skipna=skipna, **kwargs)
565+
if name in {"sum", "min", "max"}:
566+
op = getattr(masked_reductions, name)
567+
return op(data, mask, skipna=skipna, **kwargs)
567568

568569
# coerce to a nan-aware float if needed
569570
# (we explicitly use NaN within reductions)
@@ -582,7 +583,7 @@ def _reduce(self, name: str, skipna: bool = True, **kwargs):
582583

583584
# if we have a preservable numeric op,
584585
# provide coercion back to an integer type if possible
585-
elif name in ["min", "max", "prod"]:
586+
elif name == "prod":
586587
# GH#31409 more performant than casting-then-checking
587588
result = com.cast_scalar_indexer(result)
588589

pandas/tests/arrays/integer/test_dtypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_preserve_dtypes(op):
3434

3535
# op
3636
result = getattr(df.C, op)()
37-
if op == "sum":
37+
if op in {"sum", "min", "max"}:
3838
assert isinstance(result, np.int64)
3939
else:
4040
assert isinstance(result, int)

pandas/tests/reductions/test_reductions.py

+45-17
Original file line numberDiff line numberDiff line change
@@ -65,27 +65,58 @@ def test_ops(self, opname, obj):
6565
assert result.value == expected
6666

6767
@pytest.mark.parametrize("opname", ["max", "min"])
68-
def test_nanops(self, opname, index_or_series):
68+
@pytest.mark.parametrize(
69+
"dtype, val",
70+
[
71+
("object", 2.0),
72+
("float64", 2.0),
73+
("datetime64[ns]", datetime(2011, 11, 1)),
74+
("Int64", 2),
75+
("boolean", True),
76+
],
77+
)
78+
def test_nanminmax(self, opname, dtype, val, index_or_series):
6979
# GH#7261
7080
klass = index_or_series
71-
arg_op = "arg" + opname if klass is Index else "idx" + opname
7281

73-
obj = klass([np.nan, 2.0])
74-
assert getattr(obj, opname)() == 2.0
82+
if dtype in ["Int64", "boolean"] and klass == pd.Index:
83+
pytest.skip("EAs can't yet be stored in an index")
7584

76-
obj = klass([np.nan])
77-
assert pd.isna(getattr(obj, opname)())
78-
assert pd.isna(getattr(obj, opname)(skipna=False))
85+
def check_missing(res):
86+
if dtype == "datetime64[ns]":
87+
return res is pd.NaT
88+
elif dtype == "Int64":
89+
return res is pd.NA
90+
else:
91+
return pd.isna(res)
7992

80-
obj = klass([], dtype=object)
81-
assert pd.isna(getattr(obj, opname)())
82-
assert pd.isna(getattr(obj, opname)(skipna=False))
93+
obj = klass([None], dtype=dtype)
94+
assert check_missing(getattr(obj, opname)())
95+
assert check_missing(getattr(obj, opname)(skipna=False))
8396

84-
obj = klass([pd.NaT, datetime(2011, 11, 1)])
85-
# check DatetimeIndex monotonic path
86-
assert getattr(obj, opname)() == datetime(2011, 11, 1)
87-
assert getattr(obj, opname)(skipna=False) is pd.NaT
97+
obj = klass([], dtype=dtype)
98+
assert check_missing(getattr(obj, opname)())
99+
assert check_missing(getattr(obj, opname)(skipna=False))
100+
101+
if dtype == "object":
102+
# generic test with object only works for empty / all NaN
103+
return
104+
105+
obj = klass([None, val], dtype=dtype)
106+
assert getattr(obj, opname)() == val
107+
assert check_missing(getattr(obj, opname)(skipna=False))
88108

109+
obj = klass([None, val, None], dtype=dtype)
110+
assert getattr(obj, opname)() == val
111+
assert check_missing(getattr(obj, opname)(skipna=False))
112+
113+
@pytest.mark.parametrize("opname", ["max", "min"])
114+
def test_nanargminmax(self, opname, index_or_series):
115+
# GH#7261
116+
klass = index_or_series
117+
arg_op = "arg" + opname if klass is Index else "idx" + opname
118+
119+
obj = klass([pd.NaT, datetime(2011, 11, 1)])
89120
assert getattr(obj, arg_op)() == 1
90121
result = getattr(obj, arg_op)(skipna=False)
91122
if klass is Series:
@@ -95,9 +126,6 @@ def test_nanops(self, opname, index_or_series):
95126

96127
obj = klass([pd.NaT, datetime(2011, 11, 1), pd.NaT])
97128
# check DatetimeIndex non-monotonic path
98-
assert getattr(obj, opname)(), datetime(2011, 11, 1)
99-
assert getattr(obj, opname)(skipna=False) is pd.NaT
100-
101129
assert getattr(obj, arg_op)() == 1
102130
result = getattr(obj, arg_op)(skipna=False)
103131
if klass is Series:

0 commit comments

Comments
 (0)