Skip to content

Commit 291cb80

Browse files
dcherianJoe Hamman
authored and
Joe Hamman
committed
Add groupby.dims & Fix groupby reduce for DataArray (#3338)
* Fix groupby reduce for DataArray * bugfix. * another bugfix. * bugfix unique_and_monotonic for object indexes (uniqueness is enough) * Add groupby.dims property. * update reduce docstring to point to xarray.ALL_DIMS * test for object index dims. * test reduce dimensions error. * Add whats-new * fix docs build * sq whats-new * one more test. * fix test. * undo monotonic change. * Add dimensions to repr. * Raise error if no bins. * Raise nice error if no groups were formed. * Some more error raising and testing. * Add dataset tests. * update whats-new. * fix tests. * make dims a cached lazy property. * fix whats-new. * whitespace * fix whats-new
1 parent 3f0049f commit 291cb80

File tree

5 files changed

+105
-15
lines changed

5 files changed

+105
-15
lines changed

doc/groupby.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,4 +213,4 @@ applying your function, and then unstacking the result:
213213
.. ipython:: python
214214
215215
stacked = da.stack(gridcell=['ny', 'nx'])
216-
stacked.groupby('gridcell').sum().unstack('gridcell')
216+
stacked.groupby('gridcell').sum(xr.ALL_DIMS).unstack('gridcell')

doc/whats-new.rst

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,17 @@ New functions/methods
5252
Enhancements
5353
~~~~~~~~~~~~
5454

55-
- Add a repr for :py:class:`~xarray.core.GroupBy` objects.
56-
Example::
55+
- :py:class:`~xarray.core.GroupBy` enhancements. By `Deepak Cherian <https://github.com/dcherian>`_.
56+
57+
- Added a repr. Example::
5758

5859
>>> da.groupby("time.season")
5960
DataArrayGroupBy, grouped over 'season'
6061
4 groups with labels 'DJF', 'JJA', 'MAM', 'SON'
6162

62-
(:issue:`3344`) by `Deepak Cherian <https://github.com/dcherian>`_.
63+
- Added a ``GroupBy.dims`` property that mirrors the dimensions
64+
of each group.(:issue:`3344`)
65+
6366
- Speed up :meth:`Dataset.isel` up to 33% and :meth:`DataArray.isel` up to 25% for small
6467
arrays (:issue:`2799`, :pull:`3375`) by
6568
`Guido Imperiale <https://github.com/crusaderky>`_.
@@ -73,6 +76,12 @@ Bug fixes
7376
- Line plots with the ``x`` or ``y`` argument set to a 1D non-dimensional coord
7477
now plot the correct data for 2D DataArrays
7578
(:issue:`3334`). By `Tom Nicholas <http://github.com/TomNicholas>`_.
79+
- The default behaviour of reducing across all dimensions for
80+
:py:class:`~xarray.core.groupby.DataArrayGroupBy` objects has now been properly removed
81+
as was done for :py:class:`~xarray.core.groupby.DatasetGroupBy` in 0.13.0 (:issue:`3337`).
82+
Use `xarray.ALL_DIMS` if you need to replicate previous behaviour.
83+
Also raise nicer error message when no groups are created (:issue:`1764`).
84+
By `Deepak Cherian <https://github.com/dcherian>`_.
7685
- Fix error in concatenating unlabeled dimensions (:pull:`3362`).
7786
By `Deepak Cherian <https://github.com/dcherian/>`_.
7887

xarray/core/groupby.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from . import dtypes, duck_array_ops, nputils, ops
99
from .arithmetic import SupportsArithmetic
10-
from .common import ImplementsArrayReduce, ImplementsDatasetReduce
10+
from .common import ALL_DIMS, ImplementsArrayReduce, ImplementsDatasetReduce
1111
from .concat import concat
1212
from .formatting import format_array_flat
1313
from .options import _get_keep_attrs
@@ -248,6 +248,7 @@ class GroupBy(SupportsArithmetic):
248248
"_restore_coord_dims",
249249
"_stacked_dim",
250250
"_unique_coord",
251+
"_dims",
251252
)
252253

253254
def __init__(
@@ -320,6 +321,8 @@ def __init__(
320321
full_index = None
321322

322323
if bins is not None:
324+
if np.isnan(bins).all():
325+
raise ValueError("All bin edges are NaN.")
323326
binned = pd.cut(group.values, bins, **cut_kwargs)
324327
new_dim_name = group.name + "_bins"
325328
group = DataArray(binned, group.coords, name=new_dim_name)
@@ -351,6 +354,16 @@ def __init__(
351354
)
352355
unique_coord = IndexVariable(group.name, unique_values)
353356

357+
if len(group_indices) == 0:
358+
if bins is not None:
359+
raise ValueError(
360+
"None of the data falls within bins with edges %r" % bins
361+
)
362+
else:
363+
raise ValueError(
364+
"Failed to group data. Are you grouping by a variable that is all NaN?"
365+
)
366+
354367
if (
355368
isinstance(obj, DataArray)
356369
and restore_coord_dims is None
@@ -379,6 +392,16 @@ def __init__(
379392

380393
# cached attributes
381394
self._groups = None
395+
self._dims = None
396+
397+
@property
398+
def dims(self):
399+
if self._dims is None:
400+
self._dims = self._obj.isel(
401+
**{self._group_dim: self._group_indices[0]}
402+
).dims
403+
404+
return self._dims
382405

383406
@property
384407
def groups(self):
@@ -394,7 +417,7 @@ def __iter__(self):
394417
return zip(self._unique_coord.values, self._iter_grouped())
395418

396419
def __repr__(self):
397-
return "%s, grouped over %r \n%r groups with labels %s" % (
420+
return "%s, grouped over %r \n%r groups with labels %s." % (
398421
self.__class__.__name__,
399422
self._unique_coord.name,
400423
self._unique_coord.size,
@@ -689,7 +712,7 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
689712
q : float in range of [0,1] (or sequence of floats)
690713
Quantile to compute, which must be between 0 and 1
691714
inclusive.
692-
dim : str or sequence of str, optional
715+
dim : xarray.ALL_DIMS, str or sequence of str, optional
693716
Dimension(s) over which to apply quantile.
694717
Defaults to the grouped dimension.
695718
interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'}
@@ -746,7 +769,7 @@ def reduce(
746769
Function which can be called in the form
747770
`func(x, axis=axis, **kwargs)` to return the result of collapsing
748771
an np.ndarray over an integer valued axis.
749-
dim : str or sequence of str, optional
772+
dim : xarray.ALL_DIMS, str or sequence of str, optional
750773
Dimension(s) over which to apply `func`.
751774
axis : int or sequence of int, optional
752775
Axis(es) over which to apply `func`. Only one of the 'dimension'
@@ -765,9 +788,18 @@ def reduce(
765788
Array with summarized data and the indicated dimension(s)
766789
removed.
767790
"""
791+
if dim is None:
792+
dim = self._group_dim
793+
768794
if keep_attrs is None:
769795
keep_attrs = _get_keep_attrs(default=False)
770796

797+
if dim is not ALL_DIMS and dim not in self.dims:
798+
raise ValueError(
799+
"cannot reduce over dimension %r. expected either xarray.ALL_DIMS to reduce over all dimensions or one or more of %r."
800+
% (dim, self.dims)
801+
)
802+
771803
def reduce_array(ar):
772804
return ar.reduce(func, dim, axis, keep_attrs=keep_attrs, **kwargs)
773805

@@ -835,7 +867,7 @@ def reduce(self, func, dim=None, keep_attrs=None, **kwargs):
835867
Function which can be called in the form
836868
`func(x, axis=axis, **kwargs)` to return the result of collapsing
837869
an np.ndarray over an integer valued axis.
838-
dim : str or sequence of str, optional
870+
dim : xarray.ALL_DIMS, str or sequence of str, optional
839871
Dimension(s) over which to apply `func`.
840872
axis : int or sequence of int, optional
841873
Axis(es) over which to apply `func`. Only one of the 'dimension'
@@ -863,6 +895,12 @@ def reduce(self, func, dim=None, keep_attrs=None, **kwargs):
863895
def reduce_dataset(ds):
864896
return ds.reduce(func, dim, keep_attrs, **kwargs)
865897

898+
if dim is not ALL_DIMS and dim not in self.dims:
899+
raise ValueError(
900+
"cannot reduce over dimension %r. expected either xarray.ALL_DIMS to reduce over all dimensions or one or more of %r."
901+
% (dim, self.dims)
902+
)
903+
866904
return self.apply(reduce_dataset)
867905

868906
def assign(self, **kwargs):

xarray/tests/test_dataarray.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2579,6 +2579,15 @@ def change_metadata(x):
25792579
expected = change_metadata(expected)
25802580
assert_equal(expected, actual)
25812581

2582+
def test_groupby_reduce_dimension_error(self):
2583+
array = self.make_groupby_example_array()
2584+
grouped = array.groupby("y")
2585+
with raises_regex(ValueError, "cannot reduce over dimension 'y'"):
2586+
grouped.mean()
2587+
2588+
grouped = array.groupby("y", squeeze=False)
2589+
assert_identical(array, grouped.mean())
2590+
25822591
def test_groupby_math(self):
25832592
array = self.make_groupby_example_array()
25842593
for squeeze in [True, False]:

xarray/tests/test_groupby.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import xarray as xr
66
from xarray.core.groupby import _consolidate_slices
77

8-
from . import assert_identical
8+
from . import assert_identical, raises_regex
99

1010

1111
def test_consolidate_slices():
@@ -21,6 +21,19 @@ def test_consolidate_slices():
2121
_consolidate_slices([slice(3), 4])
2222

2323

24+
def test_groupby_dims_property():
25+
ds = xr.Dataset(
26+
{"foo": (("x", "y", "z"), np.random.randn(3, 4, 2))},
27+
{"x": ["a", "bcd", "c"], "y": [1, 2, 3, 4], "z": [1, 2]},
28+
)
29+
30+
assert ds.groupby("x").dims == ds.isel(x=1).dims
31+
assert ds.groupby("y").dims == ds.isel(y=1).dims
32+
33+
stacked = ds.stack({"xy": ("x", "y")})
34+
assert stacked.groupby("xy").dims == stacked.isel(xy=0).dims
35+
36+
2437
def test_multi_index_groupby_apply():
2538
# regression test for GH873
2639
ds = xr.Dataset(
@@ -222,13 +235,13 @@ def test_groupby_repr(obj, dim):
222235
expected += ", grouped over %r " % dim
223236
expected += "\n%r groups with labels " % (len(np.unique(obj[dim])))
224237
if dim == "x":
225-
expected += "1, 2, 3, 4, 5"
238+
expected += "1, 2, 3, 4, 5."
226239
elif dim == "y":
227-
expected += "0, 1, 2, 3, 4, 5, ..., 15, 16, 17, 18, 19"
240+
expected += "0, 1, 2, 3, 4, 5, ..., 15, 16, 17, 18, 19."
228241
elif dim == "z":
229-
expected += "'a', 'b', 'c'"
242+
expected += "'a', 'b', 'c'."
230243
elif dim == "month":
231-
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12"
244+
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12."
232245
assert actual == expected
233246

234247

@@ -238,8 +251,29 @@ def test_groupby_repr_datetime(obj):
238251
expected = "%sGroupBy" % obj.__class__.__name__
239252
expected += ", grouped over 'month' "
240253
expected += "\n%r groups with labels " % (len(np.unique(obj.t.dt.month)))
241-
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12"
254+
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12."
242255
assert actual == expected
243256

244257

258+
def test_groupby_grouping_errors():
259+
dataset = xr.Dataset({"foo": ("x", [1, 1, 1])}, {"x": [1, 2, 3]})
260+
with raises_regex(ValueError, "None of the data falls within bins with edges"):
261+
dataset.groupby_bins("x", bins=[0.1, 0.2, 0.3])
262+
263+
with raises_regex(ValueError, "None of the data falls within bins with edges"):
264+
dataset.to_array().groupby_bins("x", bins=[0.1, 0.2, 0.3])
265+
266+
with raises_regex(ValueError, "All bin edges are NaN."):
267+
dataset.groupby_bins("x", bins=[np.nan, np.nan, np.nan])
268+
269+
with raises_regex(ValueError, "All bin edges are NaN."):
270+
dataset.to_array().groupby_bins("x", bins=[np.nan, np.nan, np.nan])
271+
272+
with raises_regex(ValueError, "Failed to group data."):
273+
dataset.groupby(dataset.foo * np.nan)
274+
275+
with raises_regex(ValueError, "Failed to group data."):
276+
dataset.to_array().groupby(dataset.foo * np.nan)
277+
278+
245279
# TODO: move other groupby tests from test_dataset and test_dataarray over here

0 commit comments

Comments
 (0)