Skip to content
forked from pydata/xarray

Commit d2f162d

Browse files
committed
optimizations for dask array equality comparisons
Closes pydata#3068, pydata#3311
1 parent 79b3cdd commit d2f162d

File tree

6 files changed

+100
-14
lines changed

6 files changed

+100
-14
lines changed

xarray/core/concat.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,28 +183,32 @@ def process_subset_opt(opt, subset):
183183
if opt == "different":
184184
if compat == "override":
185185
raise ValueError(
186-
"Cannot specify both %s='different' and compat='override'."
187-
% subset
186+
f"Cannot specify both {subset}='different' and compat='override'."
188187
)
189188
# all nonindexes that are not the same in each dataset
190189
for k in getattr(datasets[0], subset):
191190
if k not in concat_over:
191+
variables = [ds.variables[k] for ds in datasets]
192+
equals[k] = utils.dask_name_equal(variables)
193+
if equals[k]:
194+
continue
195+
192196
# Compare the variable of all datasets vs. the one
193197
# of the first dataset. Perform the minimum amount of
194198
# loads in order to avoid multiple loads from disk
195199
# while keeping the RAM footprint low.
196-
v_lhs = datasets[0].variables[k].load()
200+
v_lhs = variables[0].load()
197201
# We'll need to know later on if variables are equal.
198-
computed = []
199-
for ds_rhs in datasets[1:]:
200-
v_rhs = ds_rhs.variables[k].compute()
202+
computed = [v_lhs]
203+
for v_rhs in variables[1:]:
204+
v_rhs = v_rhs.compute()
201205
computed.append(v_rhs)
202206
if not getattr(v_lhs, compat)(v_rhs):
203207
concat_over.add(k)
204208
equals[k] = False
205209
# computed variables are not to be re-computed
206210
# again in the future
207-
for ds, v in zip(datasets[1:], computed):
211+
for ds, v in zip(datasets, computed):
208212
ds.variables[k].data = v.data
209213
break
210214
else:

xarray/core/duck_array_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,14 @@ def array_equiv(arr1, arr2):
189189
"""
190190
arr1 = asarray(arr1)
191191
arr2 = asarray(arr2)
192+
if (
193+
dask_array
194+
and isinstance(arr1, dask_array.Array)
195+
and isinstance(arr2, dask_array.Array)
196+
):
197+
# GH3068
198+
if arr1.name == arr2.name:
199+
return True
192200
if arr1.shape != arr2.shape:
193201
return False
194202
with warnings.catch_warnings():

xarray/core/merge.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from . import dtypes, pdcompat
2121
from .alignment import deep_align
22-
from .utils import Frozen, dict_equiv
22+
from .utils import Frozen, dict_equiv, dask_name_equal
2323
from .variable import Variable, as_variable, assert_unique_multiindex_level_names
2424

2525
if TYPE_CHECKING:
@@ -123,11 +123,14 @@ def unique_variable(
123123
combine_method = "fillna"
124124

125125
if equals is None:
126-
out = out.compute()
127-
for var in variables[1:]:
128-
equals = getattr(out, compat)(var)
129-
if not equals:
130-
break
126+
equals = dask_name_equal(variables)
127+
128+
if not equals:
129+
out = out.compute()
130+
for var in variables[1:]:
131+
equals = getattr(out, compat)(var)
132+
if not equals:
133+
break
131134

132135
if not equals:
133136
raise MergeError(

xarray/core/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,3 +676,34 @@ def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable:
676676
while new_dim in dims:
677677
new_dim = "_" + str(new_dim)
678678
return new_dim
679+
680+
681+
def dask_name_equal(variables):
682+
"""
683+
Test variable data for equality by comparing dask names if possible.
684+
685+
Returns
686+
-------
687+
True or False if all variables contain dask arrays and their dask names are equal
688+
or not equal respectively.
689+
690+
None if equality cannot be determined i.e. when not all variables contain dask arrays.
691+
"""
692+
try:
693+
import dask.array as dask_array
694+
except ImportError:
695+
return None
696+
out = variables[0]
697+
equals = None
698+
if isinstance(out.data, dask_array.Array):
699+
for var in variables[1:]:
700+
if isinstance(var.data, dask_array.Array):
701+
if out.data.name == var.data.name:
702+
equals = True
703+
else:
704+
equals = False
705+
break
706+
else:
707+
equals = None
708+
break
709+
return equals

xarray/core/variable.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,9 @@ def transpose(self, *dims) -> "Variable":
12291229
if len(dims) == 0:
12301230
dims = self.dims[::-1]
12311231
axes = self.get_axis_num(dims)
1232-
if len(dims) < 2: # no need to transpose if only one dimension
1232+
if len(dims) < 2 or dims == self.dims:
1233+
# no need to transpose if only one dimension
1234+
# or dims are in same order
12331235
return self.copy(deep=False)
12341236

12351237
data = as_indexable(self._data).transpose(axes)

xarray/tests/test_dask.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
assert_identical,
2323
raises_regex,
2424
)
25+
from ..core.utils import dask_name_equal
2526

2627
dask = pytest.importorskip("dask")
2728
da = pytest.importorskip("dask.array")
@@ -1135,3 +1136,40 @@ def test_make_meta(map_ds):
11351136
for variable in map_ds.data_vars:
11361137
assert variable in meta.data_vars
11371138
assert meta.data_vars[variable].shape == (0,) * meta.data_vars[variable].ndim
1139+
1140+
1141+
def test_identical_coords_no_computes():
1142+
lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
1143+
a = xr.DataArray(
1144+
da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2}
1145+
)
1146+
b = xr.DataArray(
1147+
da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2}
1148+
)
1149+
with raise_if_dask_computes():
1150+
c = a + b
1151+
assert_identical(c, a)
1152+
1153+
1154+
def test_dask_name_equal():
1155+
lons1 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
1156+
lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
1157+
with raise_if_dask_computes():
1158+
assert dask_name_equal([lons1, lons2])
1159+
with raise_if_dask_computes():
1160+
assert not dask_name_equal([lons1, lons2 / 2])
1161+
assert dask_name_equal([lons1, lons2.compute()]) is None
1162+
assert dask_name_equal([lons1.compute(), lons2.compute()]) is None
1163+
1164+
with raise_if_dask_computes():
1165+
assert dask_name_equal([lons1, lons1.transpose("y", "x")])
1166+
1167+
with raise_if_dask_computes():
1168+
for compat in [
1169+
"broadcast_equals",
1170+
"equals",
1171+
"override",
1172+
"identical",
1173+
"no_conflicts",
1174+
]:
1175+
xr.merge([lons1, lons2], compat=compat)

0 commit comments

Comments
 (0)