Skip to content

Commit fb0cf7b

Browse files
dcherianmax-sixty
andauthored
Another groupby.reduce bugfix. (#3403)
* Another groupby.reduce bugfix. Fixes #3402 * Add whats-new. * Use is_scalar instead * bugfix * fix whats-new * Update xarray/core/groupby.py Co-Authored-By: Maximilian Roos <[email protected]>
1 parent 63cc857 commit fb0cf7b

File tree

4 files changed

+61
-35
lines changed

4 files changed

+61
-35
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ Bug fixes
5555
- Sync with cftime by removing `dayofwk=-1` for cftime>=1.0.4.
5656
By `Anderson Banihirwe <https://github.com/andersy005>`_.
5757

58+
- Fix :py:meth:`xarray.core.groupby.DataArrayGroupBy.reduce` and
59+
:py:meth:`xarray.core.groupby.DatasetGroupBy.reduce` when reducing over multiple dimensions.
60+
(:issue:`3402`). By `Deepak Cherian <https://github.com/dcherian/>`_
61+
5862

5963
Documentation
6064
~~~~~~~~~~~~~

xarray/core/groupby.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,26 @@
1515
from .utils import (
1616
either_dict_or_kwargs,
1717
hashable,
18+
is_scalar,
1819
maybe_wrap_array,
1920
peek_at,
2021
safe_cast_to_index,
2122
)
2223
from .variable import IndexVariable, Variable, as_variable
2324

2425

26+
def check_reduce_dims(reduce_dims, dimensions):
27+
28+
if reduce_dims is not ...:
29+
if is_scalar(reduce_dims):
30+
reduce_dims = [reduce_dims]
31+
if any([dim not in dimensions for dim in reduce_dims]):
32+
raise ValueError(
33+
"cannot reduce over dimensions %r. expected either '...' to reduce over all dimensions or one or more of %r."
34+
% (reduce_dims, dimensions)
35+
)
36+
37+
2538
def unique_value_groups(ar, sort=True):
2639
"""Group an array by its unique values.
2740
@@ -794,15 +807,11 @@ def reduce(
794807
if keep_attrs is None:
795808
keep_attrs = _get_keep_attrs(default=False)
796809

797-
if dim is not ... and dim not in self.dims:
798-
raise ValueError(
799-
"cannot reduce over dimension %r. expected either '...' to reduce over all dimensions or one or more of %r."
800-
% (dim, self.dims)
801-
)
802-
803810
def reduce_array(ar):
804811
return ar.reduce(func, dim, axis, keep_attrs=keep_attrs, **kwargs)
805812

813+
check_reduce_dims(dim, self.dims)
814+
806815
return self.apply(reduce_array, shortcut=shortcut)
807816

808817

@@ -895,11 +904,7 @@ def reduce(self, func, dim=None, keep_attrs=None, **kwargs):
895904
def reduce_dataset(ds):
896905
return ds.reduce(func, dim, keep_attrs, **kwargs)
897906

898-
if dim is not ... and dim not in self.dims:
899-
raise ValueError(
900-
"cannot reduce over dimension %r. expected either '...' to reduce over all dimensions or one or more of %r."
901-
% (dim, self.dims)
902-
)
907+
check_reduce_dims(dim, self.dims)
903908

904909
return self.apply(reduce_dataset)
905910

xarray/tests/test_dataarray.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2560,15 +2560,6 @@ def change_metadata(x):
25602560
expected = change_metadata(expected)
25612561
assert_equal(expected, actual)
25622562

2563-
def test_groupby_reduce_dimension_error(self):
2564-
array = self.make_groupby_example_array()
2565-
grouped = array.groupby("y")
2566-
with raises_regex(ValueError, "cannot reduce over dimension 'y'"):
2567-
grouped.mean()
2568-
2569-
grouped = array.groupby("y", squeeze=False)
2570-
assert_identical(array, grouped.mean())
2571-
25722563
def test_groupby_math(self):
25732564
array = self.make_groupby_example_array()
25742565
for squeeze in [True, False]:

xarray/tests/test_groupby.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,23 @@
55
import xarray as xr
66
from xarray.core.groupby import _consolidate_slices
77

8-
from . import assert_identical, raises_regex
8+
from . import assert_allclose, assert_identical, raises_regex
9+
10+
11+
@pytest.fixture
12+
def dataset():
13+
ds = xr.Dataset(
14+
{"foo": (("x", "y", "z"), np.random.randn(3, 4, 2))},
15+
{"x": ["a", "b", "c"], "y": [1, 2, 3, 4], "z": [1, 2]},
16+
)
17+
ds["boo"] = (("z", "y"), [["f", "g", "h", "j"]] * 2)
18+
19+
return ds
20+
21+
22+
@pytest.fixture
23+
def array(dataset):
24+
return dataset["foo"]
925

1026

1127
def test_consolidate_slices():
@@ -21,25 +37,17 @@ def test_consolidate_slices():
2137
_consolidate_slices([slice(3), 4])
2238

2339

24-
def test_groupby_dims_property():
25-
ds = xr.Dataset(
26-
{"foo": (("x", "y", "z"), np.random.randn(3, 4, 2))},
27-
{"x": ["a", "bcd", "c"], "y": [1, 2, 3, 4], "z": [1, 2]},
28-
)
40+
def test_groupby_dims_property(dataset):
41+
assert dataset.groupby("x").dims == dataset.isel(x=1).dims
42+
assert dataset.groupby("y").dims == dataset.isel(y=1).dims
2943

30-
assert ds.groupby("x").dims == ds.isel(x=1).dims
31-
assert ds.groupby("y").dims == ds.isel(y=1).dims
32-
33-
stacked = ds.stack({"xy": ("x", "y")})
44+
stacked = dataset.stack({"xy": ("x", "y")})
3445
assert stacked.groupby("xy").dims == stacked.isel(xy=0).dims
3546

3647

37-
def test_multi_index_groupby_apply():
48+
def test_multi_index_groupby_apply(dataset):
3849
# regression test for GH873
39-
ds = xr.Dataset(
40-
{"foo": (("x", "y"), np.random.randn(3, 4))},
41-
{"x": ["a", "b", "c"], "y": [1, 2, 3, 4]},
42-
)
50+
ds = dataset.isel(z=1, drop=True)[["foo"]]
4351
doubled = 2 * ds
4452
group_doubled = (
4553
ds.stack(space=["x", "y"])
@@ -276,6 +284,24 @@ def test_groupby_grouping_errors():
276284
dataset.to_array().groupby(dataset.foo * np.nan)
277285

278286

287+
def test_groupby_reduce_dimension_error(array):
288+
grouped = array.groupby("y")
289+
with raises_regex(ValueError, "cannot reduce over dimensions"):
290+
grouped.mean()
291+
292+
with raises_regex(ValueError, "cannot reduce over dimensions"):
293+
grouped.mean("huh")
294+
295+
with raises_regex(ValueError, "cannot reduce over dimensions"):
296+
grouped.mean(("x", "y", "asd"))
297+
298+
grouped = array.groupby("y", squeeze=False)
299+
assert_identical(array, grouped.mean())
300+
301+
assert_identical(array.mean("x"), grouped.reduce(np.mean, "x"))
302+
assert_allclose(array.mean(["x", "z"]), grouped.reduce(np.mean, ["x", "z"]))
303+
304+
279305
def test_groupby_bins_timeseries():
280306
ds = xr.Dataset()
281307
ds["time"] = xr.DataArray(

0 commit comments

Comments
 (0)