Skip to content

Commit 6b1d97a

Browse files
Typing for open_dataset/array/mfdataset and to_netcdf/zarr (#6612)
* type filename and chunks * type open_dataset, open_dataarray, open_mfdataset * type to_netcdf * add return doc to Dataset.to_netcdf * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import error * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * replace tuple[x] by Tuple[x] for py3.8 * fix some merge errors * add overloads to to_zarr * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix absolute import * CamelCase type vars * move some literal type to core.types * add JoinOptions to core.types * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add some blank lines under bullet lists in docs * add comments to overloads * some more typing * fix absolute import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Delete mypy.py whops, accidential upload * fix typo * fix absolute import * fix some absolute imports * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * replace Dict by dict * fix DataArray import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix _dataset_concat arg name * fix DataArray not imported * remove xr import in Dataset * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * some more typing * replace some Sequence by Iterable * fix wrong default in docstring * fix docstring indentation * fix overloads and type some tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix open_mfdataset typing * minor update of docstring * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove uneccesary import * fix overloads of to_netcdf * minor docstring update Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e712270 commit 6b1d97a

16 files changed

+885
-364
lines changed

xarray/backends/api.py

Lines changed: 259 additions & 105 deletions
Large diffs are not rendered by default.

xarray/backends/common.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from __future__ import annotations
2+
13
import logging
24
import os
35
import time
46
import traceback
5-
from typing import Any, Dict, Tuple, Type, Union
7+
from typing import Any
68

79
import numpy as np
810

@@ -369,13 +371,13 @@ class BackendEntrypoint:
369371
method is not mandatory.
370372
"""
371373

372-
open_dataset_parameters: Union[Tuple, None] = None
374+
open_dataset_parameters: tuple | None = None
373375
"""list of ``open_dataset`` method parameters"""
374376

375377
def open_dataset(
376378
self,
377-
filename_or_obj: str,
378-
drop_variables: Tuple[str] = None,
379+
filename_or_obj: str | os.PathLike,
380+
drop_variables: tuple[str] | None = None,
379381
**kwargs: Any,
380382
):
381383
"""
@@ -384,12 +386,12 @@ def open_dataset(
384386

385387
raise NotImplementedError
386388

387-
def guess_can_open(self, filename_or_obj):
389+
def guess_can_open(self, filename_or_obj: str | os.PathLike):
388390
"""
389391
Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`.
390392
"""
391393

392394
return False
393395

394396

395-
BACKEND_ENTRYPOINTS: Dict[str, Type[BackendEntrypoint]] = {}
397+
BACKEND_ENTRYPOINTS: dict[str, type[BackendEntrypoint]] = {}

xarray/core/alignment.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
if TYPE_CHECKING:
3131
from .dataarray import DataArray
3232
from .dataset import Dataset
33+
from .types import JoinOptions
3334

3435
DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords)
3536

@@ -557,7 +558,7 @@ def align(self) -> None:
557558

558559
def align(
559560
*objects: DataAlignable,
560-
join="inner",
561+
join: JoinOptions = "inner",
561562
copy=True,
562563
indexes=None,
563564
exclude=frozenset(),
@@ -590,6 +591,7 @@ def align(
590591
- "override": if indexes are of same size, rewrite indexes to be
591592
those of the first object with that dimension. Indexes for the same
592593
dimension must have the same size in all objects.
594+
593595
copy : bool, optional
594596
If ``copy=True``, data in the return values is always copied. If
595597
``copy=False`` and reindexing is unnecessary, or can be performed with
@@ -764,7 +766,7 @@ def align(
764766

765767
def deep_align(
766768
objects,
767-
join="inner",
769+
join: JoinOptions = "inner",
768770
copy=True,
769771
indexes=None,
770772
exclude=frozenset(),
@@ -834,7 +836,7 @@ def is_alignable(obj):
834836
if key is no_key:
835837
out[position] = aligned_obj
836838
else:
837-
out[position][key] = aligned_obj
839+
out[position][key] = aligned_obj # type: ignore[index] # maybe someone can fix this?
838840

839841
# something went wrong: we should have replaced all sentinel values
840842
for arg in out:

xarray/core/combine.py

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from __future__ import annotations
2+
13
import itertools
24
import warnings
35
from collections import Counter
4-
from typing import Iterable, Sequence, Union
6+
from typing import TYPE_CHECKING, Iterable, Literal, Sequence, Union
57

68
import pandas as pd
79

@@ -12,6 +14,9 @@
1214
from .merge import merge
1315
from .utils import iterate_nested
1416

17+
if TYPE_CHECKING:
18+
from .types import CombineAttrsOptions, CompatOptions, JoinOptions
19+
1520

1621
def _infer_concat_order_from_positions(datasets):
1722
return dict(_infer_tile_ids_from_nested_list(datasets, ()))
@@ -188,10 +193,10 @@ def _combine_nd(
188193
concat_dims,
189194
data_vars="all",
190195
coords="different",
191-
compat="no_conflicts",
196+
compat: CompatOptions = "no_conflicts",
192197
fill_value=dtypes.NA,
193-
join="outer",
194-
combine_attrs="drop",
198+
join: JoinOptions = "outer",
199+
combine_attrs: CombineAttrsOptions = "drop",
195200
):
196201
"""
197202
Combines an N-dimensional structure of datasets into one by applying a
@@ -250,10 +255,10 @@ def _combine_all_along_first_dim(
250255
dim,
251256
data_vars,
252257
coords,
253-
compat,
258+
compat: CompatOptions,
254259
fill_value=dtypes.NA,
255-
join="outer",
256-
combine_attrs="drop",
260+
join: JoinOptions = "outer",
261+
combine_attrs: CombineAttrsOptions = "drop",
257262
):
258263

259264
# Group into lines of datasets which must be combined along dim
@@ -276,12 +281,12 @@ def _combine_all_along_first_dim(
276281
def _combine_1d(
277282
datasets,
278283
concat_dim,
279-
compat="no_conflicts",
284+
compat: CompatOptions = "no_conflicts",
280285
data_vars="all",
281286
coords="different",
282287
fill_value=dtypes.NA,
283-
join="outer",
284-
combine_attrs="drop",
288+
join: JoinOptions = "outer",
289+
combine_attrs: CombineAttrsOptions = "drop",
285290
):
286291
"""
287292
Applies either concat or merge to 1D list of datasets depending on value
@@ -336,8 +341,8 @@ def _nested_combine(
336341
coords,
337342
ids,
338343
fill_value=dtypes.NA,
339-
join="outer",
340-
combine_attrs="drop",
344+
join: JoinOptions = "outer",
345+
combine_attrs: CombineAttrsOptions = "drop",
341346
):
342347

343348
if len(datasets) == 0:
@@ -377,15 +382,13 @@ def _nested_combine(
377382

378383
def combine_nested(
379384
datasets: DATASET_HYPERCUBE,
380-
concat_dim: Union[
381-
str, DataArray, None, Sequence[Union[str, "DataArray", pd.Index, None]]
382-
],
385+
concat_dim: (str | DataArray | None | Sequence[str | DataArray | pd.Index | None]),
383386
compat: str = "no_conflicts",
384387
data_vars: str = "all",
385388
coords: str = "different",
386389
fill_value: object = dtypes.NA,
387-
join: str = "outer",
388-
combine_attrs: str = "drop",
390+
join: JoinOptions = "outer",
391+
combine_attrs: CombineAttrsOptions = "drop",
389392
) -> Dataset:
390393
"""
391394
Explicitly combine an N-dimensional grid of datasets into one by using a
@@ -603,9 +606,9 @@ def _combine_single_variable_hypercube(
603606
fill_value=dtypes.NA,
604607
data_vars="all",
605608
coords="different",
606-
compat="no_conflicts",
607-
join="outer",
608-
combine_attrs="no_conflicts",
609+
compat: CompatOptions = "no_conflicts",
610+
join: JoinOptions = "outer",
611+
combine_attrs: CombineAttrsOptions = "no_conflicts",
609612
):
610613
"""
611614
Attempt to combine a list of Datasets into a hypercube using their
@@ -659,15 +662,15 @@ def _combine_single_variable_hypercube(
659662

660663
# TODO remove empty list default param after version 0.21, see PR4696
661664
def combine_by_coords(
662-
data_objects: Sequence[Union[Dataset, DataArray]] = [],
663-
compat: str = "no_conflicts",
664-
data_vars: str = "all",
665+
data_objects: Iterable[Dataset | DataArray] = [],
666+
compat: CompatOptions = "no_conflicts",
667+
data_vars: Literal["all", "minimal", "different"] | list[str] = "all",
665668
coords: str = "different",
666669
fill_value: object = dtypes.NA,
667-
join: str = "outer",
668-
combine_attrs: str = "no_conflicts",
669-
datasets: Sequence[Dataset] = None,
670-
) -> Union[Dataset, DataArray]:
670+
join: JoinOptions = "outer",
671+
combine_attrs: CombineAttrsOptions = "no_conflicts",
672+
datasets: Iterable[Dataset] = None,
673+
) -> Dataset | DataArray:
671674
"""
672675
673676
Attempt to auto-magically combine the given datasets (or data arrays)
@@ -695,7 +698,7 @@ def combine_by_coords(
695698
696699
Parameters
697700
----------
698-
data_objects : sequence of xarray.Dataset or sequence of xarray.DataArray
701+
data_objects : Iterable of Datasets or DataArrays
699702
Data objects to combine.
700703
701704
compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional
@@ -711,18 +714,19 @@ def combine_by_coords(
711714
must be equal. The returned dataset then contains the combination
712715
of all non-null values.
713716
- "override": skip comparing and pick variable from first dataset
717+
714718
data_vars : {"minimal", "different", "all" or list of str}, optional
715719
These data variables will be concatenated together:
716720
717-
* "minimal": Only data variables in which the dimension already
721+
- "minimal": Only data variables in which the dimension already
718722
appears are included.
719-
* "different": Data variables which are not equal (ignoring
723+
- "different": Data variables which are not equal (ignoring
720724
attributes) across all datasets are also concatenated (as well as
721725
all for which dimension already appears). Beware: this option may
722726
load the data payload of data variables into memory if they are not
723727
already loaded.
724-
* "all": All data variables will be concatenated.
725-
* list of str: The listed data variables will be concatenated, in
728+
- "all": All data variables will be concatenated.
729+
- list of str: The listed data variables will be concatenated, in
726730
addition to the "minimal" data variables.
727731
728732
If objects are DataArrays, `data_vars` must be "all".
@@ -745,6 +749,7 @@ def combine_by_coords(
745749
- "override": if indexes are of same size, rewrite indexes to be
746750
those of the first object with that dimension. Indexes for the same
747751
dimension must have the same size in all objects.
752+
748753
combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
749754
"override"} or callable, default: "drop"
750755
A callable or a string indicating how to combine attrs of the objects being
@@ -762,6 +767,8 @@ def combine_by_coords(
762767
If a callable, it must expect a sequence of ``attrs`` dicts and a context object
763768
as its only parameters.
764769
770+
datasets : Iterable of Datasets
771+
765772
Returns
766773
-------
767774
combined : xarray.Dataset or xarray.DataArray

xarray/core/computation.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from .coordinates import Coordinates
3838
from .dataarray import DataArray
3939
from .dataset import Dataset
40-
from .types import T_Xarray
40+
from .types import CombineAttrsOptions, JoinOptions, T_Xarray
4141

4242
_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
4343
_DEFAULT_NAME = utils.ReprObject("<default-name>")
@@ -184,7 +184,7 @@ def _enumerate(dim):
184184
return str(alt_signature)
185185

186186

187-
def result_name(objects: list) -> Any:
187+
def result_name(objects: Iterable[Any]) -> Any:
188188
# use the same naming heuristics as pandas:
189189
# https://github.com/blaze/blaze/issues/458#issuecomment-51936356
190190
names = {getattr(obj, "name", _DEFAULT_NAME) for obj in objects}
@@ -196,7 +196,7 @@ def result_name(objects: list) -> Any:
196196
return name
197197

198198

199-
def _get_coords_list(args) -> list[Coordinates]:
199+
def _get_coords_list(args: Iterable[Any]) -> list[Coordinates]:
200200
coords_list = []
201201
for arg in args:
202202
try:
@@ -209,23 +209,39 @@ def _get_coords_list(args) -> list[Coordinates]:
209209

210210

211211
def build_output_coords_and_indexes(
212-
args: list,
212+
args: Iterable[Any],
213213
signature: _UFuncSignature,
214214
exclude_dims: AbstractSet = frozenset(),
215-
combine_attrs: str = "override",
215+
combine_attrs: CombineAttrsOptions = "override",
216216
) -> tuple[list[dict[Any, Variable]], list[dict[Any, Index]]]:
217217
"""Build output coordinates and indexes for an operation.
218218
219219
Parameters
220220
----------
221-
args : list
221+
args : Iterable
222222
List of raw operation arguments. Any valid types for xarray operations
223223
are OK, e.g., scalars, Variable, DataArray, Dataset.
224224
signature : _UfuncSignature
225225
Core dimensions signature for the operation.
226226
exclude_dims : set, optional
227227
Dimensions excluded from the operation. Coordinates along these
228228
dimensions are dropped.
229+
combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
230+
"override"} or callable, default: "drop"
231+
A callable or a string indicating how to combine attrs of the objects being
232+
merged:
233+
234+
- "drop": empty attrs on returned Dataset.
235+
- "identical": all attrs must be the same on every object.
236+
- "no_conflicts": attrs from all objects are combined, any that have
237+
the same name must also have the same value.
238+
- "drop_conflicts": attrs from all objects are combined, any that have
239+
the same name but different values are dropped.
240+
- "override": skip comparing and copy attrs from the first dataset to
241+
the result.
242+
243+
If a callable, it must expect a sequence of ``attrs`` dicts and a context object
244+
as its only parameters.
229245
230246
Returns
231247
-------
@@ -267,10 +283,10 @@ def apply_dataarray_vfunc(
267283
func,
268284
*args,
269285
signature,
270-
join="inner",
286+
join: JoinOptions = "inner",
271287
exclude_dims=frozenset(),
272288
keep_attrs="override",
273-
):
289+
) -> tuple[DataArray, ...] | DataArray:
274290
"""Apply a variable level function over DataArray, Variable and/or ndarray
275291
objects.
276292
"""
@@ -295,6 +311,7 @@ def apply_dataarray_vfunc(
295311
data_vars = [getattr(a, "variable", a) for a in args]
296312
result_var = func(*data_vars)
297313

314+
out: tuple[DataArray, ...] | DataArray
298315
if signature.num_outputs > 1:
299316
out = tuple(
300317
DataArray(
@@ -829,7 +846,7 @@ def apply_ufunc(
829846
output_core_dims: Sequence[Sequence] | None = ((),),
830847
exclude_dims: AbstractSet = frozenset(),
831848
vectorize: bool = False,
832-
join: str = "exact",
849+
join: JoinOptions = "exact",
833850
dataset_join: str = "exact",
834851
dataset_fill_value: object = _NO_FILL_VALUE,
835852
keep_attrs: bool | str | None = None,

0 commit comments

Comments
 (0)