37
37
from xarray .core .types import Dims , QuantileMethods , T_DataArray , T_Xarray
38
38
from xarray .core .utils import (
39
39
either_dict_or_kwargs ,
40
+ emit_user_level_warning ,
40
41
hashable ,
41
42
is_scalar ,
42
43
maybe_wrap_array ,
@@ -73,6 +74,21 @@ def check_reduce_dims(reduce_dims, dimensions):
73
74
)
74
75
75
76
77
+ def _maybe_squeeze_indices (
78
+ indices , squeeze : bool | None , grouper : ResolvedGrouper , warn : bool
79
+ ):
80
+ if squeeze in [None , True ] and grouper .can_squeeze :
81
+ if squeeze is None and warn :
82
+ emit_user_level_warning (
83
+ "The `squeeze` kwarg to GroupBy is being removed."
84
+ "Pass .groupby(..., squeeze=False) to silence this warning."
85
+ )
86
+ if isinstance (indices , slice ):
87
+ assert indices .stop - indices .start == 1
88
+ indices = indices .start
89
+ return indices
90
+
91
+
76
92
def unique_value_groups (
77
93
ar , sort : bool = True
78
94
) -> tuple [np .ndarray | pd .Index , T_GroupIndices , np .ndarray ]:
@@ -366,10 +382,10 @@ def dims(self):
366
382
return self .group1d .dims
367
383
368
384
@abstractmethod
369
- def _factorize (self , squeeze : bool ) -> T_FactorizeOut :
385
+ def factorize (self ) -> T_FactorizeOut :
370
386
raise NotImplementedError
371
387
372
- def factorize (self , squeeze : bool ) -> None :
388
+ def _factorize (self ) -> None :
373
389
# This design makes it clear to mypy that
374
390
# codes, group_indices, unique_coord, and full_index
375
391
# are set by the factorize method on the derived class.
@@ -378,7 +394,7 @@ def factorize(self, squeeze: bool) -> None:
378
394
self .group_indices ,
379
395
self .unique_coord ,
380
396
self .full_index ,
381
- ) = self ._factorize ( squeeze )
397
+ ) = self .factorize ( )
382
398
383
399
@property
384
400
def is_unique_and_monotonic (self ) -> bool :
@@ -393,15 +409,19 @@ def group_as_index(self) -> pd.Index:
393
409
self ._group_as_index = self .group1d .to_index ()
394
410
return self ._group_as_index
395
411
412
+ @property
413
+ def can_squeeze (self ):
414
+ is_dimension = self .group .dims == (self .group .name ,)
415
+ return is_dimension and self .is_unique_and_monotonic
416
+
396
417
397
418
@dataclass
398
419
class ResolvedUniqueGrouper (ResolvedGrouper ):
399
420
grouper : UniqueGrouper
400
421
401
- def _factorize (self , squeeze ) -> T_FactorizeOut :
402
- is_dimension = self .group .dims == (self .group .name ,)
403
- if is_dimension and self .is_unique_and_monotonic :
404
- return self ._factorize_dummy (squeeze )
422
+ def factorize (self ) -> T_FactorizeOut :
423
+ if self .can_squeeze :
424
+ return self ._factorize_dummy ()
405
425
else :
406
426
return self ._factorize_unique ()
407
427
@@ -424,15 +444,12 @@ def _factorize_unique(self) -> T_FactorizeOut:
424
444
425
445
return codes , group_indices , unique_coord , full_index
426
446
427
- def _factorize_dummy (self , squeeze ) -> T_FactorizeOut :
447
+ def _factorize_dummy (self ) -> T_FactorizeOut :
428
448
size = self .group .size
429
449
# no need to factorize
430
- if not squeeze :
431
- # use slices to do views instead of fancy indexing
432
- # equivalent to: group_indices = group_indices.reshape(-1, 1)
433
- group_indices : T_GroupIndices = [slice (i , i + 1 ) for i in range (size )]
434
- else :
435
- group_indices = list (range (size ))
450
+ # use slices to do views instead of fancy indexing
451
+ # equivalent to: group_indices = group_indices.reshape(-1, 1)
452
+ group_indices : T_GroupIndices = [slice (i , i + 1 ) for i in range (size )]
436
453
size_range = np .arange (size )
437
454
if isinstance (self .group , _DummyGroup ):
438
455
codes = self .group .to_dataarray ().copy (data = size_range )
@@ -448,7 +465,7 @@ def _factorize_dummy(self, squeeze) -> T_FactorizeOut:
448
465
class ResolvedBinGrouper (ResolvedGrouper ):
449
466
grouper : BinGrouper
450
467
451
- def _factorize (self , squeeze : bool ) -> T_FactorizeOut :
468
+ def factorize (self ) -> T_FactorizeOut :
452
469
from xarray .core .dataarray import DataArray
453
470
454
471
data = self .group1d .values
@@ -546,7 +563,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
546
563
_apply_loffset (self .grouper .loffset , first_items )
547
564
return first_items , codes
548
565
549
- def _factorize (self , squeeze : bool ) -> T_FactorizeOut :
566
+ def factorize (self ) -> T_FactorizeOut :
550
567
full_index , first_items , codes_ = self ._get_index_and_items ()
551
568
sbins = first_items .values .astype (np .int64 )
552
569
group_indices : T_GroupIndices = [
@@ -591,14 +608,14 @@ class TimeResampleGrouper(Grouper):
591
608
loffset : datetime .timedelta | str | None
592
609
593
610
594
- def _validate_groupby_squeeze (squeeze : bool ) -> None :
611
+ def _validate_groupby_squeeze (squeeze : bool | None ) -> None :
595
612
# While we don't generally check the type of every arg, passing
596
613
# multiple dimensions as multiple arguments is common enough, and the
597
614
# consequences hidden enough (strings evaluate as true) to warrant
598
615
# checking here.
599
616
# A future version could make squeeze kwarg only, but would face
600
617
# backward-compat issues.
601
- if not isinstance (squeeze , bool ):
618
+ if squeeze is not None and not isinstance (squeeze , bool ):
602
619
raise TypeError (f"`squeeze` must be True or False, but { squeeze } was supplied" )
603
620
604
621
@@ -730,7 +747,7 @@ def __init__(
730
747
self ._original_obj = obj
731
748
732
749
for grouper_ in self .groupers :
733
- grouper_ .factorize ( squeeze )
750
+ grouper_ ._factorize ( )
734
751
735
752
(grouper ,) = self .groupers
736
753
self ._original_group = grouper .group
@@ -762,9 +779,14 @@ def sizes(self) -> Mapping[Hashable, int]:
762
779
Dataset.sizes
763
780
"""
764
781
if self ._sizes is None :
765
- self ._sizes = self ._obj .isel (
766
- {self ._group_dim : self ._group_indices [0 ]}
767
- ).sizes
782
+ (grouper ,) = self .groupers
783
+ index = _maybe_squeeze_indices (
784
+ self ._group_indices [0 ],
785
+ self ._squeeze ,
786
+ grouper ,
787
+ warn = True ,
788
+ )
789
+ self ._sizes = self ._obj .isel ({self ._group_dim : index }).sizes
768
790
769
791
return self ._sizes
770
792
@@ -798,14 +820,22 @@ def groups(self) -> dict[GroupKey, GroupIndex]:
798
820
# provided to mimic pandas.groupby
799
821
if self ._groups is None :
800
822
(grouper ,) = self .groupers
801
- self ._groups = dict (zip (grouper .unique_coord .values , self ._group_indices ))
823
+ squeezed_indices = (
824
+ _maybe_squeeze_indices (ind , self ._squeeze , grouper , warn = idx > 0 )
825
+ for idx , ind in enumerate (self ._group_indices )
826
+ )
827
+ self ._groups = dict (zip (grouper .unique_coord .values , squeezed_indices ))
802
828
return self ._groups
803
829
804
830
def __getitem__ (self , key : GroupKey ) -> T_Xarray :
805
831
"""
806
832
Get DataArray or Dataset corresponding to a particular group label.
807
833
"""
808
- return self ._obj .isel ({self ._group_dim : self .groups [key ]})
834
+ (grouper ,) = self .groupers
835
+ index = _maybe_squeeze_indices (
836
+ self .groups [key ], self ._squeeze , grouper , warn = True
837
+ )
838
+ return self ._obj .isel ({self ._group_dim : index })
809
839
810
840
def __len__ (self ) -> int :
811
841
(grouper ,) = self .groupers
@@ -826,7 +856,11 @@ def __repr__(self) -> str:
826
856
827
857
def _iter_grouped (self ) -> Iterator [T_Xarray ]:
828
858
"""Iterate over each element in this group"""
829
- for indices in self ._group_indices :
859
+ (grouper ,) = self .groupers
860
+ for idx , indices in enumerate (self ._group_indices ):
861
+ indices = _maybe_squeeze_indices (
862
+ indices , self ._squeeze , grouper , warn = idx > 0
863
+ )
830
864
yield self ._obj .isel ({self ._group_dim : indices })
831
865
832
866
def _infer_concat_args (self , applied_example ):
@@ -1309,7 +1343,11 @@ class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic):
1309
1343
@property
1310
1344
def dims (self ) -> tuple [Hashable , ...]:
1311
1345
if self ._dims is None :
1312
- self ._dims = self ._obj .isel ({self ._group_dim : self ._group_indices [0 ]}).dims
1346
+ (grouper ,) = self .groupers
1347
+ index = _maybe_squeeze_indices (
1348
+ self ._group_indices [0 ], self ._squeeze , grouper , warn = True
1349
+ )
1350
+ self ._dims = self ._obj .isel ({self ._group_dim : index }).dims
1313
1351
1314
1352
return self ._dims
1315
1353
@@ -1318,7 +1356,11 @@ def _iter_grouped_shortcut(self):
1318
1356
metadata
1319
1357
"""
1320
1358
var = self ._obj .variable
1321
- for indices in self ._group_indices :
1359
+ (grouper ,) = self .groupers
1360
+ for idx , indices in enumerate (self ._group_indices ):
1361
+ indices = _maybe_squeeze_indices (
1362
+ indices , self ._squeeze , grouper , warn = idx > 0
1363
+ )
1322
1364
yield var [{self ._group_dim : indices }]
1323
1365
1324
1366
def _concat_shortcut (self , applied , dim , positions = None ):
@@ -1517,7 +1559,14 @@ class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic):
1517
1559
@property
1518
1560
def dims (self ) -> Frozen [Hashable , int ]:
1519
1561
if self ._dims is None :
1520
- self ._dims = self ._obj .isel ({self ._group_dim : self ._group_indices [0 ]}).dims
1562
+ (grouper ,) = self .groupers
1563
+ index = _maybe_squeeze_indices (
1564
+ self ._group_indices [0 ],
1565
+ self ._squeeze ,
1566
+ grouper ,
1567
+ warn = True ,
1568
+ )
1569
+ self ._dims = self ._obj .isel ({self ._group_dim : index }).dims
1521
1570
1522
1571
return self ._dims
1523
1572
0 commit comments