27
27
28
28
from . import dtypes , duck_array_ops , utils
29
29
from .alignment import align , deep_align
30
- from .merge import merge_coordinates_without_align
31
- from .options import OPTIONS
30
+ from .merge import merge_attrs , merge_coordinates_without_align
31
+ from .options import OPTIONS , _get_keep_attrs
32
32
from .pycompat import is_duck_dask_array
33
33
from .utils import is_dict_like
34
34
from .variable import Variable
@@ -50,6 +50,11 @@ def _first_of_type(args, kind):
50
50
raise ValueError ("This should be unreachable." )
51
51
52
52
53
+ def _all_of_type (args , kind ):
54
+ """Return all objects of type 'kind'"""
55
+ return [arg for arg in args if isinstance (arg , kind )]
56
+
57
+
53
58
class _UFuncSignature :
54
59
"""Core dimensions signature for a given function.
55
60
@@ -202,7 +207,10 @@ def _get_coords_list(args) -> List["Coordinates"]:
202
207
203
208
204
209
def build_output_coords (
205
- args : list , signature : _UFuncSignature , exclude_dims : AbstractSet = frozenset ()
210
+ args : list ,
211
+ signature : _UFuncSignature ,
212
+ exclude_dims : AbstractSet = frozenset (),
213
+ combine_attrs : str = "override" ,
206
214
) -> "List[Dict[Any, Variable]]" :
207
215
"""Build output coordinates for an operation.
208
216
@@ -230,7 +238,7 @@ def build_output_coords(
230
238
else :
231
239
# TODO: save these merged indexes, instead of re-computing them later
232
240
merged_vars , unused_indexes = merge_coordinates_without_align (
233
- coords_list , exclude_dims = exclude_dims
241
+ coords_list , exclude_dims = exclude_dims , combine_attrs = combine_attrs
234
242
)
235
243
236
244
output_coords = []
@@ -248,7 +256,12 @@ def build_output_coords(
248
256
249
257
250
258
def apply_dataarray_vfunc (
251
- func , * args , signature , join = "inner" , exclude_dims = frozenset (), keep_attrs = False
259
+ func ,
260
+ * args ,
261
+ signature ,
262
+ join = "inner" ,
263
+ exclude_dims = frozenset (),
264
+ keep_attrs = "override" ,
252
265
):
253
266
"""Apply a variable level function over DataArray, Variable and/or ndarray
254
267
objects.
@@ -260,12 +273,16 @@ def apply_dataarray_vfunc(
260
273
args , join = join , copy = False , exclude = exclude_dims , raise_on_invalid = False
261
274
)
262
275
263
- if keep_attrs :
276
+ objs = _all_of_type (args , DataArray )
277
+
278
+ if keep_attrs == "drop" :
279
+ name = result_name (args )
280
+ else :
264
281
first_obj = _first_of_type (args , DataArray )
265
282
name = first_obj .name
266
- else :
267
- name = result_name ( args )
268
- result_coords = build_output_coords ( args , signature , exclude_dims )
283
+ result_coords = build_output_coords (
284
+ args , signature , exclude_dims , combine_attrs = keep_attrs
285
+ )
269
286
270
287
data_vars = [getattr (a , "variable" , a ) for a in args ]
271
288
result_var = func (* data_vars )
@@ -279,13 +296,12 @@ def apply_dataarray_vfunc(
279
296
(coords ,) = result_coords
280
297
out = DataArray (result_var , coords , name = name , fastpath = True )
281
298
282
- if keep_attrs :
283
- if isinstance (out , tuple ):
284
- for da in out :
285
- # This is adding attrs in place
286
- da ._copy_attrs_from (first_obj )
287
- else :
288
- out ._copy_attrs_from (first_obj )
299
+ attrs = merge_attrs ([x .attrs for x in objs ], combine_attrs = keep_attrs )
300
+ if isinstance (out , tuple ):
301
+ for da in out :
302
+ da .attrs = attrs
303
+ else :
304
+ out .attrs = attrs
289
305
290
306
return out
291
307
@@ -400,7 +416,7 @@ def apply_dataset_vfunc(
400
416
dataset_join = "exact" ,
401
417
fill_value = _NO_FILL_VALUE ,
402
418
exclude_dims = frozenset (),
403
- keep_attrs = False ,
419
+ keep_attrs = "override" ,
404
420
):
405
421
"""Apply a variable level function over Dataset, dict of DataArray,
406
422
DataArray, Variable and/or ndarray objects.
@@ -414,15 +430,16 @@ def apply_dataset_vfunc(
414
430
"dataset_fill_value argument."
415
431
)
416
432
417
- if keep_attrs :
418
- first_obj = _first_of_type (args , Dataset )
433
+ objs = _all_of_type (args , Dataset )
419
434
420
435
if len (args ) > 1 :
421
436
args = deep_align (
422
437
args , join = join , copy = False , exclude = exclude_dims , raise_on_invalid = False
423
438
)
424
439
425
- list_of_coords = build_output_coords (args , signature , exclude_dims )
440
+ list_of_coords = build_output_coords (
441
+ args , signature , exclude_dims , combine_attrs = keep_attrs
442
+ )
426
443
args = [getattr (arg , "data_vars" , arg ) for arg in args ]
427
444
428
445
result_vars = apply_dict_of_variables_vfunc (
@@ -435,13 +452,13 @@ def apply_dataset_vfunc(
435
452
(coord_vars ,) = list_of_coords
436
453
out = _fast_dataset (result_vars , coord_vars )
437
454
438
- if keep_attrs :
439
- if isinstance (out , tuple ):
440
- for ds in out :
441
- # This is adding attrs in place
442
- ds . _copy_attrs_from ( first_obj )
443
- else :
444
- out . _copy_attrs_from ( first_obj )
455
+ attrs = merge_attrs ([ x . attrs for x in objs ], combine_attrs = keep_attrs )
456
+ if isinstance (out , tuple ):
457
+ for ds in out :
458
+ ds . attrs = attrs
459
+ else :
460
+ out . attrs = attrs
461
+
445
462
return out
446
463
447
464
@@ -609,14 +626,12 @@ def apply_variable_ufunc(
609
626
dask = "forbidden" ,
610
627
output_dtypes = None ,
611
628
vectorize = False ,
612
- keep_attrs = False ,
629
+ keep_attrs = "override" ,
613
630
dask_gufunc_kwargs = None ,
614
631
):
615
632
"""Apply a ndarray level function over Variable and/or ndarray objects."""
616
633
from .variable import Variable , as_compatible_data
617
634
618
- first_obj = _first_of_type (args , Variable )
619
-
620
635
dim_sizes = unified_dim_sizes (
621
636
(a for a in args if hasattr (a , "dims" )), exclude_dims = exclude_dims
622
637
)
@@ -736,6 +751,12 @@ def func(*arrays):
736
751
)
737
752
)
738
753
754
+ objs = _all_of_type (args , Variable )
755
+ attrs = merge_attrs (
756
+ [obj .attrs for obj in objs ],
757
+ combine_attrs = keep_attrs ,
758
+ )
759
+
739
760
output = []
740
761
for dims , data in zip (output_dims , result_data ):
741
762
data = as_compatible_data (data )
@@ -758,8 +779,7 @@ def func(*arrays):
758
779
)
759
780
)
760
781
761
- if keep_attrs :
762
- var .attrs .update (first_obj .attrs )
782
+ var .attrs = attrs
763
783
output .append (var )
764
784
765
785
if signature .num_outputs == 1 :
@@ -801,7 +821,7 @@ def apply_ufunc(
801
821
join : str = "exact" ,
802
822
dataset_join : str = "exact" ,
803
823
dataset_fill_value : object = _NO_FILL_VALUE ,
804
- keep_attrs : bool = False ,
824
+ keep_attrs : Union [ bool , str ] = None ,
805
825
kwargs : Mapping = None ,
806
826
dask : str = "forbidden" ,
807
827
output_dtypes : Sequence = None ,
@@ -1098,6 +1118,12 @@ def apply_ufunc(
1098
1118
if kwargs :
1099
1119
func = functools .partial (func , ** kwargs )
1100
1120
1121
+ if keep_attrs is None :
1122
+ keep_attrs = _get_keep_attrs (default = False )
1123
+
1124
+ if isinstance (keep_attrs , bool ):
1125
+ keep_attrs = "override" if keep_attrs else "drop"
1126
+
1101
1127
variables_vfunc = functools .partial (
1102
1128
apply_variable_ufunc ,
1103
1129
func ,
0 commit comments