diff --git a/doc/source/user_guide/groupby.rst b/doc/source/user_guide/groupby.rst index f2d83885df2d0..ba3fb17cc8764 100644 --- a/doc/source/user_guide/groupby.rst +++ b/doc/source/user_guide/groupby.rst @@ -761,7 +761,7 @@ different dtypes, then a common dtype will be determined in the same way as ``Da Transformation -------------- -The ``transform`` method returns an object that is indexed the same (same size) +The ``transform`` method returns an object that is indexed the same as the one being grouped. The transform function must: * Return a result that is either the same size as the group chunk or @@ -776,6 +776,14 @@ as the one being grouped. The transform function must: * (Optionally) operates on the entire group chunk. If this is supported, a fast path is used starting from the *second* chunk. +.. deprecated:: 1.5.0 + + When using ``.transform`` on a grouped DataFrame and the transformation function + returns a DataFrame, currently pandas does not align the result's index + with the input's index. This behavior is deprecated and alignment will + be performed in a future version of pandas. You can apply ``.to_numpy()`` to the + result of the transformation function to avoid alignment. + Similar to :ref:`groupby.aggregate.udfs`, the resulting dtype will reflect that of the transformation function. If the results from different groups have different dtypes, then a common dtype will be determined in the same way as ``DataFrame`` construction. diff --git a/doc/source/whatsnew/v1.5.0.rst b/doc/source/whatsnew/v1.5.0.rst index 2b6e621df211e..55bfb044fb31d 100644 --- a/doc/source/whatsnew/v1.5.0.rst +++ b/doc/source/whatsnew/v1.5.0.rst @@ -691,6 +691,7 @@ Other Deprecations - Deprecated the ``closed`` argument in :class:`ArrowInterval` in favor of ``inclusive`` argument; In a future version passing ``closed`` will raise (:issue:`40245`) - Deprecated allowing ``unit="M"`` or ``unit="Y"`` in :class:`Timestamp` constructor with a non-round float value (:issue:`47267`) - Deprecated the ``display.column_space`` global configuration option (:issue:`7576`) +- Deprecated :meth:`DataFrameGroupBy.transform` not aligning the result when the UDF returned DataFrame (:issue:`45648`) - .. --------------------------------------------------------------------------- diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index a469372d85967..38b93c6be60f8 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1196,14 +1196,33 @@ def _transform_general(self, func, *args, **kwargs): applied.append(res) # Compute and process with the remaining groups + emit_alignment_warning = False for name, group in gen: if group.size == 0: continue object.__setattr__(group, "name", name) res = path(group) + if ( + not emit_alignment_warning + and res.ndim == 2 + and not res.index.equals(group.index) + ): + emit_alignment_warning = True + res = _wrap_transform_general_frame(self.obj, group, res) applied.append(res) + if emit_alignment_warning: + # GH#45648 + warnings.warn( + "In a future version of pandas, returning a DataFrame in " + "groupby.transform will align with the input's index. Apply " + "`.to_numpy()` to the result in the transform function to keep " + "the current behavior and silence this warning.", + FutureWarning, + stacklevel=find_stack_level(), + ) + concat_index = obj.columns if self.axis == 0 else obj.index other_axis = 1 if self.axis == 0 else 0 # switches between 0 & 1 concatenated = concat(applied, axis=self.axis, verify_integrity=False) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 81f4018ef8fc6..c294082edce71 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -373,14 +373,14 @@ class providing the base-class of operations. """ _transform_template = """ -Call function producing a like-indexed %(klass)s on each group and +Call function producing a same-indexed %(klass)s on each group and return a %(klass)s having the same indexes as the original object filled with the transformed values. Parameters ---------- f : function - Function to apply to each group. + Function to apply to each group. See the Notes section below for requirements. Can also accept a Numba JIT function with ``engine='numba'`` specified. @@ -451,6 +451,14 @@ class providing the base-class of operations. The resulting dtype will reflect the return value of the passed ``func``, see the examples below. +.. deprecated:: 1.5.0 + + When using ``.transform`` on a grouped DataFrame and the transformation function + returns a DataFrame, currently pandas does not align the result's index + with the input's index. This behavior is deprecated and alignment will + be performed in a future version of pandas. You can apply ``.to_numpy()`` to the + result of the transformation function to avoid alignment. + Examples -------- diff --git a/pandas/tests/groupby/transform/test_transform.py b/pandas/tests/groupby/transform/test_transform.py index b325edaf2b1ea..5c64ba3d9e266 100644 --- a/pandas/tests/groupby/transform/test_transform.py +++ b/pandas/tests/groupby/transform/test_transform.py @@ -1531,3 +1531,42 @@ def test_null_group_str_transformer_series(request, dropna, transformation_func) result = gb.transform(transformation_func, *args) tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "func, series, expected_values", + [ + (Series.sort_values, False, [4, 5, 3, 1, 2]), + (lambda x: x.head(1), False, ValueError), + # SeriesGroupBy already has correct behavior + (Series.sort_values, True, [5, 4, 3, 2, 1]), + (lambda x: x.head(1), True, [5.0, np.nan, 3.0, 2.0, np.nan]), + ], +) +@pytest.mark.parametrize("keys", [["a1"], ["a1", "a2"]]) +@pytest.mark.parametrize("keys_in_index", [True, False]) +def test_transform_aligns_depr(func, series, expected_values, keys, keys_in_index): + # GH#45648 - transform should align with the input's index + df = DataFrame({"a1": [1, 1, 3, 2, 2], "b": [5, 4, 3, 2, 1]}) + if "a2" in keys: + df["a2"] = df["a1"] + if keys_in_index: + df = df.set_index(keys, append=True) + + gb = df.groupby(keys) + if series: + gb = gb["b"] + + warn = None if series else FutureWarning + msg = "returning a DataFrame in groupby.transform will align" + if expected_values is ValueError: + with tm.assert_produces_warning(warn, match=msg): + with pytest.raises(ValueError, match="Length mismatch"): + gb.transform(func) + else: + with tm.assert_produces_warning(warn, match=msg): + result = gb.transform(func) + expected = DataFrame({"b": expected_values}, index=df.index) + if series: + expected = expected["b"] + tm.assert_equal(result, expected)