diff --git a/doc/source/whatsnew/v1.0.0.rst b/doc/source/whatsnew/v1.0.0.rst index c352a36bf6de1..c768da7c9f81f 100644 --- a/doc/source/whatsnew/v1.0.0.rst +++ b/doc/source/whatsnew/v1.0.0.rst @@ -39,8 +39,58 @@ Backwards incompatible API changes .. _whatsnew_1000.api.other: -- :class:`pandas.core.groupby.GroupBy.transform` now raises on invalid operation names (:issue:`27489`). -- +Groupby.transform(str) validates name is an aggregation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In previous releases, :meth:`DataFrameGroupBy.transform` and +:meth:`SeriesGroupBy.transform` did not validate that the function name +passed was actually the name of an aggregation. As a result, users might get a +cryptic error or worse, erroneous results. Starting with this release, these +methods will rised if the name of a non-aggregation is passed to them. There +is no change in the behavior associated with passing a callable. + +Users who relied on :meth:`DataFrameGroupBy.transform` or :meth:`SeriesGroupBy.transform` +for transformations such as :meth:`DataFrameGroupBy.rank`, :meth:`DataFrameGroupBy.ffill`, +etc, should instead call these method directly +(:issue:`27597`) (:issue:`14274`) (:issue:`19354`) (:issue:`22509`). + +.. ipython:: python + + df = pd.DataFrame([0, 1, 100, 99]) + labels = [0, 0, 1, 1] + g = df.groupby(labels) + +*Previous behavior*: + +.. code-block:: ipython + + In [1]: g.transform('ers >= Decepticons') + AttributeError: 'DataFrameGroupBy' object has no attribute 'ers >= Decepticons' + + g.transform('rank') + Out[14]: + 0 + 0 1.0 + 1 1.0 + 2 2.0 + 3 2.0 + + g.rank() + Out[15]: + 0 + 0 1.0 + 1 2.0 + 2 2.0 + 3 1.0 + +*New behavior*: + +.. ipython:: python + :okexcept: + + g.transform('ers >= Decepticons') + g.transform('rank') + Other API changes ^^^^^^^^^^^^^^^^^ @@ -78,6 +128,7 @@ Performance improvements Bug fixes ~~~~~~~~~ +- Categorical ^^^^^^^^^^^ diff --git a/pandas/core/groupby/base.py b/pandas/core/groupby/base.py index fc3bb69afd0cb..b23a75747e737 100644 --- a/pandas/core/groupby/base.py +++ b/pandas/core/groupby/base.py @@ -144,7 +144,6 @@ def _gotitem(self, key, ndim, subset=None): [ "backfill", "bfill", - "corrwith", "cumcount", "cummax", "cummin", @@ -173,6 +172,8 @@ def _gotitem(self, key, ndim, subset=None): # are neither a transformation nor a reduction "corr", "cov", + # corrwith does not preserve shape, depending on `other` + "corrwith", "describe", "dtypes", "expanding", @@ -197,4 +198,4 @@ def _gotitem(self, key, ndim, subset=None): # Valid values of `name` for `groupby.transform(name)` # NOTE: do NOT edit this directly. New additions should be inserted # into the appropriate list above. -transform_kernel_whitelist = reduction_kernels | transformation_kernels +groupby_transform_whitelist = reduction_kernels diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 1fef65349976b..647dc63326767 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -576,8 +576,17 @@ def transform(self, func, *args, **kwargs): func = self._get_cython_func(func) or func if isinstance(func, str): - if not (func in base.transform_kernel_whitelist): - msg = "'{func}' is not a valid function name for transform(name)" + if not (func in base.groupby_transform_whitelist): + msg = ( + "`g.transform('func')` is used exclusively for " + "computing aggregations and broadcasting " + "the results across groups. " + "'{func}' is not a valid aggregation name. " + ) + + if func in base.transformation_kernels | base.groupby_other_methods: + msg += "Perhaps you should try .{func}() instead?" + msg = msg.format(func=func) raise ValueError(msg.format(func=func)) if func in base.cythonized_kernels: # cythonized transformation or canned "reduction+broadcast" @@ -615,7 +624,11 @@ def _transform_fast(self, result, obj, func_nm): ids, _, ngroup = self.grouper.group_info output = [] for i, _ in enumerate(result.columns): - res = algorithms.take_1d(result.iloc[:, i].values, ids) + if func_nm in base.reduction_kernels: + # only broadcast results if we performed a reduction + res = algorithms.take_1d(result.iloc[:, i]._values, ids) + else: + res = result.iloc[:, i].values if cast: res = self._try_cast(res, obj.iloc[:, i]) output.append(res) @@ -1014,8 +1027,17 @@ def transform(self, func, *args, **kwargs): func = self._get_cython_func(func) or func if isinstance(func, str): - if not (func in base.transform_kernel_whitelist): - msg = "'{func}' is not a valid function name for transform(name)" + if not (func in base.groupby_transform_whitelist): + msg = ( + "`g.transform('func')` is used exclusively for " + "computing aggregations and broadcasting " + "the results across groups. " + "'{func}' is not a valid aggregation name. " + ) + + if func in base.transformation_kernels | base.groupby_other_methods: + msg += "Perhaps you should try .{func}() instead?" + msg = msg.format(func=func) raise ValueError(msg.format(func=func)) if func in base.cythonized_kernels: # cythonized transform or canned "agg+broadcast" @@ -1072,7 +1094,10 @@ def _transform_fast(self, func, func_nm): ids, _, ngroup = self.grouper.group_info cast = self._transform_should_cast(func_nm) - out = algorithms.take_1d(func()._values, ids) + out = func() + if func_nm in base.reduction_kernels: + # only broadcast results if we performed a reduction + out = algorithms.take_1d(out._values, ids) if cast: out = self._try_cast(out, self.obj) return Series(out, index=self.obj.index, name=self.obj.name) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 3d4dbd3f8d887..5568aacfd849f 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -241,8 +241,9 @@ class providing the base-class of operations. Parameters ---------- -f : function - Function to apply to each group +func : callable or str + Callable to apply to each group OR + name of an aggregation function. Returns ------- @@ -257,6 +258,10 @@ class providing the base-class of operations. Each group is endowed the attribute 'name' in case you need to know which group you are working on. +If `func` is the name of an aggregation, the resulting value for +each group is replicated along the row axis to produce an output +with the same shape as the input. + The current implementation imposes three requirements on f: * f must return a value that either has the same shape as the input diff --git a/pandas/tests/groupby/conftest.py b/pandas/tests/groupby/conftest.py index 72e60c5099304..5c6caa6790153 100644 --- a/pandas/tests/groupby/conftest.py +++ b/pandas/tests/groupby/conftest.py @@ -2,7 +2,7 @@ import pytest from pandas import DataFrame, MultiIndex -from pandas.core.groupby.base import reduction_kernels +from pandas.core.groupby.base import reduction_kernels, transformation_kernels from pandas.util import testing as tm @@ -105,6 +105,13 @@ def three_group(): ) +@pytest.fixture(params=sorted(transformation_kernels)) +def transformation_func(request): + """yields the string names of all groupby reduction functions, one at a time. + """ + return request.param + + @pytest.fixture(params=sorted(reduction_kernels)) def reduction_func(request): """yields the string names of all groupby reduction functions, one at a time. diff --git a/pandas/tests/groupby/test_transform.py b/pandas/tests/groupby/test_transform.py index d3972e6ba9008..56622ce24e9bf 100644 --- a/pandas/tests/groupby/test_transform.py +++ b/pandas/tests/groupby/test_transform.py @@ -581,7 +581,7 @@ def test_cython_transform_series(op, args, targop): # print(data.head()) expected = data.groupby(labels).transform(targop) - tm.assert_series_equal(expected, data.groupby(labels).transform(op, *args)) + tm.assert_series_equal(expected, getattr(data.groupby(labels), op)(*args)) tm.assert_series_equal(expected, getattr(data.groupby(labels), op)(*args)) @@ -632,7 +632,7 @@ def test_cython_transform_series(op, args, targop): ) def test_groupby_cum_skipna(op, skipna, input, exp): df = pd.DataFrame(input) - result = df.groupby("key")["value"].transform(op, skipna=skipna) + result = getattr(df.groupby("key")["value"], op)(skipna=skipna) if isinstance(exp, dict): expected = exp[(op, skipna)] else: @@ -710,20 +710,17 @@ def test_cython_transform_frame(op, args, targop): expected = gb.apply(targop) expected = expected.sort_index(axis=1) - tm.assert_frame_equal(expected, gb.transform(op, *args).sort_index(axis=1)) tm.assert_frame_equal(expected, getattr(gb, op)(*args).sort_index(axis=1)) # individual columns for c in df: if c not in ["float", "int", "float_missing"] and op != "shift": msg = "No numeric types to aggregate" - with pytest.raises(DataError, match=msg): - gb[c].transform(op) with pytest.raises(DataError, match=msg): getattr(gb[c], op)() else: expected = gb[c].apply(targop) expected.name = c - tm.assert_series_equal(expected, gb[c].transform(op, *args)) + tm.assert_series_equal(expected, getattr(gb[c], op)(*args)) tm.assert_series_equal(expected, getattr(gb[c], op)(*args)) @@ -765,7 +762,7 @@ def test_transform_with_non_scalar_group(): ), ], ) -@pytest.mark.parametrize("agg_func", ["count", "rank", "size"]) +@pytest.mark.parametrize("agg_func", ["count", "size"]) def test_transform_numeric_ret(cols, exp, comp_func, agg_func): if agg_func == "size" and isinstance(cols, list): pytest.xfail("'size' transformation not supported with NDFrameGroupy") @@ -1007,17 +1004,19 @@ def test_transform_invalid_name_raises(): # GH#27486 df = DataFrame(dict(a=[0, 1, 1, 2])) g = df.groupby(["a", "b", "b", "c"]) - with pytest.raises(ValueError, match="not a valid function name"): + with pytest.raises(ValueError, match="exclusively"): g.transform("some_arbitrary_name") # method exists on the object, but is not a valid transformation/agg + # make sure the error suggests using the method directly. assert hasattr(g, "aggregate") # make sure the method exists - with pytest.raises(ValueError, match="not a valid function name"): + with pytest.raises(ValueError, match="exclusively.+you should try"): g.transform("aggregate") # Test SeriesGroupBy - g = df["a"].groupby(["a", "b", "b", "c"]) - with pytest.raises(ValueError, match="not a valid function name"): + ser = Series(range(4)) + g = ser.groupby(["a", "b", "b", "c"]) + with pytest.raises(ValueError, match="exclusively"): g.transform("some_arbitrary_name") @@ -1052,6 +1051,20 @@ def test_transform_agg_by_name(reduction_func, obj): assert len(set(DataFrame(result).iloc[-3:, -1])) == 1 +def test_transform_transformation_by_name(transformation_func): + """Make sure g.transform('name') raises a helpful error for non-agg + """ + func = transformation_func + obj = DataFrame( + dict(a=[0, 0, 0, 1, 1, 1], b=range(6)), index=["A", "B", "C", "D", "E", "F"] + ) + g = obj.groupby(np.repeat([0, 1], 3)) + + match = "exclusively for.+you should try" + with pytest.raises(ValueError, match=match): + g.transform(func) + + def test_transform_lambda_with_datetimetz(): # GH 27496 df = DataFrame(