Skip to content

Commit 79b0054

Browse files
committed
fix tests with, finalize typing
1 parent 9909074 commit 79b0054

File tree

1 file changed

+113
-92
lines changed

1 file changed

+113
-92
lines changed

xarray/tests/test_concat.py

Lines changed: 113 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import random
44
from copy import deepcopy
5-
from typing import TYPE_CHECKING, Any
5+
from typing import TYPE_CHECKING, Any, Callable
66

77
import numpy as np
88
import pandas as pd
@@ -230,13 +230,12 @@ def test_concat_missing_multiple_consecutive_var() -> None:
230230
"day": ["day1", "day2", "day3", "day4", "day5", "day6"],
231231
},
232232
)
233-
# assign here, as adding above gave switched pressure/humidity-order every once in a while
234-
ds_result = ds_result.assign({"humidity": (["x", "y", "day"], humidity_result)})
235-
ds_result = ds_result.assign({"pressure": (["x", "y", "day"], pressure_result)})
236233
result = concat(datasets, dim="day")
237234
r1 = [var for var in result.data_vars]
238235
r2 = [var for var in ds_result.data_vars]
239-
assert r1 == r2 # check the variables orders are the same
236+
# check the variables orders are the same for the first three variables
237+
assert r1[:3] == r2[:3]
238+
assert set(r1[3:]) == set(r2[3:]) # just check availability for the remaining vars
240239
assert_equal(result, ds_result)
241240

242241

@@ -301,56 +300,60 @@ def test_multiple_missing_variables() -> None:
301300
assert_equal(result, ds_result)
302301

303302

304-
@pytest.mark.xfail(strict=True)
305-
def test_concat_multiple_datasets_missing_vars_and_new_dim() -> None:
303+
@pytest.mark.parametrize("include_day", [True, False])
304+
def test_concat_multiple_datasets_missing_vars_and_new_dim(include_day: bool) -> None:
306305
vars_to_drop = [
307306
"temperature",
308307
"pressure",
309308
"humidity",
310309
"precipitation",
311310
"cloud cover",
312311
]
313-
datasets = create_concat_datasets(len(vars_to_drop), 123, include_day=False)
312+
313+
datasets = create_concat_datasets(len(vars_to_drop), 123, include_day=include_day)
314314
# set up the test data
315315
datasets = [datasets[i].drop_vars(vars_to_drop[i]) for i in range(len(datasets))]
316316

317+
dim_size = 2 if include_day else 1
318+
317319
# set up the validation data
318320
# the below code just drops one var per dataset depending on the location of the
319321
# dataset in the list and allows us to quickly catch any boundaries cases across
320322
# the three equivalence classes of beginning, middle and end of the concat list
321-
result_vars = dict.fromkeys(vars_to_drop)
323+
result_vars = dict.fromkeys(vars_to_drop, np.array([]))
322324
for i in range(len(vars_to_drop)):
323325
for d in range(len(datasets)):
324326
if d != i:
325-
if result_vars[vars_to_drop[i]] is None:
326-
result_vars[vars_to_drop[i]] = datasets[d][vars_to_drop[i]].values
327+
if include_day:
328+
ds_vals = datasets[d][vars_to_drop[i]].values
329+
else:
330+
ds_vals = datasets[d][vars_to_drop[i]].values[..., None]
331+
if not result_vars[vars_to_drop[i]].size:
332+
result_vars[vars_to_drop[i]] = ds_vals
327333
else:
328334
result_vars[vars_to_drop[i]] = np.concatenate(
329335
(
330336
result_vars[vars_to_drop[i]],
331-
datasets[d][vars_to_drop[i]].values,
337+
ds_vals,
332338
),
333-
axis=1,
339+
axis=-1,
334340
)
335341
else:
336-
if result_vars[vars_to_drop[i]] is None:
337-
result_vars[vars_to_drop[i]] = np.full([1, 4], np.nan)
342+
if not result_vars[vars_to_drop[i]].size:
343+
result_vars[vars_to_drop[i]] = np.full([1, 4, dim_size], np.nan)
338344
else:
339345
result_vars[vars_to_drop[i]] = np.concatenate(
340-
(result_vars[vars_to_drop[i]], np.full([1, 4], np.nan)),
341-
axis=1,
346+
(
347+
result_vars[vars_to_drop[i]],
348+
np.full([1, 4, dim_size], np.nan),
349+
),
350+
axis=-1,
342351
)
343-
# TODO: this test still has two unexpected errors:
344-
345-
# 1: concat throws a mergeerror expecting the temperature values to be the same, this doesn't seem to be correct in this case
346-
# as we are concating on new dims
347-
# 2: if the values are the same for a variable (working around #1) then it will likely not correct add the new dim to the first variable
348-
# the resulting set
349352

350353
ds_result = Dataset(
351354
data_vars={
352-
# pressure will be first in this since the first dataset is missing this var
353-
# and there isn't a good way to determine that this should be first
355+
# pressure will be first here since it is first in first dataset and
356+
# there isn't a good way to determine that temperature should be first
354357
# this also means temperature will be last as the first data vars will
355358
# determine the order for all that exist in that dataset
356359
"pressure": (["x", "y", "day"], result_vars["pressure"]),
@@ -362,11 +365,17 @@ def test_concat_multiple_datasets_missing_vars_and_new_dim() -> None:
362365
coords={
363366
"lat": (["x", "y"], datasets[0].lat.values),
364367
"lon": (["x", "y"], datasets[0].lon.values),
365-
# "day": ["day" + str(d + 1) for d in range(2 * len(vars_to_drop))],
366368
},
367369
)
370+
if include_day:
371+
ds_result = ds_result.assign_coords(
372+
{"day": ["day" + str(d + 1) for d in range(2 * len(vars_to_drop))]}
373+
)
374+
else:
375+
ds_result = ds_result.transpose("day", "x", "y")
368376

369377
result = concat(datasets, dim="day")
378+
370379
r1 = list(result.data_vars.keys())
371380
r2 = list(ds_result.data_vars.keys())
372381
assert r1 == r2 # check the variables orders are the same
@@ -390,11 +399,11 @@ def test_multiple_datasets_with_missing_variables() -> None:
390399
# the below code just drops one var per dataset depending on the location of the
391400
# dataset in the list and allows us to quickly catch any boundaries cases across
392401
# the three equivalence classes of beginning, middle and end of the concat list
393-
result_vars = dict.fromkeys(vars_to_drop)
402+
result_vars = dict.fromkeys(vars_to_drop, np.array([]))
394403
for i in range(len(vars_to_drop)):
395404
for d in range(len(datasets)):
396405
if d != i:
397-
if result_vars[vars_to_drop[i]] is None:
406+
if not result_vars[vars_to_drop[i]].size:
398407
result_vars[vars_to_drop[i]] = datasets[d][vars_to_drop[i]].values
399408
else:
400409
result_vars[vars_to_drop[i]] = np.concatenate(
@@ -405,7 +414,7 @@ def test_multiple_datasets_with_missing_variables() -> None:
405414
axis=2,
406415
)
407416
else:
408-
if result_vars[vars_to_drop[i]] is None:
417+
if not result_vars[vars_to_drop[i]].size:
409418
result_vars[vars_to_drop[i]] = np.full([1, 4, 2], np.nan)
410419
else:
411420
result_vars[vars_to_drop[i]] = np.concatenate(
@@ -483,8 +492,9 @@ def test_multiple_datasets_with_multiple_missing_variables() -> None:
483492

484493
r1 = list(result.data_vars.keys())
485494
r2 = list(ds_result.data_vars.keys())
486-
assert r1 == r2 # check the variables orders are the same
487-
495+
# check the variables orders are the same for the first three variables
496+
assert r1[:3] == r2[:3]
497+
assert set(r1[3:]) == set(r2[3:]) # just check availability for the remaining vars
488498
assert_equal(result, ds_result)
489499

490500

@@ -581,7 +591,7 @@ def test_type_of_missing_fill() -> None:
581591

582592

583593
def test_order_when_filling_missing() -> None:
584-
vars_to_drop_in_first = []
594+
vars_to_drop_in_first: list[str] = []
585595
# drop middle
586596
vars_to_drop_in_second = ["humidity"]
587597
datasets = create_concat_datasets(2, 123)
@@ -649,6 +659,77 @@ def test_order_when_filling_missing() -> None:
649659
result_index += 1
650660

651661

662+
@pytest.fixture
663+
def concat_var_names() -> Callable:
664+
# create var names list with one missing value
665+
def get_varnames(var_cnt: int = 10, list_cnt: int = 10) -> list[list[str]]:
666+
orig = [f"d{i:02d}" for i in range(var_cnt)]
667+
var_names = []
668+
for i in range(0, list_cnt):
669+
l1 = orig.copy()
670+
var_names.append(l1)
671+
return var_names
672+
673+
return get_varnames
674+
675+
676+
@pytest.fixture
677+
def create_concat_ds() -> Callable:
678+
def create_ds(
679+
var_names: list[list[str]],
680+
dim: bool = False,
681+
coord: bool = False,
682+
drop_idx: list[int] | None = None,
683+
) -> list[Dataset]:
684+
out_ds = []
685+
ds = Dataset()
686+
ds = ds.assign_coords({"x": np.arange(2)})
687+
ds = ds.assign_coords({"y": np.arange(3)})
688+
ds = ds.assign_coords({"z": np.arange(4)})
689+
for i, dsl in enumerate(var_names):
690+
vlist = dsl.copy()
691+
if drop_idx is not None:
692+
vlist.pop(drop_idx[i])
693+
foo_data = np.arange(48, dtype=float).reshape(2, 2, 3, 4)
694+
dsi = ds.copy()
695+
if coord:
696+
dsi = ds.assign({"time": (["time"], [i * 2, i * 2 + 1])})
697+
for k in vlist:
698+
dsi = dsi.assign({k: (["time", "x", "y", "z"], foo_data.copy())})
699+
if not dim:
700+
dsi = dsi.isel(time=0)
701+
out_ds.append(dsi)
702+
return out_ds
703+
704+
return create_ds
705+
706+
707+
@pytest.mark.parametrize("dim", [True, False])
708+
@pytest.mark.parametrize("coord", [True, False])
709+
def test_concat_fill_missing_variables(
710+
concat_var_names, create_concat_ds, dim: bool, coord: bool
711+
) -> None:
712+
var_names = concat_var_names()
713+
714+
random.seed(42)
715+
drop_idx = [random.randrange(len(vlist)) for vlist in var_names]
716+
expected = concat(
717+
create_concat_ds(var_names, dim=dim, coord=coord), dim="time", data_vars="all"
718+
)
719+
for i, idx in enumerate(drop_idx):
720+
if dim:
721+
expected[var_names[0][idx]][i * 2 : i * 2 + 2] = np.nan
722+
else:
723+
expected[var_names[0][idx]][i] = np.nan
724+
725+
concat_ds = create_concat_ds(var_names, dim=dim, coord=coord, drop_idx=drop_idx)
726+
actual = concat(concat_ds, dim="time", data_vars="all")
727+
728+
for name in var_names[0]:
729+
assert_equal(expected[name], actual[name])
730+
assert_equal(expected, actual)
731+
732+
652733
class TestConcatDataset:
653734
@pytest.fixture
654735
def data(self) -> Dataset:
@@ -1168,66 +1249,6 @@ def test_concat_str_dtype(self, dtype, dim) -> None:
11681249

11691250
assert np.issubdtype(actual.x2.dtype, dtype)
11701251

1171-
@pytest.mark.parametrize("dim", [True, False])
1172-
@pytest.mark.parametrize("coord", [True, False])
1173-
def test_concat_fill_missing_variables(self, dim: bool, coord: bool) -> None:
1174-
# create var names list with one missing value
1175-
def get_var_names(var_cnt: int = 10, list_cnt: int = 10) -> list[list[str]]:
1176-
orig = [f"d{i:02d}" for i in range(var_cnt)]
1177-
var_names = []
1178-
for i in range(0, list_cnt):
1179-
l1 = orig.copy()
1180-
var_names.append(l1)
1181-
return var_names
1182-
1183-
def create_ds(
1184-
var_names: list[list[str]],
1185-
dim: bool = False,
1186-
coord: bool = False,
1187-
drop_idx: list[int] | None = None,
1188-
) -> list[Dataset]:
1189-
out_ds = []
1190-
ds = Dataset()
1191-
ds = ds.assign_coords({"x": np.arange(2)})
1192-
ds = ds.assign_coords({"y": np.arange(3)})
1193-
ds = ds.assign_coords({"z": np.arange(4)})
1194-
for i, dsl in enumerate(var_names):
1195-
vlist = dsl.copy()
1196-
if drop_idx is not None:
1197-
vlist.pop(drop_idx[i])
1198-
foo_data = np.arange(48, dtype=float).reshape(2, 2, 3, 4)
1199-
dsi = ds.copy()
1200-
if coord:
1201-
dsi = ds.assign({"time": (["time"], [i * 2, i * 2 + 1])})
1202-
for k in vlist:
1203-
dsi = dsi.assign({k: (["time", "x", "y", "z"], foo_data.copy())})
1204-
if not dim:
1205-
dsi = dsi.isel(time=0)
1206-
out_ds.append(dsi)
1207-
return out_ds
1208-
1209-
var_names = get_var_names()
1210-
1211-
import random
1212-
1213-
random.seed(42)
1214-
drop_idx = [random.randrange(len(vlist)) for vlist in var_names]
1215-
expected = concat(
1216-
create_ds(var_names, dim=dim, coord=coord), dim="time", data_vars="all"
1217-
)
1218-
for i, idx in enumerate(drop_idx):
1219-
if dim:
1220-
expected[var_names[0][idx]][i * 2 : i * 2 + 2] = np.nan
1221-
else:
1222-
expected[var_names[0][idx]][i] = np.nan
1223-
1224-
concat_ds = create_ds(var_names, dim=dim, coord=coord, drop_idx=drop_idx)
1225-
actual = concat(concat_ds, dim="time", data_vars="all")
1226-
1227-
for name in var_names[0]:
1228-
assert_equal(expected[name], actual[name])
1229-
assert_equal(expected, actual)
1230-
12311252

12321253
class TestConcatDataArray:
12331254
def test_concat(self) -> None:

0 commit comments

Comments
 (0)