Skip to content

Commit 751f76a

Browse files
authored
combine keep_attrs and combine_attrs in apply_ufunc (pydata#5041)
1 parent 1f52ae0 commit 751f76a

File tree

4 files changed

+460
-34
lines changed

4 files changed

+460
-34
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ v0.18.1 (unreleased)
2121

2222
New Features
2323
~~~~~~~~~~~~
24+
- allow passing ``combine_attrs`` strategy names to the ``keep_attrs`` parameter of
25+
:py:func:`apply_ufunc` (:pull:`5041`)
26+
By `Justus Magin <https://github.com/keewis>`_.
2427
- :py:meth:`Dataset.interp` now allows interpolation with non-numerical datatypes,
2528
such as booleans, instead of dropping them. (:issue:`4761` :pull:`5008`).
2629
By `Jimmy Westling <https://github.com/illviljan>`_.

xarray/core/computation.py

Lines changed: 59 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727

2828
from . import dtypes, duck_array_ops, utils
2929
from .alignment import align, deep_align
30-
from .merge import merge_coordinates_without_align
31-
from .options import OPTIONS
30+
from .merge import merge_attrs, merge_coordinates_without_align
31+
from .options import OPTIONS, _get_keep_attrs
3232
from .pycompat import is_duck_dask_array
3333
from .utils import is_dict_like
3434
from .variable import Variable
@@ -50,6 +50,11 @@ def _first_of_type(args, kind):
5050
raise ValueError("This should be unreachable.")
5151

5252

53+
def _all_of_type(args, kind):
54+
"""Return all objects of type 'kind'"""
55+
return [arg for arg in args if isinstance(arg, kind)]
56+
57+
5358
class _UFuncSignature:
5459
"""Core dimensions signature for a given function.
5560
@@ -202,7 +207,10 @@ def _get_coords_list(args) -> List["Coordinates"]:
202207

203208

204209
def build_output_coords(
205-
args: list, signature: _UFuncSignature, exclude_dims: AbstractSet = frozenset()
210+
args: list,
211+
signature: _UFuncSignature,
212+
exclude_dims: AbstractSet = frozenset(),
213+
combine_attrs: str = "override",
206214
) -> "List[Dict[Any, Variable]]":
207215
"""Build output coordinates for an operation.
208216
@@ -230,7 +238,7 @@ def build_output_coords(
230238
else:
231239
# TODO: save these merged indexes, instead of re-computing them later
232240
merged_vars, unused_indexes = merge_coordinates_without_align(
233-
coords_list, exclude_dims=exclude_dims
241+
coords_list, exclude_dims=exclude_dims, combine_attrs=combine_attrs
234242
)
235243

236244
output_coords = []
@@ -248,7 +256,12 @@ def build_output_coords(
248256

249257

250258
def apply_dataarray_vfunc(
251-
func, *args, signature, join="inner", exclude_dims=frozenset(), keep_attrs=False
259+
func,
260+
*args,
261+
signature,
262+
join="inner",
263+
exclude_dims=frozenset(),
264+
keep_attrs="override",
252265
):
253266
"""Apply a variable level function over DataArray, Variable and/or ndarray
254267
objects.
@@ -260,12 +273,16 @@ def apply_dataarray_vfunc(
260273
args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False
261274
)
262275

263-
if keep_attrs:
276+
objs = _all_of_type(args, DataArray)
277+
278+
if keep_attrs == "drop":
279+
name = result_name(args)
280+
else:
264281
first_obj = _first_of_type(args, DataArray)
265282
name = first_obj.name
266-
else:
267-
name = result_name(args)
268-
result_coords = build_output_coords(args, signature, exclude_dims)
283+
result_coords = build_output_coords(
284+
args, signature, exclude_dims, combine_attrs=keep_attrs
285+
)
269286

270287
data_vars = [getattr(a, "variable", a) for a in args]
271288
result_var = func(*data_vars)
@@ -279,13 +296,12 @@ def apply_dataarray_vfunc(
279296
(coords,) = result_coords
280297
out = DataArray(result_var, coords, name=name, fastpath=True)
281298

282-
if keep_attrs:
283-
if isinstance(out, tuple):
284-
for da in out:
285-
# This is adding attrs in place
286-
da._copy_attrs_from(first_obj)
287-
else:
288-
out._copy_attrs_from(first_obj)
299+
attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs)
300+
if isinstance(out, tuple):
301+
for da in out:
302+
da.attrs = attrs
303+
else:
304+
out.attrs = attrs
289305

290306
return out
291307

@@ -400,7 +416,7 @@ def apply_dataset_vfunc(
400416
dataset_join="exact",
401417
fill_value=_NO_FILL_VALUE,
402418
exclude_dims=frozenset(),
403-
keep_attrs=False,
419+
keep_attrs="override",
404420
):
405421
"""Apply a variable level function over Dataset, dict of DataArray,
406422
DataArray, Variable and/or ndarray objects.
@@ -414,15 +430,16 @@ def apply_dataset_vfunc(
414430
"dataset_fill_value argument."
415431
)
416432

417-
if keep_attrs:
418-
first_obj = _first_of_type(args, Dataset)
433+
objs = _all_of_type(args, Dataset)
419434

420435
if len(args) > 1:
421436
args = deep_align(
422437
args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False
423438
)
424439

425-
list_of_coords = build_output_coords(args, signature, exclude_dims)
440+
list_of_coords = build_output_coords(
441+
args, signature, exclude_dims, combine_attrs=keep_attrs
442+
)
426443
args = [getattr(arg, "data_vars", arg) for arg in args]
427444

428445
result_vars = apply_dict_of_variables_vfunc(
@@ -435,13 +452,13 @@ def apply_dataset_vfunc(
435452
(coord_vars,) = list_of_coords
436453
out = _fast_dataset(result_vars, coord_vars)
437454

438-
if keep_attrs:
439-
if isinstance(out, tuple):
440-
for ds in out:
441-
# This is adding attrs in place
442-
ds._copy_attrs_from(first_obj)
443-
else:
444-
out._copy_attrs_from(first_obj)
455+
attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs)
456+
if isinstance(out, tuple):
457+
for ds in out:
458+
ds.attrs = attrs
459+
else:
460+
out.attrs = attrs
461+
445462
return out
446463

447464

@@ -609,14 +626,12 @@ def apply_variable_ufunc(
609626
dask="forbidden",
610627
output_dtypes=None,
611628
vectorize=False,
612-
keep_attrs=False,
629+
keep_attrs="override",
613630
dask_gufunc_kwargs=None,
614631
):
615632
"""Apply a ndarray level function over Variable and/or ndarray objects."""
616633
from .variable import Variable, as_compatible_data
617634

618-
first_obj = _first_of_type(args, Variable)
619-
620635
dim_sizes = unified_dim_sizes(
621636
(a for a in args if hasattr(a, "dims")), exclude_dims=exclude_dims
622637
)
@@ -736,6 +751,12 @@ def func(*arrays):
736751
)
737752
)
738753

754+
objs = _all_of_type(args, Variable)
755+
attrs = merge_attrs(
756+
[obj.attrs for obj in objs],
757+
combine_attrs=keep_attrs,
758+
)
759+
739760
output = []
740761
for dims, data in zip(output_dims, result_data):
741762
data = as_compatible_data(data)
@@ -758,8 +779,7 @@ def func(*arrays):
758779
)
759780
)
760781

761-
if keep_attrs:
762-
var.attrs.update(first_obj.attrs)
782+
var.attrs = attrs
763783
output.append(var)
764784

765785
if signature.num_outputs == 1:
@@ -801,7 +821,7 @@ def apply_ufunc(
801821
join: str = "exact",
802822
dataset_join: str = "exact",
803823
dataset_fill_value: object = _NO_FILL_VALUE,
804-
keep_attrs: bool = False,
824+
keep_attrs: Union[bool, str] = None,
805825
kwargs: Mapping = None,
806826
dask: str = "forbidden",
807827
output_dtypes: Sequence = None,
@@ -1098,6 +1118,12 @@ def apply_ufunc(
10981118
if kwargs:
10991119
func = functools.partial(func, **kwargs)
11001120

1121+
if keep_attrs is None:
1122+
keep_attrs = _get_keep_attrs(default=False)
1123+
1124+
if isinstance(keep_attrs, bool):
1125+
keep_attrs = "override" if keep_attrs else "drop"
1126+
11011127
variables_vfunc = functools.partial(
11021128
apply_variable_ufunc,
11031129
func,

xarray/core/merge.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def merge_coordinates_without_align(
314314
objects: "List[Coordinates]",
315315
prioritized: Mapping[Hashable, MergeElement] = None,
316316
exclude_dims: AbstractSet = frozenset(),
317+
combine_attrs: str = "override",
317318
) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]:
318319
"""Merge variables/indexes from coordinates without automatic alignments.
319320
@@ -335,7 +336,7 @@ def merge_coordinates_without_align(
335336
else:
336337
filtered = collected
337338

338-
return merge_collected(filtered, prioritized)
339+
return merge_collected(filtered, prioritized, combine_attrs=combine_attrs)
339340

340341

341342
def determine_coords(

0 commit comments

Comments
 (0)