Skip to content

Commit 65a5bff

Browse files
ej81dcherian
andauthored
Fix recombination in groupby when changing size along the grouped dimension (#3807)
* Fix recombination in groupby when changing size along the grouped dimension * cleanup tests * minor test rename * minor fix Co-authored-by: dcherian <[email protected]> Co-authored-by: Deepak Cherian <[email protected]>
1 parent cafab46 commit 65a5bff

File tree

3 files changed

+42
-5
lines changed

3 files changed

+42
-5
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,10 @@ Bug fixes
5757
- Fix :py:meth:`Dataset.interp` when indexing array shares coordinates with the
5858
indexed variable (:issue:`3252`).
5959
By `David Huard <https://github.com/huard>`_.
60-
61-
60+
- Fix recombination of groups in :py:meth:`Dataset.groupby` and
61+
:py:meth:`DataArray.groupby` when performing an operation that changes the
62+
size of the groups along the grouped dimension. By `Eric Jansen
63+
<https://github.com/ej81>`_.
6264
- Fix use of multi-index with categorical values (:issue:`3674`).
6365
By `Matthieu Ancellin <https://github.com/mancellin>`_.
6466
- Fix alignment with ``join="override"`` when some dimensions are unindexed. (:issue:`3681`).

xarray/core/groupby.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ def assign_coords(self, coords=None, **coords_kwargs):
720720
def _maybe_reorder(xarray_obj, dim, positions):
721721
order = _inverse_permutation_indices(positions)
722722

723-
if order is None:
723+
if order is None or len(order) != xarray_obj.sizes[dim]:
724724
return xarray_obj
725725
else:
726726
return xarray_obj[{dim: order}]
@@ -838,7 +838,8 @@ def _combine(self, applied, restore_coord_dims=False, shortcut=False):
838838
if isinstance(combined, type(self._obj)):
839839
# only restore dimension order for arrays
840840
combined = self._restore_dim_order(combined)
841-
if coord is not None:
841+
# assign coord when the applied function does not return that coord
842+
if coord is not None and dim not in applied_example.dims:
842843
if shortcut:
843844
coord_var = as_variable(coord)
844845
combined._coords[coord.name] = coord_var
@@ -954,7 +955,8 @@ def _combine(self, applied):
954955
coord, dim, positions = self._infer_concat_args(applied_example)
955956
combined = concat(applied, dim)
956957
combined = _maybe_reorder(combined, dim, positions)
957-
if coord is not None:
958+
# assign coord when the applied function does not return that coord
959+
if coord is not None and dim not in applied_example.dims:
958960
combined[coord.name] = coord
959961
combined = self._maybe_restore_empty_groups(combined)
960962
combined = self._maybe_unstack(combined)

xarray/tests/test_groupby.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,39 @@ def test_groupby_input_mutation():
107107
assert_identical(array, array_copy) # should not modify inputs
108108

109109

110+
@pytest.mark.parametrize(
111+
"obj",
112+
[
113+
xr.DataArray([1, 2, 3, 4, 5, 6], [("x", [1, 1, 1, 2, 2, 2])]),
114+
xr.Dataset({"foo": ("x", [1, 2, 3, 4, 5, 6])}, {"x": [1, 1, 1, 2, 2, 2]}),
115+
],
116+
)
117+
def test_groupby_map_shrink_groups(obj):
118+
expected = obj.isel(x=[0, 1, 3, 4])
119+
actual = obj.groupby("x").map(lambda f: f.isel(x=[0, 1]))
120+
assert_identical(expected, actual)
121+
122+
123+
@pytest.mark.parametrize(
124+
"obj",
125+
[
126+
xr.DataArray([1, 2, 3], [("x", [1, 2, 2])]),
127+
xr.Dataset({"foo": ("x", [1, 2, 3])}, {"x": [1, 2, 2]}),
128+
],
129+
)
130+
def test_groupby_map_change_group_size(obj):
131+
def func(group):
132+
if group.sizes["x"] == 1:
133+
result = group.isel(x=[0, 0])
134+
else:
135+
result = group.isel(x=[0])
136+
return result
137+
138+
expected = obj.isel(x=[0, 0, 1])
139+
actual = obj.groupby("x").map(func)
140+
assert_identical(expected, actual)
141+
142+
110143
def test_da_groupby_map_func_args():
111144
def func(arg1, arg2, arg3=0):
112145
return arg1 + arg2 + arg3

0 commit comments

Comments
 (0)