25
25
26
26
import numpy as np
27
27
28
+ from .alignment import align
28
29
from .dataarray import DataArray
29
30
from .dataset import Dataset
30
31
@@ -35,6 +36,13 @@ def get_index_vars(obj):
35
36
return {dim : obj [dim ] for dim in obj .indexes }
36
37
37
38
39
+ def to_object_array (iterable ):
40
+ npargs = np .empty ((len (iterable ),), dtype = np .object )
41
+ for idx , item in enumerate (iterable ):
42
+ npargs [idx ] = item
43
+ return npargs
44
+
45
+
38
46
def assert_chunks_compatible (a : Dataset , b : Dataset ):
39
47
a = a .unify_chunks ()
40
48
b = b .unify_chunks ()
@@ -358,32 +366,30 @@ def _wrapper(func, args, kwargs, arg_is_array, expected):
358
366
if not dask .is_dask_collection (obj ):
359
367
return func (obj , * args , ** kwargs )
360
368
361
- if isinstance (obj , DataArray ):
362
- dataset = dataarray_to_dataset (obj )
363
- input_is_array = True
364
- else :
365
- dataset = obj
366
- input_is_array = False
367
-
368
- # TODO: align args and dataset here?
369
- input_chunks = dict (dataset .chunks )
370
- input_indexes = get_index_vars (dataset )
371
- converted_args = []
372
- arg_is_array = []
373
- for arg in args :
374
- arg_is_array .append (isinstance (arg , DataArray ))
375
- if isinstance (arg , (Dataset , DataArray )):
376
- if isinstance (arg , DataArray ):
377
- converted_args .append (dataarray_to_dataset (arg ))
378
- assert_chunks_compatible (dataset , converted_args [- 1 ])
379
- input_chunks .update (converted_args [- 1 ].chunks )
380
- input_indexes .update (converted_args [- 1 ].indexes )
381
- else :
382
- converted_args .append (arg )
369
+ npargs = to_object_array ([obj ] + list (args ))
370
+ is_xarray = [isinstance (arg , (Dataset , DataArray )) for arg in npargs ]
371
+ is_array = [isinstance (arg , DataArray ) for arg in npargs ]
372
+
373
+ # align all xarray objects
374
+ # TODO: should we allow join as a kwarg or force everything to be aligned to the first object?
375
+ aligned = align (* npargs [is_xarray ], join = "left" )
376
+ # assigning to object arrays works better when RHS is object array
377
+ # https://stackoverflow.com/questions/43645135/boolean-indexing-assignment-of-a-numpy-array-to-a-numpy-array
378
+ npargs [is_xarray ] = to_object_array (aligned )
379
+ npargs [is_array ] = to_object_array (
380
+ [dataarray_to_dataset (da ) for da in npargs [is_array ]]
381
+ )
382
+
383
+ input_chunks = dict (npargs [0 ].chunks )
384
+ input_indexes = get_index_vars (npargs [0 ])
385
+ for arg in npargs [1 :][is_xarray [1 :]]:
386
+ assert_chunks_compatible (npargs [0 ], arg )
387
+ input_chunks .update (arg .chunks )
388
+ input_indexes .update (arg .indexes )
383
389
384
390
if template is None :
385
391
# infer template by providing zero-shaped arrays
386
- template = infer_template (func , obj , * args , ** kwargs )
392
+ template = infer_template (func , aligned [ 0 ] , * args , ** kwargs )
387
393
template_indexes = set (template .indexes )
388
394
preserved_indexes = template_indexes & set (input_indexes )
389
395
new_indexes = template_indexes - set (input_indexes )
@@ -420,7 +426,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected):
420
426
graph : Dict [Any , Any ] = {}
421
427
new_layers : DefaultDict [str , Dict [Any , Any ]] = collections .defaultdict (dict )
422
428
gname = "{}-{}" .format (
423
- dask .utils .funcname (func ), dask .base .tokenize (dataset , args , kwargs )
429
+ dask .utils .funcname (func ), dask .base .tokenize (npargs [ 0 ] , args , kwargs )
424
430
)
425
431
426
432
# map dims to list of chunk indexes
@@ -433,9 +439,9 @@ def _wrapper(func, args, kwargs, arg_is_array, expected):
433
439
434
440
blocked_args = [
435
441
subset_dataset_to_block (graph , gname , arg , input_chunks , chunk_tuple )
436
- if isinstance ( arg , ( DataArray , Dataset ))
442
+ if isxr
437
443
else arg
438
- for arg in ( dataset ,) + tuple ( converted_args )
444
+ for isxr , arg in zip ( is_xarray , npargs )
439
445
]
440
446
441
447
# expected["shapes", "coords", "data_vars"] are used to raise nice error messages in _wrapper
@@ -451,14 +457,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected):
451
457
expected ["coords" ] = set (template .coords .keys ()) # type: ignore
452
458
453
459
from_wrapper = (gname ,) + chunk_tuple
454
- graph [from_wrapper ] = (
455
- _wrapper ,
456
- func ,
457
- blocked_args ,
458
- kwargs ,
459
- [input_is_array ] + arg_is_array ,
460
- expected ,
461
- )
460
+ graph [from_wrapper ] = (_wrapper , func , blocked_args , kwargs , is_array , expected )
462
461
463
462
# mapping from variable name to dask graph key
464
463
var_key_map : Dict [Hashable , str ] = {}
@@ -491,7 +490,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected):
491
490
hlg = HighLevelGraph .from_collections (
492
491
gname ,
493
492
graph ,
494
- dependencies = [dataset ] + [ arg for arg in args if dask .is_dask_collection (arg )],
493
+ dependencies = [arg for arg in npargs if dask .is_dask_collection (arg )],
495
494
)
496
495
497
496
for gname_l , layer in new_layers .items ():
0 commit comments