Skip to content

BUG: groupby.transform(name) validates name is an aggregation #27597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
55 changes: 53 additions & 2 deletions doc/source/whatsnew/v1.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,58 @@ Backwards incompatible API changes

.. _whatsnew_1000.api.other:

- :class:`pandas.core.groupby.GroupBy.transform` now raises on invalid operation names (:issue:`27489`).
-
Groupby.transform(str) validates name is an aggregation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In previous releases, :meth:`DataFrameGroupBy.transform` and
:meth:`SeriesGroupBy.transform` did not validate that the function name
passed was actually the name of an aggregation. As a result, users might get a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the: As a result.... sentence.

rised -> raise

cryptic error or worse, erroneous results. Starting with this release, these
methods will rised if the name of a non-aggregation is passed to them. There
is no change in the behavior associated with passing a callable.

Users who relied on :meth:`DataFrameGroupBy.transform` or :meth:`SeriesGroupBy.transform`
for transformations such as :meth:`DataFrameGroupBy.rank`, :meth:`DataFrameGroupBy.ffill`,
etc, should instead call these method directly
(:issue:`27597`) (:issue:`14274`) (:issue:`19354`) (:issue:`22509`).

.. ipython:: python

df = pd.DataFrame([0, 1, 100, 99])
labels = [0, 0, 1, 1]
g = df.groupby(labels)

*Previous behavior*:

.. code-block:: ipython

In [1]: g.transform('ers >= Decepticons')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just pass it a name like 'foo'

AttributeError: 'DataFrameGroupBy' object has no attribute 'ers >= Decepticons'

g.transform('rank')
Out[14]:
0
0 1.0
1 1.0
2 2.0
3 2.0

g.rank()
Out[15]:
0
0 1.0
1 2.0
2 2.0
3 1.0

*New behavior*:

.. ipython:: python
:okexcept:

g.transform('ers >= Decepticons')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use foo, & make this a code-block (so we don't have the long traceback)

put the 'rank' in its own ipython block; I would also show .rank() or at least indicate that they are now the same.

g.transform('rank')


Other API changes
^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -78,6 +128,7 @@ Performance improvements
Bug fixes
~~~~~~~~~

-

Categorical
^^^^^^^^^^^
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/groupby/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def _gotitem(self, key, ndim, subset=None):
[
"backfill",
"bfill",
"corrwith",
"cumcount",
"cummax",
"cummin",
Expand Down Expand Up @@ -173,6 +172,8 @@ def _gotitem(self, key, ndim, subset=None):
# are neither a transformation nor a reduction
"corr",
"cov",
# corrwith does not preserve shape, depending on `other`
"corrwith",
"describe",
"dtypes",
"expanding",
Expand All @@ -197,4 +198,4 @@ def _gotitem(self, key, ndim, subset=None):
# Valid values of `name` for `groupby.transform(name)`
# NOTE: do NOT edit this directly. New additions should be inserted
# into the appropriate list above.
transform_kernel_whitelist = reduction_kernels | transformation_kernels
groupby_transform_whitelist = reduction_kernels
37 changes: 31 additions & 6 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,8 +576,17 @@ def transform(self, func, *args, **kwargs):
func = self._get_cython_func(func) or func

if isinstance(func, str):
if not (func in base.transform_kernel_whitelist):
msg = "'{func}' is not a valid function name for transform(name)"
if not (func in base.groupby_transform_whitelist):
msg = (
"`g.transform('func')` is used exclusively for "
"computing aggregations and broadcasting "
"the results across groups. "
"'{func}' is not a valid aggregation name. "
)

if func in base.transformation_kernels | base.groupby_other_methods:
msg += "Perhaps you should try <your grouper>.{func}() instead?"
msg = msg.format(func=func)
raise ValueError(msg.format(func=func))
if func in base.cythonized_kernels:
# cythonized transformation or canned "reduction+broadcast"
Expand Down Expand Up @@ -615,7 +624,11 @@ def _transform_fast(self, result, obj, func_nm):
ids, _, ngroup = self.grouper.group_info
output = []
for i, _ in enumerate(result.columns):
res = algorithms.take_1d(result.iloc[:, i].values, ids)
if func_nm in base.reduction_kernels:
# only broadcast results if we performed a reduction
res = algorithms.take_1d(result.iloc[:, i]._values, ids)
else:
res = result.iloc[:, i].values
if cast:
res = self._try_cast(res, obj.iloc[:, i])
output.append(res)
Expand Down Expand Up @@ -1014,8 +1027,17 @@ def transform(self, func, *args, **kwargs):
func = self._get_cython_func(func) or func

if isinstance(func, str):
if not (func in base.transform_kernel_whitelist):
msg = "'{func}' is not a valid function name for transform(name)"
if not (func in base.groupby_transform_whitelist):
msg = (
"`g.transform('func')` is used exclusively for "
"computing aggregations and broadcasting "
"the results across groups. "
"'{func}' is not a valid aggregation name. "
)

if func in base.transformation_kernels | base.groupby_other_methods:
msg += "Perhaps you should try <your grouper>.{func}() instead?"
msg = msg.format(func=func)
raise ValueError(msg.format(func=func))
if func in base.cythonized_kernels:
# cythonized transform or canned "agg+broadcast"
Expand Down Expand Up @@ -1072,7 +1094,10 @@ def _transform_fast(self, func, func_nm):

ids, _, ngroup = self.grouper.group_info
cast = self._transform_should_cast(func_nm)
out = algorithms.take_1d(func()._values, ids)
out = func()
if func_nm in base.reduction_kernels:
# only broadcast results if we performed a reduction
out = algorithms.take_1d(out._values, ids)
if cast:
out = self._try_cast(out, self.obj)
return Series(out, index=self.obj.index, name=self.obj.name)
Expand Down
9 changes: 7 additions & 2 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,9 @@ class providing the base-class of operations.

Parameters
----------
f : function
Function to apply to each group
func : callable or str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leave this as f, otherwise this is an api change

Callable to apply to each group OR
name of an aggregation function.

Returns
-------
Expand All @@ -257,6 +258,10 @@ class providing the base-class of operations.
Each group is endowed the attribute 'name' in case you need to know
which group you are working on.

If `func` is the name of an aggregation, the resulting value for
each group is replicated along the row axis to produce an output
with the same shape as the input.

The current implementation imposes three requirements on f:

* f must return a value that either has the same shape as the input
Expand Down
9 changes: 8 additions & 1 deletion pandas/tests/groupby/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from pandas import DataFrame, MultiIndex
from pandas.core.groupby.base import reduction_kernels
from pandas.core.groupby.base import reduction_kernels, transformation_kernels
from pandas.util import testing as tm


Expand Down Expand Up @@ -105,6 +105,13 @@ def three_group():
)


@pytest.fixture(params=sorted(transformation_kernels))
def transformation_func(request):
"""yields the string names of all groupby reduction functions, one at a time.
"""
return request.param


@pytest.fixture(params=sorted(reduction_kernels))
def reduction_func(request):
"""yields the string names of all groupby reduction functions, one at a time.
Expand Down
35 changes: 24 additions & 11 deletions pandas/tests/groupby/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def test_cython_transform_series(op, args, targop):
# print(data.head())
expected = data.groupby(labels).transform(targop)

tm.assert_series_equal(expected, data.groupby(labels).transform(op, *args))
tm.assert_series_equal(expected, getattr(data.groupby(labels), op)(*args))
tm.assert_series_equal(expected, getattr(data.groupby(labels), op)(*args))


Expand Down Expand Up @@ -632,7 +632,7 @@ def test_cython_transform_series(op, args, targop):
)
def test_groupby_cum_skipna(op, skipna, input, exp):
df = pd.DataFrame(input)
result = df.groupby("key")["value"].transform(op, skipna=skipna)
result = getattr(df.groupby("key")["value"], op)(skipna=skipna)
if isinstance(exp, dict):
expected = exp[(op, skipna)]
else:
Expand Down Expand Up @@ -710,20 +710,17 @@ def test_cython_transform_frame(op, args, targop):
expected = gb.apply(targop)

expected = expected.sort_index(axis=1)
tm.assert_frame_equal(expected, gb.transform(op, *args).sort_index(axis=1))
tm.assert_frame_equal(expected, getattr(gb, op)(*args).sort_index(axis=1))
# individual columns
for c in df:
if c not in ["float", "int", "float_missing"] and op != "shift":
msg = "No numeric types to aggregate"
with pytest.raises(DataError, match=msg):
gb[c].transform(op)
with pytest.raises(DataError, match=msg):
getattr(gb[c], op)()
else:
expected = gb[c].apply(targop)
expected.name = c
tm.assert_series_equal(expected, gb[c].transform(op, *args))
tm.assert_series_equal(expected, getattr(gb[c], op)(*args))
tm.assert_series_equal(expected, getattr(gb[c], op)(*args))


Expand Down Expand Up @@ -765,7 +762,7 @@ def test_transform_with_non_scalar_group():
),
],
)
@pytest.mark.parametrize("agg_func", ["count", "rank", "size"])
@pytest.mark.parametrize("agg_func", ["count", "size"])
def test_transform_numeric_ret(cols, exp, comp_func, agg_func):
if agg_func == "size" and isinstance(cols, list):
pytest.xfail("'size' transformation not supported with NDFrameGroupy")
Expand Down Expand Up @@ -1007,17 +1004,19 @@ def test_transform_invalid_name_raises():
# GH#27486
df = DataFrame(dict(a=[0, 1, 1, 2]))
g = df.groupby(["a", "b", "b", "c"])
with pytest.raises(ValueError, match="not a valid function name"):
with pytest.raises(ValueError, match="exclusively"):
g.transform("some_arbitrary_name")

# method exists on the object, but is not a valid transformation/agg
# make sure the error suggests using the method directly.
assert hasattr(g, "aggregate") # make sure the method exists
with pytest.raises(ValueError, match="not a valid function name"):
with pytest.raises(ValueError, match="exclusively.+you should try"):
g.transform("aggregate")

# Test SeriesGroupBy
g = df["a"].groupby(["a", "b", "b", "c"])
with pytest.raises(ValueError, match="not a valid function name"):
ser = Series(range(4))
g = ser.groupby(["a", "b", "b", "c"])
with pytest.raises(ValueError, match="exclusively"):
g.transform("some_arbitrary_name")


Expand Down Expand Up @@ -1052,6 +1051,20 @@ def test_transform_agg_by_name(reduction_func, obj):
assert len(set(DataFrame(result).iloc[-3:, -1])) == 1


def test_transform_transformation_by_name(transformation_func):
"""Make sure g.transform('name') raises a helpful error for non-agg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add the issue refences numbers as a comment

"""
func = transformation_func
obj = DataFrame(
dict(a=[0, 0, 0, 1, 1, 1], b=range(6)), index=["A", "B", "C", "D", "E", "F"]
)
g = obj.groupby(np.repeat([0, 1], 3))

match = "exclusively for.+you should try"
with pytest.raises(ValueError, match=match):
g.transform(func)


def test_transform_lambda_with_datetimetz():
# GH 27496
df = DataFrame(
Expand Down