|
| 1 | +""" |
| 2 | +Property-based tests for roundtripping between xarray and pandas objects. |
| 3 | +""" |
| 4 | +import pytest |
| 5 | + |
| 6 | +pytest.importorskip("hypothesis") |
| 7 | + |
| 8 | +from functools import partial |
| 9 | +import hypothesis.extra.numpy as npst |
| 10 | +import hypothesis.extra.pandas as pdst |
| 11 | +import hypothesis.strategies as st |
| 12 | +from hypothesis import given |
| 13 | + |
| 14 | +import numpy as np |
| 15 | +import pandas as pd |
| 16 | +import xarray as xr |
| 17 | + |
| 18 | +numeric_dtypes = st.one_of( |
| 19 | + npst.unsigned_integer_dtypes(), npst.integer_dtypes(), npst.floating_dtypes() |
| 20 | +) |
| 21 | + |
| 22 | +numeric_series = numeric_dtypes.flatmap(lambda dt: pdst.series(dtype=dt)) |
| 23 | + |
| 24 | +an_array = npst.arrays( |
| 25 | + dtype=numeric_dtypes, |
| 26 | + shape=npst.array_shapes(max_dims=2), # can only convert 1D/2D to pandas |
| 27 | +) |
| 28 | + |
| 29 | + |
| 30 | +@st.composite |
| 31 | +def datasets_1d_vars(draw): |
| 32 | + """Generate datasets with only 1D variables |
| 33 | +
|
| 34 | + Suitable for converting to pandas dataframes. |
| 35 | + """ |
| 36 | + # Generate an index for the dataset |
| 37 | + idx = draw(pdst.indexes(dtype="u8", min_size=0, max_size=100)) |
| 38 | + |
| 39 | + # Generate 1-3 variables, 1D with the same length as the index |
| 40 | + vars_strategy = st.dictionaries( |
| 41 | + keys=st.text(), |
| 42 | + values=npst.arrays(dtype=numeric_dtypes, shape=len(idx)).map( |
| 43 | + partial(xr.Variable, ("rows",)) |
| 44 | + ), |
| 45 | + min_size=1, |
| 46 | + max_size=3, |
| 47 | + ) |
| 48 | + return xr.Dataset(draw(vars_strategy), coords={"rows": idx}) |
| 49 | + |
| 50 | + |
| 51 | +@given(st.data(), an_array) |
| 52 | +def test_roundtrip_dataarray(data, arr): |
| 53 | + names = data.draw( |
| 54 | + st.lists(st.text(), min_size=arr.ndim, max_size=arr.ndim, unique=True).map( |
| 55 | + tuple |
| 56 | + ) |
| 57 | + ) |
| 58 | + coords = {name: np.arange(n) for (name, n) in zip(names, arr.shape)} |
| 59 | + original = xr.DataArray(arr, dims=names, coords=coords) |
| 60 | + roundtripped = xr.DataArray(original.to_pandas()) |
| 61 | + xr.testing.assert_identical(original, roundtripped) |
| 62 | + |
| 63 | + |
| 64 | +@given(datasets_1d_vars()) |
| 65 | +def test_roundtrip_dataset(dataset): |
| 66 | + df = dataset.to_dataframe() |
| 67 | + assert isinstance(df, pd.DataFrame) |
| 68 | + roundtripped = xr.Dataset(df) |
| 69 | + xr.testing.assert_identical(dataset, roundtripped) |
| 70 | + |
| 71 | + |
| 72 | +@given(numeric_series, st.text()) |
| 73 | +def test_roundtrip_pandas_series(ser, ix_name): |
| 74 | + # Need to name the index, otherwise Xarray calls it 'dim_0'. |
| 75 | + ser.index.name = ix_name |
| 76 | + arr = xr.DataArray(ser) |
| 77 | + roundtripped = arr.to_pandas() |
| 78 | + pd.testing.assert_series_equal(ser, roundtripped) |
| 79 | + xr.testing.assert_identical(arr, roundtripped.to_xarray()) |
| 80 | + |
| 81 | + |
| 82 | +# Dataframes with columns of all the same dtype - for roundtrip to DataArray |
| 83 | +numeric_homogeneous_dataframe = numeric_dtypes.flatmap( |
| 84 | + lambda dt: pdst.data_frames(columns=pdst.columns(["a", "b", "c"], dtype=dt)) |
| 85 | +) |
| 86 | + |
| 87 | + |
| 88 | +@pytest.mark.xfail |
| 89 | +@given(numeric_homogeneous_dataframe) |
| 90 | +def test_roundtrip_pandas_dataframe(df): |
| 91 | + # Need to name the indexes, otherwise Xarray names them 'dim_0', 'dim_1'. |
| 92 | + df.index.name = "rows" |
| 93 | + df.columns.name = "cols" |
| 94 | + arr = xr.DataArray(df) |
| 95 | + roundtripped = arr.to_pandas() |
| 96 | + pd.testing.assert_frame_equal(df, roundtripped) |
| 97 | + xr.testing.assert_identical(arr, roundtripped.to_xarray()) |
0 commit comments