Skip to content

Commit 9959405

Browse files
Use encoding['dtype'] over data.dtype when possible within CFMaskCoder.encode (#3652)
* Use encoding['dtype'] over data.dtype when possible * Add what's new entry * Fix typo in what's new Co-authored-by: Deepak Cherian <[email protected]>
1 parent e0fd480 commit 9959405

File tree

3 files changed

+33
-11
lines changed

3 files changed

+33
-11
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ Bug fixes
6969
By `Tom Augspurger <https://github.com/TomAugspurger>`_.
7070
- Ensure :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` issue the correct error
7171
when ``q`` is out of bounds (:issue:`3634`) by `Mathias Hauser <https://github.com/mathause>`_.
72+
- Fix regression in xarray 0.14.1 that prevented encoding times with certain
73+
``dtype``, ``_FillValue``, and ``missing_value`` encodings (:issue:`3624`).
74+
By `Spencer Clark <https://github.com/spencerkclark>`_
7275
- Raise an error when trying to use :py:meth:`Dataset.rename_dims` to
7376
rename to an existing name (:issue:`3438`, :pull:`3645`)
7477
By `Justus Magin <https://github.com/keewis>`_.

xarray/coding/variables.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ class CFMaskCoder(VariableCoder):
148148
def encode(self, variable, name=None):
149149
dims, data, attrs, encoding = unpack_for_encoding(variable)
150150

151+
dtype = np.dtype(encoding.get("dtype", data.dtype))
151152
fv = encoding.get("_FillValue")
152153
mv = encoding.get("missing_value")
153154

@@ -162,14 +163,14 @@ def encode(self, variable, name=None):
162163

163164
if fv is not None:
164165
# Ensure _FillValue is cast to same dtype as data's
165-
encoding["_FillValue"] = data.dtype.type(fv)
166+
encoding["_FillValue"] = dtype.type(fv)
166167
fill_value = pop_to(encoding, attrs, "_FillValue", name=name)
167168
if not pd.isnull(fill_value):
168169
data = duck_array_ops.fillna(data, fill_value)
169170

170171
if mv is not None:
171172
# Ensure missing_value is cast to same dtype as data's
172-
encoding["missing_value"] = data.dtype.type(mv)
173+
encoding["missing_value"] = dtype.type(mv)
173174
fill_value = pop_to(encoding, attrs, "missing_value", name=name)
174175
if not pd.isnull(fill_value) and fv is None:
175176
data = duck_array_ops.fillna(data, fill_value)

xarray/tests/test_coding.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from contextlib import suppress
22

33
import numpy as np
4+
import pandas as pd
45
import pytest
56

67
import xarray as xr
78
from xarray.coding import variables
9+
from xarray.conventions import decode_cf_variable, encode_cf_variable
810

911
from . import assert_equal, assert_identical, requires_dask
1012

@@ -20,20 +22,36 @@ def test_CFMaskCoder_decode():
2022
assert_identical(expected, encoded)
2123

2224

23-
def test_CFMaskCoder_encode_missing_fill_values_conflict():
24-
original = xr.Variable(
25-
("x",),
26-
[0.0, -1.0, 1.0],
27-
encoding={"_FillValue": np.float32(1e20), "missing_value": np.float64(1e20)},
28-
)
29-
coder = variables.CFMaskCoder()
30-
encoded = coder.encode(original)
25+
encoding_with_dtype = {
26+
"dtype": np.dtype("float64"),
27+
"_FillValue": np.float32(1e20),
28+
"missing_value": np.float64(1e20),
29+
}
30+
encoding_without_dtype = {
31+
"_FillValue": np.float32(1e20),
32+
"missing_value": np.float64(1e20),
33+
}
34+
CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS = {
35+
"numeric-with-dtype": ([0.0, -1.0, 1.0], encoding_with_dtype),
36+
"numeric-without-dtype": ([0.0, -1.0, 1.0], encoding_without_dtype),
37+
"times-with-dtype": (pd.date_range("2000", periods=3), encoding_with_dtype),
38+
}
39+
40+
41+
@pytest.mark.parametrize(
42+
("data", "encoding"),
43+
CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS.values(),
44+
ids=list(CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS.keys()),
45+
)
46+
def test_CFMaskCoder_encode_missing_fill_values_conflict(data, encoding):
47+
original = xr.Variable(("x",), data, encoding=encoding)
48+
encoded = encode_cf_variable(original)
3149

3250
assert encoded.dtype == encoded.attrs["missing_value"].dtype
3351
assert encoded.dtype == encoded.attrs["_FillValue"].dtype
3452

3553
with pytest.warns(variables.SerializationWarning):
36-
roundtripped = coder.decode(coder.encode(original))
54+
roundtripped = decode_cf_variable("foo", encoded)
3755
assert_identical(roundtripped, original)
3856

3957

0 commit comments

Comments
 (0)