Skip to content

Commit 4e9240a

Browse files
keewismax-sixty
authored andcommitted
add missing pint integration tests (#3508)
* add tests for broadcast_like * add tests for DataArray head / tail / thin * update whats-new.rst
1 parent f14edf3 commit 4e9240a

File tree

2 files changed

+108
-1
lines changed

2 files changed

+108
-1
lines changed

doc/whats-new.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ Internal Changes
111111
~~~~~~~~~~~~~~~~
112112

113113
- Added integration tests against `pint <https://pint.readthedocs.io/>`_.
114-
(:pull:`3238`, :pull:`3447`) by `Justus Magin <https://github.com/keewis>`_.
114+
(:pull:`3238`, :pull:`3447`, :pull:`3508`) by `Justus Magin <https://github.com/keewis>`_.
115115

116116
.. note::
117117

xarray/tests/test_units.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,36 @@ def test_comparisons(self, func, variation, unit, dtype):
10451045

10461046
assert expected == result
10471047

1048+
@pytest.mark.xfail(reason="blocked by `where`")
1049+
@pytest.mark.parametrize(
1050+
"unit",
1051+
(
1052+
pytest.param(1, id="no_unit"),
1053+
pytest.param(unit_registry.dimensionless, id="dimensionless"),
1054+
pytest.param(unit_registry.s, id="incompatible_unit"),
1055+
pytest.param(unit_registry.cm, id="compatible_unit"),
1056+
pytest.param(unit_registry.m, id="identical_unit"),
1057+
),
1058+
)
1059+
def test_broadcast_like(self, unit, dtype):
1060+
array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * unit_registry.Pa
1061+
array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * unit_registry.Pa
1062+
1063+
x1 = np.arange(2) * unit_registry.m
1064+
x2 = np.arange(2) * unit
1065+
y1 = np.array([0]) * unit_registry.m
1066+
y2 = np.arange(3) * unit
1067+
1068+
arr1 = xr.DataArray(data=array1, coords={"x": x1, "y": y1}, dims=("x", "y"))
1069+
arr2 = xr.DataArray(data=array2, coords={"x": x2, "y": y2}, dims=("x", "y"))
1070+
1071+
expected = attach_units(
1072+
strip_units(arr1).broadcast_like(strip_units(arr2)), extract_units(arr1)
1073+
)
1074+
result = arr1.broadcast_like(arr2)
1075+
1076+
assert_equal_with_units(expected, result)
1077+
10481078
@pytest.mark.parametrize(
10491079
"unit",
10501080
(
@@ -1303,6 +1333,49 @@ def test_squeeze(self, shape, dtype):
13031333
np.squeeze(array, axis=index), data_array.squeeze(dim=name)
13041334
)
13051335

1336+
@pytest.mark.xfail(
1337+
reason="indexes strip units and head / tail / thin only support integers"
1338+
)
1339+
@pytest.mark.parametrize(
1340+
"unit,error",
1341+
(
1342+
pytest.param(1, DimensionalityError, id="no_unit"),
1343+
pytest.param(
1344+
unit_registry.dimensionless, DimensionalityError, id="dimensionless"
1345+
),
1346+
pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
1347+
pytest.param(unit_registry.cm, None, id="compatible_unit"),
1348+
pytest.param(unit_registry.m, None, id="identical_unit"),
1349+
),
1350+
)
1351+
@pytest.mark.parametrize(
1352+
"func",
1353+
(method("head", x=7, y=3), method("tail", x=7, y=3), method("thin", x=7, y=3)),
1354+
ids=repr,
1355+
)
1356+
def test_head_tail_thin(self, func, unit, error, dtype):
1357+
array = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK
1358+
1359+
coords = {
1360+
"x": np.arange(10) * unit_registry.m,
1361+
"y": np.arange(5) * unit_registry.m,
1362+
}
1363+
1364+
arr = xr.DataArray(data=array, coords=coords, dims=("x", "y"))
1365+
1366+
kwargs = {name: value * unit for name, value in func.kwargs.items()}
1367+
1368+
if error is not None:
1369+
with pytest.raises(error):
1370+
func(arr, **kwargs)
1371+
1372+
return
1373+
1374+
expected = attach_units(func(strip_units(arr)), extract_units(arr))
1375+
result = func(arr, **kwargs)
1376+
1377+
assert_equal_with_units(expected, result)
1378+
13061379
@pytest.mark.parametrize(
13071380
"unit,error",
13081381
(
@@ -2472,6 +2545,40 @@ def test_comparisons(self, func, variation, unit, dtype):
24722545

24732546
assert expected == result
24742547

2548+
@pytest.mark.xfail(reason="blocked by `where`")
2549+
@pytest.mark.parametrize(
2550+
"unit",
2551+
(
2552+
pytest.param(1, id="no_unit"),
2553+
pytest.param(unit_registry.dimensionless, id="dimensionless"),
2554+
pytest.param(unit_registry.s, id="incompatible_unit"),
2555+
pytest.param(unit_registry.cm, id="compatible_unit"),
2556+
pytest.param(unit_registry.m, id="identical_unit"),
2557+
),
2558+
)
2559+
def test_broadcast_like(self, unit, dtype):
2560+
array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * unit_registry.Pa
2561+
array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * unit_registry.Pa
2562+
2563+
x1 = np.arange(2) * unit_registry.m
2564+
x2 = np.arange(2) * unit
2565+
y1 = np.array([0]) * unit_registry.m
2566+
y2 = np.arange(3) * unit
2567+
2568+
ds1 = xr.Dataset(
2569+
data_vars={"a": (("x", "y"), array1)}, coords={"x": x1, "y": y1}
2570+
)
2571+
ds2 = xr.Dataset(
2572+
data_vars={"a": (("x", "y"), array2)}, coords={"x": x2, "y": y2}
2573+
)
2574+
2575+
expected = attach_units(
2576+
strip_units(ds1).broadcast_like(strip_units(ds2)), extract_units(ds1)
2577+
)
2578+
result = ds1.broadcast_like(ds2)
2579+
2580+
assert_equal_with_units(expected, result)
2581+
24752582
@pytest.mark.parametrize(
24762583
"unit",
24772584
(

0 commit comments

Comments
 (0)