diff --git a/doc/source/groupby.rst b/doc/source/groupby.rst index eaccbfddc1f86..fb1004edca785 100644 --- a/doc/source/groupby.rst +++ b/doc/source/groupby.rst @@ -869,7 +869,7 @@ This shows the first or last n rows from each group. Taking the nth row of each group ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To select from a DataFrame or Series the nth item, use the nth method. This is a reduction method, and will return a single row (or no row) per group: +To select from a DataFrame or Series the nth item, use the nth method. This is a reduction method, and will return a single row (or no row) per group if you pass an int for n: .. ipython:: python @@ -880,7 +880,7 @@ To select from a DataFrame or Series the nth item, use the nth method. This is a g.nth(-1) g.nth(1) -If you want to select the nth not-null method, use the ``dropna`` kwarg. For a DataFrame this should be either ``'any'`` or ``'all'`` just like you would pass to dropna, for a Series this just needs to be truthy. +If you want to select the nth not-null item, use the ``dropna`` kwarg. For a DataFrame this should be either ``'any'`` or ``'all'`` just like you would pass to dropna, for a Series this just needs to be truthy. .. ipython:: python @@ -904,6 +904,15 @@ As with other methods, passing ``as_index=False``, will achieve a filtration, wh g.nth(0) g.nth(-1) +You can also select multiple rows from each group by specifying multiple nth values as a list of ints. + +.. ipython:: python + + business_dates = date_range(start='4/1/2014', end='6/30/2014', freq='B') + df = DataFrame(1, index=business_dates, columns=['a', 'b']) + # get the first, 4th, and last date index for each month + df.groupby((df.index.year, df.index.month)).nth([0, 3, -1]) + Enumerate group items ~~~~~~~~~~~~~~~~~~~~~ diff --git a/pandas/core/groupby.py b/pandas/core/groupby.py index 8cfa0e25b789f..18a16b3262236 100644 --- a/pandas/core/groupby.py +++ b/pandas/core/groupby.py @@ -756,12 +756,21 @@ def ohlc(self): def nth(self, n, dropna=None): """ - Take the nth row from each group. + Take the nth row from each group if n is an int, or a subset of rows + if n is a list of ints. - If dropna, will not show nth non-null row, dropna is either + If dropna, will take the nth non-null row, dropna is either Truthy (if a Series) or 'all', 'any' (if a DataFrame); this is equivalent to calling dropna(how=dropna) before the groupby. + Parameters + ---------- + n : int or list of ints + a single nth value for the row or a list of nth values + dropna : None or str, optional + apply the specified dropna operation before counting which row is + the nth row. Needs to be None, 'any' or 'all' + Examples -------- >>> df = DataFrame([[1, np.nan], [1, 4], [5, 6]], columns=['A', 'B']) @@ -789,19 +798,36 @@ def nth(self, n, dropna=None): 5 NaN """ + if isinstance(n, int): + nth_values = [n] + elif isinstance(n, (set, list, tuple)): + nth_values = list(set(n)) + if dropna is not None: + raise ValueError("dropna option with a list of nth values is not supported") + else: + raise TypeError("n needs to be an int or a list/set/tuple of ints") + + m = self.grouper._max_groupsize + # filter out values that are outside [-m, m) + pos_nth_values = [i for i in nth_values if i >= 0 and i < m] + neg_nth_values = [i for i in nth_values if i < 0 and i >= -m] self._set_selection_from_grouper() if not dropna: # good choice - m = self.grouper._max_groupsize - if n >= m or n < -m: + if not pos_nth_values and not neg_nth_values: + # no valid nth values return self._selected_obj.loc[[]] + rng = np.zeros(m, dtype=bool) - if n >= 0: - rng[n] = True - is_nth = self._cumcount_array(rng) - else: - rng[- n - 1] = True - is_nth = self._cumcount_array(rng, ascending=False) + for i in pos_nth_values: + rng[i] = True + is_nth = self._cumcount_array(rng) + + if neg_nth_values: + rng = np.zeros(m, dtype=bool) + for i in neg_nth_values: + rng[- i - 1] = True + is_nth |= self._cumcount_array(rng, ascending=False) result = self._selected_obj[is_nth] diff --git a/pandas/tests/test_groupby.py b/pandas/tests/test_groupby.py index f958d5481ad33..4c9caecfb99ed 100644 --- a/pandas/tests/test_groupby.py +++ b/pandas/tests/test_groupby.py @@ -312,6 +312,30 @@ def test_nth(self): expected = g.B.first() assert_series_equal(result,expected) + # test multiple nth values + df = DataFrame([[1, np.nan], [1, 3], [1, 4], [5, 6], [5, 7]], + columns=['A', 'B']) + g = df.groupby('A') + + assert_frame_equal(g.nth(0), df.iloc[[0, 3]].set_index('A')) + assert_frame_equal(g.nth([0]), df.iloc[[0, 3]].set_index('A')) + assert_frame_equal(g.nth([0, 1]), df.iloc[[0, 1, 3, 4]].set_index('A')) + assert_frame_equal(g.nth([0, -1]), df.iloc[[0, 2, 3, 4]].set_index('A')) + assert_frame_equal(g.nth([0, 1, 2]), df.iloc[[0, 1, 2, 3, 4]].set_index('A')) + assert_frame_equal(g.nth([0, 1, -1]), df.iloc[[0, 1, 2, 3, 4]].set_index('A')) + assert_frame_equal(g.nth([2]), df.iloc[[2]].set_index('A')) + assert_frame_equal(g.nth([3, 4]), df.loc[[],['B']]) + + business_dates = pd.date_range(start='4/1/2014', end='6/30/2014', freq='B') + df = DataFrame(1, index=business_dates, columns=['a', 'b']) + # get the first, fourth and last two business days for each month + result = df.groupby((df.index.year, df.index.month)).nth([0, 3, -2, -1]) + expected_dates = pd.to_datetime(['2014/4/1', '2014/4/4', '2014/4/29', '2014/4/30', + '2014/5/1', '2014/5/6', '2014/5/29', '2014/5/30', + '2014/6/2', '2014/6/5', '2014/6/27', '2014/6/30']) + expected = DataFrame(1, columns=['a', 'b'], index=expected_dates) + assert_frame_equal(result, expected) + def test_grouper_index_types(self): # related GH5375 # groupby misbehaving when using a Floatlike index