Skip to content

Commit 314f007

Browse files
committed
add workaround for dask.pad mode=mean that converts integers to floats, and add an additional check if the shape of output
1 parent 742487e commit 314f007

File tree

4 files changed

+42
-10
lines changed

4 files changed

+42
-10
lines changed

xarray/core/dask_array_compat.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,43 @@ def meta_from_array(x, ndim=None, dtype=None):
9999
return meta
100100

101101

102-
# TODO figure out how Dask versioning works
103-
# if LooseVersion(dask_version) >= LooseVersion("1.7.0"):
104-
try:
105-
pad = da.pad
106-
except AttributeError:
102+
def _validate_pad_output_shape(input_shape, pad_width, output_shape):
103+
""" Dask.array.pad with mode='reflect' does not always return the correct output_shape. """
104+
isint = lambda i: isinstance(i, int)
105+
106+
if isint(pad_width):
107+
pass
108+
elif len(pad_width) == 2 and all(map(isint, pad_width)):
109+
pad_width = sum(pad_width)
110+
elif (
111+
len(pad_width) == len(input_shape)
112+
and all(map(lambda x: len(x) == 2, pad_width))
113+
and all((isint(i) for p in pad_width for i in p))
114+
):
115+
pad_width = np.sum(pad_width, axis=1)
116+
else:
117+
return # should be impossible
118+
119+
if not np.array_equal(np.array(input_shape) + pad_width, output_shape):
120+
raise RuntimeError(
121+
"There seems to be something wrong with the shape of the output of dask.array.pad, "
122+
"try upgrading Dask, use a different pad mode e.g. mode='constant' or first convert "
123+
"your DataArray/Dataset to one backed by a numpy array by calling the `compute()` method."
124+
)
125+
126+
127+
if LooseVersion(dask_version) >= LooseVersion("0.18.1"):
128+
129+
def pad(array, pad_width, mode="constant", **kwargs):
130+
padded = da.pad(array, pad_width, mode=mode, **kwargs)
131+
# workaround for inconsistency between numpy and dask: https://github.com/dask/dask/issues/5303
132+
if mode == "mean" and issubclass(array.dtype.type, np.integer):
133+
return da.round(padded).astype(array.dtype)
134+
_validate_pad_output_shape(array.shape, pad_width, padded.shape)
135+
return padded
136+
137+
138+
else:
107139

108140
def pad(array, pad_width, mode="constant", **kwargs):
109141
"""

xarray/core/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5820,8 +5820,8 @@ def pad(
58205820
reflect_type=reflect_type,
58215821
)
58225822
else:
5823-
variables[name] = var.pad( # type: ignore
5824-
pad_width=var_pad_width, mode=coord_pad_mode, **coord_pad_options,
5823+
variables[name] = var.pad(
5824+
pad_width=var_pad_width, mode=coord_pad_mode, **coord_pad_options, # type: ignore
58255825
)
58265826

58275827
return self._replace_vars_and_dims(variables)

xarray/core/duck_array_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212
import pandas as pd
1313

14-
from . import dask_array_ops, dask_array_compat, dtypes, npcompat, nputils
14+
from . import dask_array_ops, dtypes, npcompat, nputils
1515
from .nputils import nanfirst, nanlast
1616
from .pycompat import dask_array_type
1717

xarray/tests/test_variable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ def test_getitem_error(self):
788788
@pytest.mark.parametrize(
789789
"mode",
790790
[
791-
pytest.param("mean", marks=pytest.mark.xfail),
791+
"mean",
792792
pytest.param("median", marks=pytest.mark.xfail),
793793
pytest.param("reflect", marks=pytest.mark.xfail),
794794
"edge",
@@ -2070,7 +2070,7 @@ def test_getitem_uint(self):
20702070
@pytest.mark.parametrize(
20712071
"mode",
20722072
[
2073-
pytest.param("mean", marks=pytest.mark.xfail),
2073+
"mean",
20742074
pytest.param("median", marks=pytest.mark.xfail),
20752075
pytest.param("reflect", marks=pytest.mark.xfail),
20762076
"edge",

0 commit comments

Comments
 (0)