@@ -141,9 +141,9 @@ def count(self):
141
141
142
142
143
143
class DataArrayRolling (Rolling ):
144
- __slots__ = ("window_labels" ,)
144
+ __slots__ = ("window_labels" , "stride" )
145
145
146
- def __init__ (self , obj , windows , min_periods = None , center = False ):
146
+ def __init__ (self , obj , windows , min_periods = None , center = False , stride = 1 ):
147
147
"""
148
148
Moving window object for DataArray.
149
149
You should use DataArray.rolling() method to construct this object
@@ -165,6 +165,8 @@ def __init__(self, obj, windows, min_periods=None, center=False):
165
165
setting min_periods equal to the size of the window.
166
166
center : boolean, default False
167
167
Set the labels at the center of the window.
168
+ stride : int, default 1
169
+ Stride of the moving window
168
170
169
171
Returns
170
172
-------
@@ -179,21 +181,33 @@ def __init__(self, obj, windows, min_periods=None, center=False):
179
181
"""
180
182
super ().__init__ (obj , windows , min_periods = min_periods , center = center )
181
183
182
- self .window_labels = self .obj [self .dim ]
184
+ if stride is None :
185
+ self .stride = 1
186
+ else :
187
+ self .stride = stride
188
+
189
+ window_labels = self .obj [self .dim ]
190
+ self .window_labels = window_labels [:: self .stride ]
183
191
184
192
def __iter__ (self ):
185
- stops = np .arange (1 , len (self .window_labels ) + 1 )
193
+ stops = np .arange (1 , len (self .window_labels ) * self . stride + 1 )
186
194
starts = stops - int (self .window )
187
195
starts [: int (self .window )] = 0
188
- for (label , start , stop ) in zip (self .window_labels , starts , stops ):
196
+
197
+ # apply striding
198
+ stops = stops [:: self .stride ]
199
+ starts = starts [:: self .stride ]
200
+ window_labels = self .window_labels
201
+
202
+ for (label , start , stop ) in zip (window_labels , starts , stops ):
189
203
window = self .obj .isel (** {self .dim : slice (start , stop )})
190
204
191
205
counts = window .count (dim = self .dim )
192
206
window = window .where (counts >= self ._min_periods )
193
207
194
208
yield (label , window )
195
209
196
- def construct (self , window_dim , stride = 1 , fill_value = dtypes .NA ):
210
+ def construct (self , window_dim , stride = None , fill_value = dtypes .NA ):
197
211
"""
198
212
Convert this rolling object to xr.DataArray,
199
213
where the window dimension is stacked as a new dimension
@@ -233,6 +247,9 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA):
233
247
234
248
from .dataarray import DataArray
235
249
250
+ if stride is None :
251
+ stride = self .stride
252
+
236
253
window = self .obj .variable .rolling_window (
237
254
self .dim , self .window , window_dim , self .center , fill_value = fill_value
238
255
)
@@ -283,7 +300,7 @@ def reduce(self, func, **kwargs):
283
300
[ 4., 9., 15., 18.]])
284
301
"""
285
302
rolling_dim = utils .get_temp_dimname (self .obj .dims , "_rolling_dim" )
286
- windows = self .construct (rolling_dim )
303
+ windows = self .construct (rolling_dim , stride = self . stride )
287
304
result = windows .reduce (func , dim = rolling_dim , ** kwargs )
288
305
289
306
# Find valid windows based on count.
@@ -301,7 +318,7 @@ def _counts(self):
301
318
counts = (
302
319
self .obj .notnull ()
303
320
.rolling (center = self .center , ** {self .dim : self .window })
304
- .construct (rolling_dim , fill_value = False )
321
+ .construct (rolling_dim , fill_value = False , stride = self . stride )
305
322
.sum (dim = rolling_dim , skipna = False )
306
323
)
307
324
return counts
@@ -347,7 +364,7 @@ def _bottleneck_reduce(self, func, **kwargs):
347
364
values = values [valid ]
348
365
result = DataArray (values , self .obj .coords )
349
366
350
- return result
367
+ return result . isel ( ** { self . dim : slice ( None , None , self . stride )})
351
368
352
369
def _numpy_or_bottleneck_reduce (
353
370
self , array_agg_func , bottleneck_move_func , ** kwargs
@@ -372,9 +389,9 @@ def _numpy_or_bottleneck_reduce(
372
389
373
390
374
391
class DatasetRolling (Rolling ):
375
- __slots__ = ("rollings" ,)
392
+ __slots__ = ("rollings" , "stride" )
376
393
377
- def __init__ (self , obj , windows , min_periods = None , center = False ):
394
+ def __init__ (self , obj , windows , min_periods = None , center = False , stride = 1 ):
378
395
"""
379
396
Moving window object for Dataset.
380
397
You should use Dataset.rolling() method to construct this object
@@ -396,6 +413,8 @@ def __init__(self, obj, windows, min_periods=None, center=False):
396
413
setting min_periods equal to the size of the window.
397
414
center : boolean, default False
398
415
Set the labels at the center of the window.
416
+ stride : int, default 1
417
+ Stride of the moving window
399
418
400
419
Returns
401
420
-------
@@ -411,12 +430,15 @@ def __init__(self, obj, windows, min_periods=None, center=False):
411
430
super ().__init__ (obj , windows , min_periods , center )
412
431
if self .dim not in self .obj .dims :
413
432
raise KeyError (self .dim )
433
+ self .stride = stride
414
434
# Keep each Rolling object as a dictionary
415
435
self .rollings = {}
416
436
for key , da in self .obj .data_vars .items ():
417
437
# keeps rollings only for the dataset depending on slf.dim
418
438
if self .dim in da .dims :
419
- self .rollings [key ] = DataArrayRolling (da , windows , min_periods , center )
439
+ self .rollings [key ] = DataArrayRolling (
440
+ da , windows , min_periods , center , stride = stride
441
+ )
420
442
421
443
def _dataset_implementation (self , func , ** kwargs ):
422
444
from .dataset import Dataset
@@ -427,7 +449,9 @@ def _dataset_implementation(self, func, **kwargs):
427
449
reduced [key ] = func (self .rollings [key ], ** kwargs )
428
450
else :
429
451
reduced [key ] = self .obj [key ]
430
- return Dataset (reduced , coords = self .obj .coords )
452
+ return Dataset (reduced , coords = self .obj .coords ).isel (
453
+ ** {self .dim : slice (None , None , self .stride )}
454
+ )
431
455
432
456
def reduce (self , func , ** kwargs ):
433
457
"""Reduce the items in this group by applying `func` along some
@@ -466,7 +490,7 @@ def _numpy_or_bottleneck_reduce(
466
490
** kwargs ,
467
491
)
468
492
469
- def construct (self , window_dim , stride = 1 , fill_value = dtypes .NA ):
493
+ def construct (self , window_dim , stride = None , fill_value = dtypes .NA ):
470
494
"""
471
495
Convert this rolling object to xr.Dataset,
472
496
where the window dimension is stacked as a new dimension
@@ -487,6 +511,9 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA):
487
511
488
512
from .dataset import Dataset
489
513
514
+ if stride is None :
515
+ stride = self .stride
516
+
490
517
dataset = {}
491
518
for key , da in self .obj .data_vars .items ():
492
519
if self .dim in da .dims :
0 commit comments