Skip to content

Commit 4c8b03b

Browse files
josephnowakpre-commit-ci[bot]phofldcherian
authored
Optimize ffill, bfill with dask when limit is specified (#9771)
* Reduce the number of tasks when the limit parameter is set on the push function * Reduce the number of tasks when the limit parameter is set on the push function, and incorporate the method parameter for the cumreduction on the push method * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update xarray/core/dask_array_ops.py Co-authored-by: Deepak Cherian <[email protected]> * Use last instead of creating a custom function, and add a keepdims parameter for the last and first to make it compatible with the blelloch method * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove the keepdims on the last and first method and use the nanlast method directly, they already have the parameter * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Include the optimization of ffill and bfill on the whats-new.rst * Use map_overlap when the n is smaller than all the chunks * Avoid creating a numpy array to check if all the chunks are bigger than N on the push method * Updating the whats-new.rst * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Patrick Hoefler <[email protected]> Co-authored-by: Deepak Cherian <[email protected]>
1 parent e674286 commit 4c8b03b

File tree

4 files changed

+69
-27
lines changed

4 files changed

+69
-27
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ New Features
2929
- Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])``
3030
(:issue:`2852`, :issue:`757`).
3131
By `Deepak Cherian <https://github.com/dcherian>`_.
32+
- Optimize ffill, bfill with dask when limit is specified
33+
(:pull:`9771`).
34+
By `Joseph Nowak <https://github.com/josephnowak>`_, and
35+
`Patrick Hoefler <https://github.com/phofl>`.
3236
- Allow wrapping ``np.ndarray`` subclasses, e.g. ``astropy.units.Quantity`` (:issue:`9704`, :pull:`9760`).
3337
By `Sam Levang <https://github.com/slevang>`_ and `Tien Vo <https://github.com/tien-vo>`_.
3438
- Optimize :py:meth:`DataArray.polyfit` and :py:meth:`Dataset.polyfit` with dask, when used with

xarray/core/dask_array_ops.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -75,41 +75,71 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
7575
return coeffs, residuals
7676

7777

78-
def push(array, n, axis):
78+
def push(array, n, axis, method="blelloch"):
7979
"""
8080
Dask-aware bottleneck.push
8181
"""
8282
import dask.array as da
8383
import numpy as np
8484

8585
from xarray.core.duck_array_ops import _push
86+
from xarray.core.nputils import nanlast
87+
88+
if n is not None and all(n <= size for size in array.chunks[axis]):
89+
return array.map_overlap(_push, depth={axis: (n, 0)}, n=n, axis=axis)
90+
91+
# TODO: Replace all this function
92+
# once https://github.com/pydata/xarray/issues/9229 being implemented
8693

8794
def _fill_with_last_one(a, b):
88-
# cumreduction apply the push func over all the blocks first so, the only missing part is filling
89-
# the missing values using the last data of the previous chunk
90-
return np.where(~np.isnan(b), b, a)
95+
# cumreduction apply the push func over all the blocks first so,
96+
# the only missing part is filling the missing values using the
97+
# last data of the previous chunk
98+
return np.where(np.isnan(b), a, b)
9199

92-
if n is not None and 0 < n < array.shape[axis] - 1:
93-
arange = da.broadcast_to(
94-
da.arange(
95-
array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype
96-
).reshape(
97-
tuple(size if i == axis else 1 for i, size in enumerate(array.shape))
98-
),
99-
array.shape,
100-
array.chunks,
101-
)
102-
valid_arange = da.where(da.notnull(array), arange, np.nan)
103-
valid_limits = (arange - push(valid_arange, None, axis)) <= n
104-
# omit the forward fill that violate the limit
105-
return da.where(valid_limits, push(array, None, axis), np.nan)
106-
107-
# The method parameter makes that the tests for python 3.7 fails.
108-
return da.reductions.cumreduction(
109-
func=_push,
100+
def _dtype_push(a, axis, dtype=None):
101+
# Not sure why the blelloch algorithm force to receive a dtype
102+
return _push(a, axis=axis)
103+
104+
pushed_array = da.reductions.cumreduction(
105+
func=_dtype_push,
110106
binop=_fill_with_last_one,
111107
ident=np.nan,
112108
x=array,
113109
axis=axis,
114110
dtype=array.dtype,
111+
method=method,
112+
preop=nanlast,
115113
)
114+
115+
if n is not None and 0 < n < array.shape[axis] - 1:
116+
117+
def _reset_cumsum(a, axis, dtype=None):
118+
cumsum = np.cumsum(a, axis=axis)
119+
reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis)
120+
return cumsum - reset_points
121+
122+
def _last_reset_cumsum(a, axis, keepdims=None):
123+
# Take the last cumulative sum taking into account the reset
124+
# This is useful for blelloch method
125+
return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1])
126+
127+
def _combine_reset_cumsum(a, b):
128+
# It is going to sum the previous result until the first
129+
# non nan value
130+
bitmask = np.cumprod(b != 0, axis=axis)
131+
return np.where(bitmask, b + a, b)
132+
133+
valid_positions = da.reductions.cumreduction(
134+
func=_reset_cumsum,
135+
binop=_combine_reset_cumsum,
136+
ident=0,
137+
x=da.isnan(array, dtype=int),
138+
axis=axis,
139+
dtype=int,
140+
method=method,
141+
preop=_last_reset_cumsum,
142+
)
143+
pushed_array = da.where(valid_positions <= n, pushed_array, np.nan)
144+
145+
return pushed_array

xarray/core/duck_array_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,7 @@ def first(values, axis, skipna=None):
716716
return chunked_nanfirst(values, axis)
717717
else:
718718
return nputils.nanfirst(values, axis)
719+
719720
return take(values, 0, axis=axis)
720721

721722

@@ -729,6 +730,7 @@ def last(values, axis, skipna=None):
729730
return chunked_nanlast(values, axis)
730731
else:
731732
return nputils.nanlast(values, axis)
733+
732734
return take(values, -1, axis=axis)
733735

734736

@@ -769,14 +771,14 @@ def _push(array, n: int | None = None, axis: int = -1):
769771
return bn.push(array, limit, axis)
770772

771773

772-
def push(array, n, axis):
774+
def push(array, n, axis, method="blelloch"):
773775
if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]:
774776
raise RuntimeError(
775777
"ffill & bfill requires bottleneck or numbagg to be enabled."
776778
" Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one."
777779
)
778780
if is_duck_dask_array(array):
779-
return dask_array_ops.push(array, n, axis)
781+
return dask_array_ops.push(array, n, axis, method=method)
780782
else:
781783
return _push(array, n, axis)
782784

xarray/tests/test_duck_array_ops.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,8 @@ def test_least_squares(use_dask, skipna):
10081008

10091009
@requires_dask
10101010
@requires_bottleneck
1011-
def test_push_dask():
1011+
@pytest.mark.parametrize("method", ["sequential", "blelloch"])
1012+
def test_push_dask(method):
10121013
import bottleneck
10131014
import dask.array
10141015

@@ -1018,13 +1019,18 @@ def test_push_dask():
10181019
expected = bottleneck.push(array, axis=0, n=n)
10191020
for c in range(1, 11):
10201021
with raise_if_dask_computes():
1021-
actual = push(dask.array.from_array(array, chunks=c), axis=0, n=n)
1022+
actual = push(
1023+
dask.array.from_array(array, chunks=c), axis=0, n=n, method=method
1024+
)
10221025
np.testing.assert_equal(actual, expected)
10231026

10241027
# some chunks of size-1 with NaN
10251028
with raise_if_dask_computes():
10261029
actual = push(
1027-
dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), axis=0, n=n
1030+
dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)),
1031+
axis=0,
1032+
n=n,
1033+
method=method,
10281034
)
10291035
np.testing.assert_equal(actual, expected)
10301036

0 commit comments

Comments
 (0)