diff --git a/doc/source/whatsnew/v1.1.5.rst b/doc/source/whatsnew/v1.1.5.rst index cf728d94b2a55..a122154904996 100644 --- a/doc/source/whatsnew/v1.1.5.rst +++ b/doc/source/whatsnew/v1.1.5.rst @@ -23,7 +23,7 @@ Fixed regressions Bug fixes ~~~~~~~~~ -- +- Bug in metadata propagation for ``groupby`` iterator (:issue:`37343`) - .. --------------------------------------------------------------------------- diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 3aaeef3b63760..5ea4f0cbb6a0c 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -131,9 +131,16 @@ def get_iterator(self, data: FrameOrSeries, axis: int = 0): splitter = self._get_splitter(data, axis=axis) keys = self._get_group_keys() for key, (i, group) in zip(keys, splitter): - yield key, group + yield key, group.__finalize__(data, method="groupby") def _get_splitter(self, data: FrameOrSeries, axis: int = 0) -> "DataSplitter": + """ + Returns + ------- + Generator yielding subsetted objects + + __finalize__ has not been called for the the subsetted objects returned. + """ comp_ids, _, ngroups = self.group_info return get_splitter(data, comp_ids, ngroups, axis=axis) @@ -955,7 +962,8 @@ class SeriesSplitter(DataSplitter): def _chop(self, sdata: Series, slice_obj: slice) -> Series: # fastpath equivalent to `sdata.iloc[slice_obj]` mgr = sdata._mgr.get_slice(slice_obj) - return type(sdata)(mgr, name=sdata.name, fastpath=True) + # __finalize__ not called here, must be applied by caller if applicable + return sdata._constructor(mgr, name=sdata.name, fastpath=True) class FrameSplitter(DataSplitter): @@ -971,7 +979,8 @@ def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame: # else: # return sdata.iloc[:, slice_obj] mgr = sdata._mgr.get_slice(slice_obj, axis=1 - self.axis) - return type(sdata)(mgr) + # __finalize__ not called here, must be applied by caller if applicable + return sdata._constructor(mgr) def get_splitter( diff --git a/pandas/tests/groupby/test_groupby_subclass.py b/pandas/tests/groupby/test_groupby_subclass.py index 7271911c5f80f..a97f4d95a677d 100644 --- a/pandas/tests/groupby/test_groupby_subclass.py +++ b/pandas/tests/groupby/test_groupby_subclass.py @@ -51,6 +51,15 @@ def test_groupby_preserves_subclass(obj, groupby_func): tm.assert_series_equal(result1, result2) +def test_groupby_preserves_metadata(): + # GH-37343 + custom_df = tm.SubclassedDataFrame({"a": [1, 2, 3], "b": [1, 1, 2], "c": [7, 8, 9]}) + assert "testattr" in custom_df._metadata + custom_df.testattr = "hello" + for _, group_df in custom_df.groupby("c"): + assert group_df.testattr == "hello" + + @pytest.mark.parametrize( "obj", [DataFrame, tm.SubclassedDataFrame], )