Skip to content

Commit 562235d

Browse files
authored
CLN: groupby assorted (#41379)
1 parent 58accd7 commit 562235d

File tree

3 files changed

+32
-26
lines changed

3 files changed

+32
-26
lines changed

pandas/core/groupby/generic.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@
8989
MultiIndex,
9090
all_indexes_same,
9191
)
92-
import pandas.core.indexes.base as ibase
9392
from pandas.core.series import Series
9493
from pandas.core.util.numba_ import maybe_use_numba
9594

@@ -481,14 +480,13 @@ def _get_index() -> Index:
481480
if isinstance(values[0], dict):
482481
# GH #823 #24880
483482
index = _get_index()
484-
result: FrameOrSeriesUnion = self._reindex_output(
485-
self.obj._constructor_expanddim(values, index=index)
486-
)
483+
res_df = self.obj._constructor_expanddim(values, index=index)
484+
res_df = self._reindex_output(res_df)
487485
# if self.observed is False,
488486
# keep all-NaN rows created while re-indexing
489-
result = result.stack(dropna=self.observed)
490-
result.name = self._selection_name
491-
return result
487+
res_ser = res_df.stack(dropna=self.observed)
488+
res_ser.name = self._selection_name
489+
return res_ser
492490
elif isinstance(values[0], (Series, DataFrame)):
493491
return self._concat_objects(keys, values, not_indexed_same=not_indexed_same)
494492
else:
@@ -1019,13 +1017,18 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
10191017

10201018
# grouper specific aggregations
10211019
if self.grouper.nkeys > 1:
1020+
# test_groupby_as_index_series_scalar gets here with 'not self.as_index'
10221021
return self._python_agg_general(func, *args, **kwargs)
10231022
elif args or kwargs:
1023+
# test_pass_args_kwargs gets here (with and without as_index)
1024+
# can't return early
10241025
result = self._aggregate_frame(func, *args, **kwargs)
10251026

10261027
elif self.axis == 1:
10271028
# _aggregate_multiple_funcs does not allow self.axis == 1
1029+
# Note: axis == 1 precludes 'not self.as_index', see __init__
10281030
result = self._aggregate_frame(func)
1031+
return result
10291032

10301033
else:
10311034

@@ -1055,7 +1058,7 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
10551058

10561059
if not self.as_index:
10571060
self._insert_inaxis_grouper_inplace(result)
1058-
result.index = np.arange(len(result))
1061+
result.index = Index(range(len(result)))
10591062

10601063
return result._convert(datetime=True)
10611064

@@ -1181,7 +1184,9 @@ def _wrap_applied_output(self, data, keys, values, not_indexed_same=False):
11811184
if self.as_index:
11821185
return self.obj._constructor_sliced(values, index=key_index)
11831186
else:
1184-
result = DataFrame(values, index=key_index, columns=[self._selection])
1187+
result = self.obj._constructor(
1188+
values, index=key_index, columns=[self._selection]
1189+
)
11851190
self._insert_inaxis_grouper_inplace(result)
11861191
return result
11871192
else:
@@ -1664,8 +1669,8 @@ def _wrap_transformed_output(
16641669

16651670
def _wrap_agged_manager(self, mgr: Manager2D) -> DataFrame:
16661671
if not self.as_index:
1667-
index = np.arange(mgr.shape[1])
1668-
mgr.set_axis(1, ibase.Index(index))
1672+
index = Index(range(mgr.shape[1]))
1673+
mgr.set_axis(1, index)
16691674
result = self.obj._constructor(mgr)
16701675

16711676
self._insert_inaxis_grouper_inplace(result)
@@ -1793,7 +1798,7 @@ def nunique(self, dropna: bool = True) -> DataFrame:
17931798
results.columns.names = obj.columns.names # TODO: do at higher level?
17941799

17951800
if not self.as_index:
1796-
results.index = ibase.default_index(len(results))
1801+
results.index = Index(range(len(results)))
17971802
self._insert_inaxis_grouper_inplace(results)
17981803

17991804
return results

pandas/core/groupby/ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -889,9 +889,8 @@ def codes_info(self) -> np.ndarray:
889889

890890
@final
891891
def _get_compressed_codes(self) -> tuple[np.ndarray, np.ndarray]:
892-
all_codes = self.codes
893-
if len(all_codes) > 1:
894-
group_index = get_group_index(all_codes, self.shape, sort=True, xnull=True)
892+
if len(self.groupings) > 1:
893+
group_index = get_group_index(self.codes, self.shape, sort=True, xnull=True)
895894
return compress_group_index(group_index, sort=self.sort)
896895

897896
ping = self.groupings[0]
@@ -1111,6 +1110,7 @@ def groups(self):
11111110

11121111
@property
11131112
def nkeys(self) -> int:
1113+
# still matches len(self.groupings), but we can hard-code
11141114
return 1
11151115

11161116
def _get_grouper(self):

pandas/tests/groupby/test_groupby.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -234,17 +234,18 @@ def f(x, q=None, axis=0):
234234
tm.assert_series_equal(trans_result, trans_expected)
235235

236236
# DataFrame
237-
df_grouped = tsframe.groupby(lambda x: x.month)
238-
agg_result = df_grouped.agg(np.percentile, 80, axis=0)
239-
apply_result = df_grouped.apply(DataFrame.quantile, 0.8)
240-
expected = df_grouped.quantile(0.8)
241-
tm.assert_frame_equal(apply_result, expected, check_names=False)
242-
tm.assert_frame_equal(agg_result, expected)
243-
244-
agg_result = df_grouped.agg(f, q=80)
245-
apply_result = df_grouped.apply(DataFrame.quantile, q=0.8)
246-
tm.assert_frame_equal(agg_result, expected)
247-
tm.assert_frame_equal(apply_result, expected, check_names=False)
237+
for as_index in [True, False]:
238+
df_grouped = tsframe.groupby(lambda x: x.month, as_index=as_index)
239+
agg_result = df_grouped.agg(np.percentile, 80, axis=0)
240+
apply_result = df_grouped.apply(DataFrame.quantile, 0.8)
241+
expected = df_grouped.quantile(0.8)
242+
tm.assert_frame_equal(apply_result, expected, check_names=False)
243+
tm.assert_frame_equal(agg_result, expected)
244+
245+
agg_result = df_grouped.agg(f, q=80)
246+
apply_result = df_grouped.apply(DataFrame.quantile, q=0.8)
247+
tm.assert_frame_equal(agg_result, expected)
248+
tm.assert_frame_equal(apply_result, expected, check_names=False)
248249

249250

250251
def test_len():

0 commit comments

Comments
 (0)