Skip to content

Commit 31ec4e4

Browse files
committed
added support for selecting multiple nth values
1 parent 912b138 commit 31ec4e4

File tree

3 files changed

+71
-12
lines changed

3 files changed

+71
-12
lines changed

doc/source/groupby.rst

+11-2
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ This shows the first or last n rows from each group.
869869
Taking the nth row of each group
870870
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
871871

872-
To select from a DataFrame or Series the nth item, use the nth method. This is a reduction method, and will return a single row (or no row) per group:
872+
To select from a DataFrame or Series the nth item, use the nth method. This is a reduction method, and will return a single row (or no row) per group if you pass an int for n:
873873

874874
.. ipython:: python
875875
@@ -880,7 +880,7 @@ To select from a DataFrame or Series the nth item, use the nth method. This is a
880880
g.nth(-1)
881881
g.nth(1)
882882
883-
If you want to select the nth not-null method, use the ``dropna`` kwarg. For a DataFrame this should be either ``'any'`` or ``'all'`` just like you would pass to dropna, for a Series this just needs to be truthy.
883+
If you want to select the nth not-null item, use the ``dropna`` kwarg. For a DataFrame this should be either ``'any'`` or ``'all'`` just like you would pass to dropna, for a Series this just needs to be truthy.
884884

885885
.. ipython:: python
886886
@@ -904,6 +904,15 @@ As with other methods, passing ``as_index=False``, will achieve a filtration, wh
904904
g.nth(0)
905905
g.nth(-1)
906906
907+
You can also select multiple rows from each group by specifying multiple nth values as a list of ints.
908+
909+
.. ipython:: python
910+
911+
business_dates = date_range(start='4/1/2014', end='6/30/2014', freq='B')
912+
df = DataFrame(1, index=business_dates, columns=['a', 'b'])
913+
# get the first, 4th, and last date index for each month
914+
df.groupby((df.index.year, df.index.month)).nth([0, 3, -1])
915+
907916
Enumerate group items
908917
~~~~~~~~~~~~~~~~~~~~~
909918

pandas/core/groupby.py

+36-10
Original file line numberDiff line numberDiff line change
@@ -756,12 +756,21 @@ def ohlc(self):
756756

757757
def nth(self, n, dropna=None):
758758
"""
759-
Take the nth row from each group.
759+
Take the nth row from each group if n is an int, or a subset of rows
760+
if n is a list of ints.
760761
761-
If dropna, will not show nth non-null row, dropna is either
762+
If dropna, will take the nth non-null row, dropna is either
762763
Truthy (if a Series) or 'all', 'any' (if a DataFrame); this is equivalent
763764
to calling dropna(how=dropna) before the groupby.
764765
766+
Parameters
767+
----------
768+
n : int or list of ints
769+
a single nth value for the row or a list of nth values
770+
dropna : None or str, optional
771+
apply the specified dropna operation before counting which row is
772+
the nth row. Needs to be None, 'any' or 'all'
773+
765774
Examples
766775
--------
767776
>>> df = DataFrame([[1, np.nan], [1, 4], [5, 6]], columns=['A', 'B'])
@@ -789,19 +798,36 @@ def nth(self, n, dropna=None):
789798
5 NaN
790799
791800
"""
801+
if isinstance(n, int):
802+
nth_values = [n]
803+
elif isinstance(n, (set, list, tuple)):
804+
nth_values = list(set(n))
805+
if dropna is not None:
806+
raise ValueError("dropna option with a list of nth values is not supported")
807+
else:
808+
raise TypeError("n needs to be an int or a list/set/tuple of ints")
809+
810+
m = self.grouper._max_groupsize
811+
# filter out values that are outside [-m, m)
812+
pos_nth_values = [i for i in nth_values if i >= 0 and i < m]
813+
neg_nth_values = [i for i in nth_values if i < 0 and i >= -m]
792814

793815
self._set_selection_from_grouper()
794816
if not dropna: # good choice
795-
m = self.grouper._max_groupsize
796-
if n >= m or n < -m:
817+
if not pos_nth_values and not neg_nth_values:
818+
# no valid nth values
797819
return self._selected_obj.loc[[]]
820+
798821
rng = np.zeros(m, dtype=bool)
799-
if n >= 0:
800-
rng[n] = True
801-
is_nth = self._cumcount_array(rng)
802-
else:
803-
rng[- n - 1] = True
804-
is_nth = self._cumcount_array(rng, ascending=False)
822+
for i in pos_nth_values:
823+
rng[i] = True
824+
is_nth = self._cumcount_array(rng)
825+
826+
if neg_nth_values:
827+
rng = np.zeros(m, dtype=bool)
828+
for i in neg_nth_values:
829+
rng[- i - 1] = True
830+
is_nth |= self._cumcount_array(rng, ascending=False)
805831

806832
result = self._selected_obj[is_nth]
807833

pandas/tests/test_groupby.py

+24
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,30 @@ def test_nth(self):
312312
expected = g.B.first()
313313
assert_series_equal(result,expected)
314314

315+
# test multiple nth values
316+
df = DataFrame([[1, np.nan], [1, 3], [1, 4], [5, 6], [5, 7]],
317+
columns=['A', 'B'])
318+
g = df.groupby('A')
319+
320+
assert_frame_equal(g.nth(0), df.iloc[[0, 3]].set_index('A'))
321+
assert_frame_equal(g.nth([0]), df.iloc[[0, 3]].set_index('A'))
322+
assert_frame_equal(g.nth([0, 1]), df.iloc[[0, 1, 3, 4]].set_index('A'))
323+
assert_frame_equal(g.nth([0, -1]), df.iloc[[0, 2, 3, 4]].set_index('A'))
324+
assert_frame_equal(g.nth([0, 1, 2]), df.iloc[[0, 1, 2, 3, 4]].set_index('A'))
325+
assert_frame_equal(g.nth([0, 1, -1]), df.iloc[[0, 1, 2, 3, 4]].set_index('A'))
326+
assert_frame_equal(g.nth([2]), df.iloc[[2]].set_index('A'))
327+
assert_frame_equal(g.nth([3, 4]), df.loc[[],['B']])
328+
329+
business_dates = pd.date_range(start='4/1/2014', end='6/30/2014', freq='B')
330+
df = DataFrame(1, index=business_dates, columns=['a', 'b'])
331+
# get the first, fourth and last two business days for each month
332+
result = df.groupby((df.index.year, df.index.month)).nth([0, 3, -2, -1])
333+
expected_dates = pd.to_datetime(['2014/4/1', '2014/4/4', '2014/4/29', '2014/4/30',
334+
'2014/5/1', '2014/5/6', '2014/5/29', '2014/5/30',
335+
'2014/6/2', '2014/6/5', '2014/6/27', '2014/6/30'])
336+
expected = DataFrame(1, columns=['a', 'b'], index=expected_dates)
337+
assert_frame_equal(result, expected)
338+
315339
def test_grouper_index_types(self):
316340
# related GH5375
317341
# groupby misbehaving when using a Floatlike index

0 commit comments

Comments
 (0)