Skip to content

Commit a6838f8

Browse files
committed
Align all xarray objects
1 parent 162133e commit a6838f8

File tree

2 files changed

+38
-35
lines changed

2 files changed

+38
-35
lines changed

xarray/core/parallel.py

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import numpy as np
2727

28+
from .alignment import align
2829
from .dataarray import DataArray
2930
from .dataset import Dataset
3031

@@ -35,6 +36,13 @@ def get_index_vars(obj):
3536
return {dim: obj[dim] for dim in obj.indexes}
3637

3738

39+
def to_object_array(iterable):
40+
npargs = np.empty((len(iterable),), dtype=np.object)
41+
for idx, item in enumerate(iterable):
42+
npargs[idx] = item
43+
return npargs
44+
45+
3846
def assert_chunks_compatible(a: Dataset, b: Dataset):
3947
a = a.unify_chunks()
4048
b = b.unify_chunks()
@@ -358,32 +366,30 @@ def _wrapper(func, args, kwargs, arg_is_array, expected):
358366
if not dask.is_dask_collection(obj):
359367
return func(obj, *args, **kwargs)
360368

361-
if isinstance(obj, DataArray):
362-
dataset = dataarray_to_dataset(obj)
363-
input_is_array = True
364-
else:
365-
dataset = obj
366-
input_is_array = False
367-
368-
# TODO: align args and dataset here?
369-
input_chunks = dict(dataset.chunks)
370-
input_indexes = get_index_vars(dataset)
371-
converted_args = []
372-
arg_is_array = []
373-
for arg in args:
374-
arg_is_array.append(isinstance(arg, DataArray))
375-
if isinstance(arg, (Dataset, DataArray)):
376-
if isinstance(arg, DataArray):
377-
converted_args.append(dataarray_to_dataset(arg))
378-
assert_chunks_compatible(dataset, converted_args[-1])
379-
input_chunks.update(converted_args[-1].chunks)
380-
input_indexes.update(converted_args[-1].indexes)
381-
else:
382-
converted_args.append(arg)
369+
npargs = to_object_array([obj] + list(args))
370+
is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in npargs]
371+
is_array = [isinstance(arg, DataArray) for arg in npargs]
372+
373+
# align all xarray objects
374+
# TODO: should we allow join as a kwarg or force everything to be aligned to the first object?
375+
aligned = align(*npargs[is_xarray], join="left")
376+
# assigning to object arrays works better when RHS is object array
377+
# https://stackoverflow.com/questions/43645135/boolean-indexing-assignment-of-a-numpy-array-to-a-numpy-array
378+
npargs[is_xarray] = to_object_array(aligned)
379+
npargs[is_array] = to_object_array(
380+
[dataarray_to_dataset(da) for da in npargs[is_array]]
381+
)
382+
383+
input_chunks = dict(npargs[0].chunks)
384+
input_indexes = get_index_vars(npargs[0])
385+
for arg in npargs[1:][is_xarray[1:]]:
386+
assert_chunks_compatible(npargs[0], arg)
387+
input_chunks.update(arg.chunks)
388+
input_indexes.update(arg.indexes)
383389

384390
if template is None:
385391
# infer template by providing zero-shaped arrays
386-
template = infer_template(func, obj, *args, **kwargs)
392+
template = infer_template(func, aligned[0], *args, **kwargs)
387393
template_indexes = set(template.indexes)
388394
preserved_indexes = template_indexes & set(input_indexes)
389395
new_indexes = template_indexes - set(input_indexes)
@@ -420,7 +426,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected):
420426
graph: Dict[Any, Any] = {}
421427
new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict)
422428
gname = "{}-{}".format(
423-
dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs)
429+
dask.utils.funcname(func), dask.base.tokenize(npargs[0], args, kwargs)
424430
)
425431

426432
# map dims to list of chunk indexes
@@ -433,9 +439,9 @@ def _wrapper(func, args, kwargs, arg_is_array, expected):
433439

434440
blocked_args = [
435441
subset_dataset_to_block(graph, gname, arg, input_chunks, chunk_tuple)
436-
if isinstance(arg, (DataArray, Dataset))
442+
if isxr
437443
else arg
438-
for arg in (dataset,) + tuple(converted_args)
444+
for isxr, arg in zip(is_xarray, npargs)
439445
]
440446

441447
# expected["shapes", "coords", "data_vars"] are used to raise nice error messages in _wrapper
@@ -451,14 +457,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected):
451457
expected["coords"] = set(template.coords.keys()) # type: ignore
452458

453459
from_wrapper = (gname,) + chunk_tuple
454-
graph[from_wrapper] = (
455-
_wrapper,
456-
func,
457-
blocked_args,
458-
kwargs,
459-
[input_is_array] + arg_is_array,
460-
expected,
461-
)
460+
graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected)
462461

463462
# mapping from variable name to dask graph key
464463
var_key_map: Dict[Hashable, str] = {}
@@ -491,7 +490,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected):
491490
hlg = HighLevelGraph.from_collections(
492491
gname,
493492
graph,
494-
dependencies=[dataset] + [arg for arg in args if dask.is_dask_collection(arg)],
493+
dependencies=[arg for arg in npargs if dask.is_dask_collection(arg)],
495494
)
496495

497496
for gname_l, layer in new_layers.items():

xarray/tests/test_dask.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,10 @@ def sumda(da1, da2):
11241124
with raises_regex(ValueError, "Chunk sizes along dimension 'x'"):
11251125
xr.map_blocks(operator.add, da1, args=[da1.chunk({"x": 1})])
11261126

1127+
with raise_if_dask_computes():
1128+
mapped = xr.map_blocks(operator.add, da1, args=[da1.reindex(x=np.arange(20))])
1129+
xr.testing.assert_equal(da1 + da1, mapped)
1130+
11271131

11281132
@pytest.mark.parametrize("obj", [make_da(), make_ds()])
11291133
def test_map_blocks_add_attrs(obj):

0 commit comments

Comments
 (0)