Skip to content

Commit a9d7ac5

Browse files
committed
ENH: fill missing variables during concat by reindexing
1 parent 41fef6f commit a9d7ac5

File tree

2 files changed

+53
-15
lines changed

2 files changed

+53
-15
lines changed

xarray/core/concat.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,9 @@ def process_subset_opt(opt, subset):
378378

379379
elif opt == "all":
380380
concat_over.update(
381-
set(getattr(datasets[0], subset)) - set(datasets[0].dims)
381+
set().union(
382+
*list((set(getattr(d, subset)) - set(d.dims) for d in datasets))
383+
)
382384
)
383385
elif opt == "minimal":
384386
pass
@@ -553,16 +555,35 @@ def get_indexes(name):
553555
data = var.set_dims(dim).values
554556
yield PandasIndex(data, dim, coord_dtype=var.dtype)
555557

558+
# preserve variable order for variables in first dataset
559+
data_var_order = list(datasets[0].variables)
560+
# append additional variables to the end
561+
data_var_order += [e for e in data_names if e not in data_var_order]
562+
# create concatenation index, needed for later reindexing
563+
concat_index = list(range(sum(concat_dim_lengths)))
564+
556565
# stack up each variable and/or index to fill-out the dataset (in order)
557566
# n.b. this loop preserves variable order, needed for groupby.
558-
for name in datasets[0].variables:
567+
for name in data_var_order:
559568
if name in concat_over and name not in result_indexes:
560-
try:
561-
vars = ensure_common_dims([ds[name].variable for ds in datasets])
562-
except KeyError:
563-
raise ValueError(f"{name!r} is not present in all datasets.")
564-
565-
# Try concatenate the indexes, concatenate the variables when no index
569+
variables = []
570+
variable_index = []
571+
for i, ds in enumerate(datasets):
572+
if name in ds.variables:
573+
variables.append(ds.variables[name])
574+
# add to variable index, needed for reindexing
575+
variable_index.extend(
576+
[sum(concat_dim_lengths[:i]) + k for k in range(concat_dim_lengths[i])]
577+
)
578+
else:
579+
# raise if coordinate not in all datasets
580+
if name in coord_names:
581+
raise ValueError(
582+
f"coordinate {name!r} not present in all datasets."
583+
)
584+
vars = list(ensure_common_dims(variables))
585+
586+
# Try to concatenate the indexes, concatenate the variables when no index
566587
# is found on all datasets.
567588
indexes: list[Index] = list(get_indexes(name))
568589
if indexes:
@@ -586,9 +607,28 @@ def get_indexes(name):
586607
)
587608
result_vars[k] = v
588609
else:
589-
combined_var = concat_vars(
590-
vars, dim, positions, combine_attrs=combine_attrs
591-
)
610+
# if variable is only present in one dataset of multiple datasets,
611+
# then do not concat
612+
if len(variables) == 1 and len(datasets) > 1:
613+
combined_var = variables[0]
614+
# only concat if variable is in multiple datasets
615+
# or if single dataset (GH1988)
616+
else:
617+
combined_var = concat_vars(
618+
vars, dim, positions, combine_attrs=combine_attrs
619+
)
620+
# reindex if variable is not present in all datasets
621+
if len(variable_index) < len(concat_index):
622+
try:
623+
fill = fill_value[name]
624+
except (TypeError, KeyError):
625+
fill = fill_value
626+
combined_var = (
627+
DataArray(data=combined_var, name=name)
628+
.assign_coords({dim: variable_index})
629+
.reindex({dim: concat_index}, fill_value=fill)
630+
.variable
631+
)
592632
result_vars[name] = combined_var
593633

594634
elif name in result_vars:

xarray/tests/test_concat.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ def test_concat_compat() -> None:
5050
ValueError, match=r"coordinates in some datasets but not others"
5151
):
5252
concat([ds1, ds2], dim="q")
53-
with pytest.raises(ValueError, match=r"'q' is not present in all datasets"):
54-
concat([ds2, ds1], dim="q")
5553

5654

5755
class TestConcatDataset:
@@ -776,15 +774,15 @@ def test_concat_merge_single_non_dim_coord():
776774
actual = concat([da1, da2], "x", coords=coords)
777775
assert_identical(actual, expected)
778776

779-
with pytest.raises(ValueError, match=r"'y' is not present in all datasets."):
777+
with pytest.raises(ValueError, match=r"'y' not present in all datasets."):
780778
concat([da1, da2], dim="x", coords="all")
781779

782780
da1 = DataArray([1, 2, 3], dims="x", coords={"x": [1, 2, 3], "y": 1})
783781
da2 = DataArray([4, 5, 6], dims="x", coords={"x": [4, 5, 6]})
784782
da3 = DataArray([7, 8, 9], dims="x", coords={"x": [7, 8, 9], "y": 1})
785783
for coords in ["different", "all"]:
786784
with pytest.raises(ValueError, match=r"'y' not present in all datasets"):
787-
concat([da1, da2, da3], dim="x")
785+
concat([da1, da2, da3], dim="x", coords=coords)
788786

789787

790788
def test_concat_preserve_coordinate_order() -> None:

0 commit comments

Comments
 (0)