Skip to content

Commit 213e352

Browse files
authored
Added support for numpy.bool_ (#4986)
1 parent 6ff27ca commit 213e352

File tree

3 files changed

+29
-11
lines changed

3 files changed

+29
-11
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ Deprecations
6565

6666
Bug fixes
6767
~~~~~~~~~
68+
- Added support for `numpy.bool_` attributes in roundtrips using `h5netcdf` engine with `invalid_netcdf=True` [which casts `bool`s to `numpy.bool_`] (:issue:`4981`, :pull:`4986`).
69+
By `Victor Negîrneac <https://github.com/caenrigen>`_.
6870
- Don't allow passing ``axis`` to :py:meth:`Dataset.reduce` methods (:issue:`3510`, :pull:`4940`).
6971
By `Justus Magin <https://github.com/keewis>`_.
7072

xarray/backends/api.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,21 @@ def check_name(name):
141141
check_name(k)
142142

143143

144-
def _validate_attrs(dataset):
144+
def _validate_attrs(dataset, invalid_netcdf=False):
145145
"""`attrs` must have a string key and a value which is either: a number,
146-
a string, an ndarray or a list/tuple of numbers/strings.
146+
a string, an ndarray, a list/tuple of numbers/strings, or a numpy.bool_.
147+
148+
Notes
149+
-----
150+
A numpy.bool_ is only allowed when using the h5netcdf engine with
151+
`invalid_netcdf=True`.
147152
"""
148153

149-
def check_attr(name, value):
154+
valid_types = (str, Number, np.ndarray, np.number, list, tuple)
155+
if invalid_netcdf:
156+
valid_types += (np.bool_,)
157+
158+
def check_attr(name, value, valid_types):
150159
if isinstance(name, str):
151160
if not name:
152161
raise ValueError(
@@ -160,22 +169,21 @@ def check_attr(name, value):
160169
"serialization to netCDF files"
161170
)
162171

163-
if not isinstance(value, (str, Number, np.ndarray, np.number, list, tuple)):
172+
if not isinstance(value, valid_types):
164173
raise TypeError(
165-
f"Invalid value for attr {name!r}: {value!r} must be a number, "
166-
"a string, an ndarray or a list/tuple of "
167-
"numbers/strings for serialization to netCDF "
168-
"files"
174+
f"Invalid value for attr {name!r}: {value!r}. For serialization to "
175+
"netCDF files, its value must be of one of the following types: "
176+
f"{', '.join([vtype.__name__ for vtype in valid_types])}"
169177
)
170178

171179
# Check attrs on the dataset itself
172180
for k, v in dataset.attrs.items():
173-
check_attr(k, v)
181+
check_attr(k, v, valid_types)
174182

175183
# Check attrs on each variable within the dataset
176184
for variable in dataset.variables.values():
177185
for k, v in variable.attrs.items():
178-
check_attr(k, v)
186+
check_attr(k, v, valid_types)
179187

180188

181189
def _resolve_decoders_kwargs(decode_cf, open_backend_dataset_parameters, **decoders):
@@ -1019,7 +1027,7 @@ def to_netcdf(
10191027

10201028
# validate Dataset keys, DataArray names, and attr keys/values
10211029
_validate_dataset_names(dataset)
1022-
_validate_attrs(dataset)
1030+
_validate_attrs(dataset, invalid_netcdf=invalid_netcdf and engine == "h5netcdf")
10231031

10241032
try:
10251033
store_open = WRITEABLE_STORES[engine]

xarray/tests/test_backends.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2541,6 +2541,14 @@ def test_complex(self, invalid_netcdf, warntype, num_warns):
25412541

25422542
assert recorded_num_warns == num_warns
25432543

2544+
def test_numpy_bool_(self):
2545+
# h5netcdf loads booleans as numpy.bool_, this type needs to be supported
2546+
# when writing invalid_netcdf datasets in order to support a roundtrip
2547+
expected = Dataset({"x": ("y", np.ones(5), {"numpy_bool": np.bool_(True)})})
2548+
save_kwargs = {"invalid_netcdf": True}
2549+
with self.roundtrip(expected, save_kwargs=save_kwargs) as actual:
2550+
assert_identical(expected, actual)
2551+
25442552
def test_cross_engine_read_write_netcdf4(self):
25452553
# Drop dim3, because its labels include strings. These appear to be
25462554
# not properly read with python-netCDF4, which converts them into

0 commit comments

Comments
 (0)