diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index f9f17a4f0dc31..272108ecfb45a 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -445,6 +445,7 @@ Groupby/resample/rolling - Bug in :class:`SeriesGroupBy` and :class:`DataFrameGroupBy` on an empty ``Series`` or ``DataFrame`` would lose index, columns, and/or data types when directly using the methods ``idxmax``, ``idxmin``, ``mad``, ``min``, ``max``, ``sum``, ``prod``, and ``skew`` or using them through ``apply``, ``aggregate``, or ``resample`` (:issue:`26411`) - Bug in :meth:`DataFrameGroupBy.sample` where error was raised when ``weights`` was specified and the index was an :class:`Int64Index` (:issue:`39927`) - Bug in :meth:`DataFrameGroupBy.aggregate` and :meth:`.Resampler.aggregate` would sometimes raise ``SpecificationError`` when passed a dictionary and columns were missing; will now always raise a ``KeyError`` instead (:issue:`40004`) +- Bug in :meth:`DataFrameGroupBy.sample` where column selection was not applied to sample result (:issue:`39928`) - Reshaping diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 7bcdb348b8a1e..59d6323ad8b8f 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -3083,11 +3083,12 @@ def sample( if random_state is not None: random_state = com.random_state(random_state) + group_iterator = self.grouper.get_iterator(self._selected_obj, self.axis) samples = [ obj.sample( n=n, frac=frac, replace=replace, weights=w, random_state=random_state ) - for (_, obj), w in zip(self, ws) + for (_, obj), w in zip(group_iterator, ws) ] return concat(samples, axis=self.axis) diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py index 4b8b0173789ae..652a5fc1a3c34 100644 --- a/pandas/tests/groupby/test_sample.py +++ b/pandas/tests/groupby/test_sample.py @@ -132,3 +132,13 @@ def test_groupby_sample_with_weights(index, expected_index): result = df.groupby("a")["b"].sample(n=2, replace=True, weights=[1, 0, 1, 0]) expected = Series(values, name="b", index=Index(expected_index)) tm.assert_series_equal(result, expected) + + +def test_groupby_sample_with_selections(): + # GH 39928 + values = [1] * 10 + [2] * 10 + df = DataFrame({"a": values, "b": values, "c": values}) + + result = df.groupby("a")[["b", "c"]].sample(n=None, frac=None) + expected = DataFrame({"b": [1, 2], "c": [1, 2]}, index=result.index) + tm.assert_frame_equal(result, expected)