Skip to content

Commit 88d4de8

Browse files
committed
Merge branch 'master' into fix_zarr_append_with_groups
2 parents 20ce63f + f56f92b commit 88d4de8

File tree

5 files changed

+171
-15
lines changed

5 files changed

+171
-15
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ New Features
4444
and support `.dt` accessor for timedelta
4545
via :py:class:`core.accessor_dt.TimedeltaAccessor` (:pull:`3612`)
4646
By `Anderson Banihirwe <https://github.com/andersy005>`_.
47+
- :py:meth:`Dataset.rolling` and :py:meth:`DataArray.rolling` now have a stride option
48+
By `Matthias Meyer <https://github.com/niowniow>`_.
4749

4850
Bug fixes
4951
~~~~~~~~~

xarray/core/common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,7 @@ def rolling(
742742
dim: Mapping[Hashable, int] = None,
743743
min_periods: int = None,
744744
center: bool = False,
745+
stride: int = 1,
745746
**window_kwargs: int,
746747
):
747748
"""
@@ -758,6 +759,8 @@ def rolling(
758759
setting min_periods equal to the size of the window.
759760
center : boolean, default False
760761
Set the labels at the center of the window.
762+
stride : int, default 1
763+
Stride of the moving window
761764
**window_kwargs : optional
762765
The keyword arguments form of ``dim``.
763766
One of dim or window_kwargs must be provided.
@@ -800,7 +803,9 @@ def rolling(
800803
core.rolling.DatasetRolling
801804
"""
802805
dim = either_dict_or_kwargs(dim, window_kwargs, "rolling")
803-
return self._rolling_cls(self, dim, min_periods=min_periods, center=center)
806+
return self._rolling_cls(
807+
self, dim, min_periods=min_periods, center=center, stride=stride
808+
)
804809

805810
def rolling_exp(
806811
self,

xarray/core/rolling.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ def count(self):
141141

142142

143143
class DataArrayRolling(Rolling):
144-
__slots__ = ("window_labels",)
144+
__slots__ = ("window_labels", "stride")
145145

146-
def __init__(self, obj, windows, min_periods=None, center=False):
146+
def __init__(self, obj, windows, min_periods=None, center=False, stride=1):
147147
"""
148148
Moving window object for DataArray.
149149
You should use DataArray.rolling() method to construct this object
@@ -165,6 +165,8 @@ def __init__(self, obj, windows, min_periods=None, center=False):
165165
setting min_periods equal to the size of the window.
166166
center : boolean, default False
167167
Set the labels at the center of the window.
168+
stride : int, default 1
169+
Stride of the moving window
168170
169171
Returns
170172
-------
@@ -179,21 +181,33 @@ def __init__(self, obj, windows, min_periods=None, center=False):
179181
"""
180182
super().__init__(obj, windows, min_periods=min_periods, center=center)
181183

182-
self.window_labels = self.obj[self.dim]
184+
if stride is None:
185+
self.stride = 1
186+
else:
187+
self.stride = stride
188+
189+
window_labels = self.obj[self.dim]
190+
self.window_labels = window_labels[:: self.stride]
183191

184192
def __iter__(self):
185-
stops = np.arange(1, len(self.window_labels) + 1)
193+
stops = np.arange(1, len(self.window_labels) * self.stride + 1)
186194
starts = stops - int(self.window)
187195
starts[: int(self.window)] = 0
188-
for (label, start, stop) in zip(self.window_labels, starts, stops):
196+
197+
# apply striding
198+
stops = stops[:: self.stride]
199+
starts = starts[:: self.stride]
200+
window_labels = self.window_labels
201+
202+
for (label, start, stop) in zip(window_labels, starts, stops):
189203
window = self.obj.isel(**{self.dim: slice(start, stop)})
190204

191205
counts = window.count(dim=self.dim)
192206
window = window.where(counts >= self._min_periods)
193207

194208
yield (label, window)
195209

196-
def construct(self, window_dim, stride=1, fill_value=dtypes.NA):
210+
def construct(self, window_dim, stride=None, fill_value=dtypes.NA):
197211
"""
198212
Convert this rolling object to xr.DataArray,
199213
where the window dimension is stacked as a new dimension
@@ -233,6 +247,9 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA):
233247

234248
from .dataarray import DataArray
235249

250+
if stride is None:
251+
stride = self.stride
252+
236253
window = self.obj.variable.rolling_window(
237254
self.dim, self.window, window_dim, self.center, fill_value=fill_value
238255
)
@@ -283,7 +300,7 @@ def reduce(self, func, **kwargs):
283300
[ 4., 9., 15., 18.]])
284301
"""
285302
rolling_dim = utils.get_temp_dimname(self.obj.dims, "_rolling_dim")
286-
windows = self.construct(rolling_dim)
303+
windows = self.construct(rolling_dim, stride=self.stride)
287304
result = windows.reduce(func, dim=rolling_dim, **kwargs)
288305

289306
# Find valid windows based on count.
@@ -301,7 +318,7 @@ def _counts(self):
301318
counts = (
302319
self.obj.notnull()
303320
.rolling(center=self.center, **{self.dim: self.window})
304-
.construct(rolling_dim, fill_value=False)
321+
.construct(rolling_dim, fill_value=False, stride=self.stride)
305322
.sum(dim=rolling_dim, skipna=False)
306323
)
307324
return counts
@@ -347,7 +364,7 @@ def _bottleneck_reduce(self, func, **kwargs):
347364
values = values[valid]
348365
result = DataArray(values, self.obj.coords)
349366

350-
return result
367+
return result.isel(**{self.dim: slice(None, None, self.stride)})
351368

352369
def _numpy_or_bottleneck_reduce(
353370
self, array_agg_func, bottleneck_move_func, **kwargs
@@ -372,9 +389,9 @@ def _numpy_or_bottleneck_reduce(
372389

373390

374391
class DatasetRolling(Rolling):
375-
__slots__ = ("rollings",)
392+
__slots__ = ("rollings", "stride")
376393

377-
def __init__(self, obj, windows, min_periods=None, center=False):
394+
def __init__(self, obj, windows, min_periods=None, center=False, stride=1):
378395
"""
379396
Moving window object for Dataset.
380397
You should use Dataset.rolling() method to construct this object
@@ -396,6 +413,8 @@ def __init__(self, obj, windows, min_periods=None, center=False):
396413
setting min_periods equal to the size of the window.
397414
center : boolean, default False
398415
Set the labels at the center of the window.
416+
stride : int, default 1
417+
Stride of the moving window
399418
400419
Returns
401420
-------
@@ -411,12 +430,15 @@ def __init__(self, obj, windows, min_periods=None, center=False):
411430
super().__init__(obj, windows, min_periods, center)
412431
if self.dim not in self.obj.dims:
413432
raise KeyError(self.dim)
433+
self.stride = stride
414434
# Keep each Rolling object as a dictionary
415435
self.rollings = {}
416436
for key, da in self.obj.data_vars.items():
417437
# keeps rollings only for the dataset depending on slf.dim
418438
if self.dim in da.dims:
419-
self.rollings[key] = DataArrayRolling(da, windows, min_periods, center)
439+
self.rollings[key] = DataArrayRolling(
440+
da, windows, min_periods, center, stride=stride
441+
)
420442

421443
def _dataset_implementation(self, func, **kwargs):
422444
from .dataset import Dataset
@@ -427,7 +449,9 @@ def _dataset_implementation(self, func, **kwargs):
427449
reduced[key] = func(self.rollings[key], **kwargs)
428450
else:
429451
reduced[key] = self.obj[key]
430-
return Dataset(reduced, coords=self.obj.coords)
452+
return Dataset(reduced, coords=self.obj.coords).isel(
453+
**{self.dim: slice(None, None, self.stride)}
454+
)
431455

432456
def reduce(self, func, **kwargs):
433457
"""Reduce the items in this group by applying `func` along some
@@ -466,7 +490,7 @@ def _numpy_or_bottleneck_reduce(
466490
**kwargs,
467491
)
468492

469-
def construct(self, window_dim, stride=1, fill_value=dtypes.NA):
493+
def construct(self, window_dim, stride=None, fill_value=dtypes.NA):
470494
"""
471495
Convert this rolling object to xr.Dataset,
472496
where the window dimension is stacked as a new dimension
@@ -487,6 +511,9 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA):
487511

488512
from .dataset import Dataset
489513

514+
if stride is None:
515+
stride = self.stride
516+
490517
dataset = {}
491518
for key, da in self.obj.data_vars.items():
492519
if self.dim in da.dims:

xarray/tests/test_dataarray.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4309,6 +4309,73 @@ def test_rolling_construct(center, window):
43094309
assert (da_rolling_mean == 0.0).sum() >= 0
43104310

43114311

4312+
@pytest.mark.parametrize("center", (True, False))
4313+
@pytest.mark.parametrize("window", (1, 2, 3, 4))
4314+
@pytest.mark.parametrize("stride", (1, 2, None))
4315+
def test_rolling_stride(center, window, stride):
4316+
s = pd.Series(np.arange(10))
4317+
da = DataArray.from_series(s)
4318+
4319+
s_rolling = s.rolling(window, center=center, min_periods=1).mean()
4320+
da_rolling_strided = da.rolling(
4321+
index=window, center=center, min_periods=1, stride=stride
4322+
)
4323+
4324+
if stride is None:
4325+
stride_index = 1
4326+
else:
4327+
stride_index = stride
4328+
4329+
# with construct
4330+
da_rolling_mean = da_rolling_strided.construct("window").mean("window")
4331+
np.testing.assert_allclose(s_rolling.values[::stride_index], da_rolling_mean.values)
4332+
np.testing.assert_allclose(
4333+
s_rolling.index[::stride_index], da_rolling_mean["index"]
4334+
)
4335+
np.testing.assert_allclose(
4336+
s_rolling.index[::stride_index], da_rolling_mean["index"]
4337+
)
4338+
4339+
# with bottleneck
4340+
da_rolling_strided_mean = da_rolling_strided.mean()
4341+
np.testing.assert_allclose(
4342+
s_rolling.values[::stride_index], da_rolling_strided_mean.values
4343+
)
4344+
np.testing.assert_allclose(
4345+
s_rolling.index[::stride_index], da_rolling_strided_mean["index"]
4346+
)
4347+
np.testing.assert_allclose(
4348+
s_rolling.index[::stride_index], da_rolling_strided_mean["index"]
4349+
)
4350+
4351+
# with fill_value
4352+
da_rolling_mean = da_rolling_strided.construct("window", fill_value=0.0).mean(
4353+
"window"
4354+
)
4355+
assert da_rolling_mean.isnull().sum() == 0
4356+
assert (da_rolling_mean == 0.0).sum() >= 0
4357+
4358+
# with iter
4359+
assert len(da_rolling_strided.window_labels) == len(da["index"]) // stride_index
4360+
assert_identical(da_rolling_strided.window_labels, da["index"][::stride_index])
4361+
4362+
for i, (label, window_da) in enumerate(da_rolling_strided):
4363+
assert label == da["index"].isel(index=i * stride_index)
4364+
4365+
with warnings.catch_warnings():
4366+
warnings.filterwarnings("ignore", "Mean of empty slice")
4367+
actual = da_rolling_strided_mean.isel(index=i)
4368+
expected = window_da.mean("index")
4369+
4370+
# TODO add assert_allclose_with_nan, which compares nan position
4371+
# as well as the closeness of the values.
4372+
assert_array_equal(actual.isnull(), expected.isnull())
4373+
if (~actual.isnull()).sum() > 0:
4374+
np.allclose(
4375+
actual.values, expected.values,
4376+
)
4377+
4378+
43124379
@pytest.mark.parametrize("da", (1, 2), indirect=True)
43134380
@pytest.mark.parametrize("center", (True, False))
43144381
@pytest.mark.parametrize("min_periods", (None, 1, 2, 3))

xarray/tests/test_dataset.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5668,6 +5668,61 @@ def test_rolling_construct(center, window):
56685668
assert (ds_rolling_mean["x"] == 0.0).sum() >= 0
56695669

56705670

5671+
@pytest.mark.parametrize("center", (True, False))
5672+
@pytest.mark.parametrize("window", (1, 2, 3, 4))
5673+
@pytest.mark.parametrize("stride", (1, 2, None))
5674+
def test_rolling_stride(center, window, stride):
5675+
df = pd.DataFrame(
5676+
{
5677+
"x": np.random.randn(20),
5678+
"y": np.random.randn(20),
5679+
"time": np.linspace(0, 1, 20),
5680+
}
5681+
)
5682+
ds = Dataset.from_dataframe(df)
5683+
5684+
df_rolling = df.rolling(window, center=center, min_periods=1).mean()
5685+
ds_rolling_strided = ds.rolling(
5686+
index=window, center=center, min_periods=1, stride=stride
5687+
)
5688+
5689+
if stride is None:
5690+
stride_index = 1
5691+
else:
5692+
stride_index = stride
5693+
5694+
# with construct
5695+
ds_rolling_mean = ds_rolling_strided.construct("window").mean("window")
5696+
np.testing.assert_allclose(
5697+
df_rolling["x"].values[::stride_index], ds_rolling_mean["x"].values
5698+
)
5699+
np.testing.assert_allclose(
5700+
df_rolling.index[::stride_index], ds_rolling_mean["index"]
5701+
)
5702+
np.testing.assert_allclose(
5703+
df_rolling.index[::stride_index], ds_rolling_mean["index"]
5704+
)
5705+
5706+
# with bottleneck
5707+
ds_rolling_strided_mean = ds_rolling_strided.mean()
5708+
np.testing.assert_allclose(
5709+
df_rolling["x"].values[::stride_index], ds_rolling_strided_mean["x"].values
5710+
)
5711+
np.testing.assert_allclose(
5712+
df_rolling.index[::stride_index], ds_rolling_strided_mean["index"]
5713+
)
5714+
np.testing.assert_allclose(
5715+
df_rolling.index[::stride_index], ds_rolling_strided_mean["index"]
5716+
)
5717+
5718+
# with fill_value
5719+
ds_rolling_mean = ds_rolling_strided.construct("window", fill_value=0.0).mean(
5720+
"window"
5721+
)
5722+
assert (ds_rolling_mean.isnull().sum() == 0).to_array(dim="vars").all()
5723+
assert (ds_rolling_mean["x"] == 0.0).sum() >= 0
5724+
5725+
56715726
@pytest.mark.slow
56725727
@pytest.mark.parametrize("ds", (1, 2), indirect=True)
56735728
@pytest.mark.parametrize("center", (True, False))

0 commit comments

Comments
 (0)