Skip to content

Commit 6631407

Browse files
committed
Merge pull request #7910 from mortada/nth_values
added support for selecting multiple nth values
2 parents 09a2415 + 31ec4e4 commit 6631407

File tree

3 files changed

+71
-12
lines changed

3 files changed

+71
-12
lines changed

doc/source/groupby.rst

+11-2
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ This shows the first or last n rows from each group.
869869
Taking the nth row of each group
870870
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
871871

872-
To select from a DataFrame or Series the nth item, use the nth method. This is a reduction method, and will return a single row (or no row) per group:
872+
To select from a DataFrame or Series the nth item, use the nth method. This is a reduction method, and will return a single row (or no row) per group if you pass an int for n:
873873

874874
.. ipython:: python
875875
@@ -880,7 +880,7 @@ To select from a DataFrame or Series the nth item, use the nth method. This is a
880880
g.nth(-1)
881881
g.nth(1)
882882
883-
If you want to select the nth not-null method, use the ``dropna`` kwarg. For a DataFrame this should be either ``'any'`` or ``'all'`` just like you would pass to dropna, for a Series this just needs to be truthy.
883+
If you want to select the nth not-null item, use the ``dropna`` kwarg. For a DataFrame this should be either ``'any'`` or ``'all'`` just like you would pass to dropna, for a Series this just needs to be truthy.
884884

885885
.. ipython:: python
886886
@@ -904,6 +904,15 @@ As with other methods, passing ``as_index=False``, will achieve a filtration, wh
904904
g.nth(0)
905905
g.nth(-1)
906906
907+
You can also select multiple rows from each group by specifying multiple nth values as a list of ints.
908+
909+
.. ipython:: python
910+
911+
business_dates = date_range(start='4/1/2014', end='6/30/2014', freq='B')
912+
df = DataFrame(1, index=business_dates, columns=['a', 'b'])
913+
# get the first, 4th, and last date index for each month
914+
df.groupby((df.index.year, df.index.month)).nth([0, 3, -1])
915+
907916
Enumerate group items
908917
~~~~~~~~~~~~~~~~~~~~~
909918

pandas/core/groupby.py

+36-10
Original file line numberDiff line numberDiff line change
@@ -782,12 +782,21 @@ def ohlc(self):
782782

783783
def nth(self, n, dropna=None):
784784
"""
785-
Take the nth row from each group.
785+
Take the nth row from each group if n is an int, or a subset of rows
786+
if n is a list of ints.
786787
787-
If dropna, will not show nth non-null row, dropna is either
788+
If dropna, will take the nth non-null row, dropna is either
788789
Truthy (if a Series) or 'all', 'any' (if a DataFrame); this is equivalent
789790
to calling dropna(how=dropna) before the groupby.
790791
792+
Parameters
793+
----------
794+
n : int or list of ints
795+
a single nth value for the row or a list of nth values
796+
dropna : None or str, optional
797+
apply the specified dropna operation before counting which row is
798+
the nth row. Needs to be None, 'any' or 'all'
799+
791800
Examples
792801
--------
793802
>>> df = DataFrame([[1, np.nan], [1, 4], [5, 6]], columns=['A', 'B'])
@@ -815,19 +824,36 @@ def nth(self, n, dropna=None):
815824
5 NaN
816825
817826
"""
827+
if isinstance(n, int):
828+
nth_values = [n]
829+
elif isinstance(n, (set, list, tuple)):
830+
nth_values = list(set(n))
831+
if dropna is not None:
832+
raise ValueError("dropna option with a list of nth values is not supported")
833+
else:
834+
raise TypeError("n needs to be an int or a list/set/tuple of ints")
835+
836+
m = self.grouper._max_groupsize
837+
# filter out values that are outside [-m, m)
838+
pos_nth_values = [i for i in nth_values if i >= 0 and i < m]
839+
neg_nth_values = [i for i in nth_values if i < 0 and i >= -m]
818840

819841
self._set_selection_from_grouper()
820842
if not dropna: # good choice
821-
m = self.grouper._max_groupsize
822-
if n >= m or n < -m:
843+
if not pos_nth_values and not neg_nth_values:
844+
# no valid nth values
823845
return self._selected_obj.loc[[]]
846+
824847
rng = np.zeros(m, dtype=bool)
825-
if n >= 0:
826-
rng[n] = True
827-
is_nth = self._cumcount_array(rng)
828-
else:
829-
rng[- n - 1] = True
830-
is_nth = self._cumcount_array(rng, ascending=False)
848+
for i in pos_nth_values:
849+
rng[i] = True
850+
is_nth = self._cumcount_array(rng)
851+
852+
if neg_nth_values:
853+
rng = np.zeros(m, dtype=bool)
854+
for i in neg_nth_values:
855+
rng[- i - 1] = True
856+
is_nth |= self._cumcount_array(rng, ascending=False)
831857

832858
result = self._selected_obj[is_nth]
833859

pandas/tests/test_groupby.py

+24
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,30 @@ def test_nth(self):
313313
expected = g.B.first()
314314
assert_series_equal(result,expected)
315315

316+
# test multiple nth values
317+
df = DataFrame([[1, np.nan], [1, 3], [1, 4], [5, 6], [5, 7]],
318+
columns=['A', 'B'])
319+
g = df.groupby('A')
320+
321+
assert_frame_equal(g.nth(0), df.iloc[[0, 3]].set_index('A'))
322+
assert_frame_equal(g.nth([0]), df.iloc[[0, 3]].set_index('A'))
323+
assert_frame_equal(g.nth([0, 1]), df.iloc[[0, 1, 3, 4]].set_index('A'))
324+
assert_frame_equal(g.nth([0, -1]), df.iloc[[0, 2, 3, 4]].set_index('A'))
325+
assert_frame_equal(g.nth([0, 1, 2]), df.iloc[[0, 1, 2, 3, 4]].set_index('A'))
326+
assert_frame_equal(g.nth([0, 1, -1]), df.iloc[[0, 1, 2, 3, 4]].set_index('A'))
327+
assert_frame_equal(g.nth([2]), df.iloc[[2]].set_index('A'))
328+
assert_frame_equal(g.nth([3, 4]), df.loc[[],['B']])
329+
330+
business_dates = pd.date_range(start='4/1/2014', end='6/30/2014', freq='B')
331+
df = DataFrame(1, index=business_dates, columns=['a', 'b'])
332+
# get the first, fourth and last two business days for each month
333+
result = df.groupby((df.index.year, df.index.month)).nth([0, 3, -2, -1])
334+
expected_dates = pd.to_datetime(['2014/4/1', '2014/4/4', '2014/4/29', '2014/4/30',
335+
'2014/5/1', '2014/5/6', '2014/5/29', '2014/5/30',
336+
'2014/6/2', '2014/6/5', '2014/6/27', '2014/6/30'])
337+
expected = DataFrame(1, columns=['a', 'b'], index=expected_dates)
338+
assert_frame_equal(result, expected)
339+
316340
def test_grouper_index_types(self):
317341
# related GH5375
318342
# groupby misbehaving when using a Floatlike index

0 commit comments

Comments
 (0)