Skip to content

Commit e712270

Browse files
{full,zeros,ones}_like typing (#6611)
* type {full,zeros,ones}_like * fix modern numpy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * python3.8 support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * apply patch from max-sixty * add link to numpy.typing.DTypeLike Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8de7061 commit e712270

File tree

6 files changed

+185
-40
lines changed

6 files changed

+185
-40
lines changed

xarray/core/common.py

Lines changed: 104 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
Iterator,
1414
Mapping,
1515
TypeVar,
16+
Union,
1617
overload,
1718
)
1819

1920
import numpy as np
2021
import pandas as pd
2122

2223
from . import dtypes, duck_array_ops, formatting, formatting_html, ops
23-
from .npcompat import DTypeLike
24+
from .npcompat import DTypeLike, DTypeLikeSave
2425
from .options import OPTIONS, _get_keep_attrs
2526
from .pycompat import is_duck_dask_array
2627
from .rolling_exp import RollingExp
@@ -1577,26 +1578,45 @@ def __getitem__(self, value):
15771578
raise NotImplementedError()
15781579

15791580

1581+
DTypeMaybeMapping = Union[DTypeLikeSave, Mapping[Any, DTypeLikeSave]]
1582+
1583+
15801584
@overload
1581-
def full_like(
1582-
other: Dataset,
1583-
fill_value,
1584-
dtype: DTypeLike | Mapping[Any, DTypeLike] = None,
1585-
) -> Dataset:
1585+
def full_like(other: DataArray, fill_value: Any, dtype: DTypeLikeSave) -> DataArray:
1586+
...
1587+
1588+
1589+
@overload
1590+
def full_like(other: Dataset, fill_value: Any, dtype: DTypeMaybeMapping) -> Dataset:
15861591
...
15871592

15881593

15891594
@overload
1590-
def full_like(other: DataArray, fill_value, dtype: DTypeLike = None) -> DataArray:
1595+
def full_like(other: Variable, fill_value: Any, dtype: DTypeLikeSave) -> Variable:
15911596
...
15921597

15931598

15941599
@overload
1595-
def full_like(other: Variable, fill_value, dtype: DTypeLike = None) -> Variable:
1600+
def full_like(
1601+
other: Dataset | DataArray, fill_value: Any, dtype: DTypeMaybeMapping = None
1602+
) -> Dataset | DataArray:
15961603
...
15971604

15981605

1599-
def full_like(other, fill_value, dtype=None):
1606+
@overload
1607+
def full_like(
1608+
other: Dataset | DataArray | Variable,
1609+
fill_value: Any,
1610+
dtype: DTypeMaybeMapping = None,
1611+
) -> Dataset | DataArray | Variable:
1612+
...
1613+
1614+
1615+
def full_like(
1616+
other: Dataset | DataArray | Variable,
1617+
fill_value: Any,
1618+
dtype: DTypeMaybeMapping = None,
1619+
) -> Dataset | DataArray | Variable:
16001620
"""Return a new object with the same shape and type as a given object.
16011621
16021622
Parameters
@@ -1711,26 +1731,26 @@ def full_like(other, fill_value, dtype=None):
17111731
f"fill_value must be scalar or, for datasets, a dict-like. Received {fill_value} instead."
17121732
)
17131733

1714-
if not isinstance(other, Dataset) and isinstance(dtype, Mapping):
1715-
raise ValueError(
1716-
"'dtype' cannot be dict-like when passing a DataArray or Variable"
1717-
)
1718-
17191734
if isinstance(other, Dataset):
17201735
if not isinstance(fill_value, dict):
17211736
fill_value = {k: fill_value for k in other.data_vars.keys()}
17221737

1738+
dtype_: Mapping[Any, DTypeLikeSave]
17231739
if not isinstance(dtype, Mapping):
17241740
dtype_ = {k: dtype for k in other.data_vars.keys()}
17251741
else:
17261742
dtype_ = dtype
17271743

17281744
data_vars = {
1729-
k: _full_like_variable(v, fill_value.get(k, dtypes.NA), dtype_.get(k, None))
1745+
k: _full_like_variable(
1746+
v.variable, fill_value.get(k, dtypes.NA), dtype_.get(k, None)
1747+
)
17301748
for k, v in other.data_vars.items()
17311749
}
17321750
return Dataset(data_vars, coords=other.coords, attrs=other.attrs)
17331751
elif isinstance(other, DataArray):
1752+
if isinstance(dtype, Mapping):
1753+
raise ValueError("'dtype' cannot be dict-like when passing a DataArray")
17341754
return DataArray(
17351755
_full_like_variable(other.variable, fill_value, dtype),
17361756
dims=other.dims,
@@ -1739,12 +1759,16 @@ def full_like(other, fill_value, dtype=None):
17391759
name=other.name,
17401760
)
17411761
elif isinstance(other, Variable):
1762+
if isinstance(dtype, Mapping):
1763+
raise ValueError("'dtype' cannot be dict-like when passing a Variable")
17421764
return _full_like_variable(other, fill_value, dtype)
17431765
else:
17441766
raise TypeError("Expected DataArray, Dataset, or Variable")
17451767

17461768

1747-
def _full_like_variable(other, fill_value, dtype: DTypeLike = None):
1769+
def _full_like_variable(
1770+
other: Variable, fill_value: Any, dtype: DTypeLike = None
1771+
) -> Variable:
17481772
"""Inner function of full_like, where other must be a variable"""
17491773
from .variable import Variable
17501774

@@ -1765,7 +1789,38 @@ def _full_like_variable(other, fill_value, dtype: DTypeLike = None):
17651789
return Variable(dims=other.dims, data=data, attrs=other.attrs)
17661790

17671791

1768-
def zeros_like(other, dtype: DTypeLike = None):
1792+
@overload
1793+
def zeros_like(other: DataArray, dtype: DTypeLikeSave) -> DataArray:
1794+
...
1795+
1796+
1797+
@overload
1798+
def zeros_like(other: Dataset, dtype: DTypeMaybeMapping) -> Dataset:
1799+
...
1800+
1801+
1802+
@overload
1803+
def zeros_like(other: Variable, dtype: DTypeLikeSave) -> Variable:
1804+
...
1805+
1806+
1807+
@overload
1808+
def zeros_like(
1809+
other: Dataset | DataArray, dtype: DTypeMaybeMapping = None
1810+
) -> Dataset | DataArray:
1811+
...
1812+
1813+
1814+
@overload
1815+
def zeros_like(
1816+
other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None
1817+
) -> Dataset | DataArray | Variable:
1818+
...
1819+
1820+
1821+
def zeros_like(
1822+
other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None
1823+
) -> Dataset | DataArray | Variable:
17691824
"""Return a new object of zeros with the same shape and
17701825
type as a given dataarray or dataset.
17711826
@@ -1821,7 +1876,38 @@ def zeros_like(other, dtype: DTypeLike = None):
18211876
return full_like(other, 0, dtype)
18221877

18231878

1824-
def ones_like(other, dtype: DTypeLike = None):
1879+
@overload
1880+
def ones_like(other: DataArray, dtype: DTypeLikeSave) -> DataArray:
1881+
...
1882+
1883+
1884+
@overload
1885+
def ones_like(other: Dataset, dtype: DTypeMaybeMapping) -> Dataset:
1886+
...
1887+
1888+
1889+
@overload
1890+
def ones_like(other: Variable, dtype: DTypeLikeSave) -> Variable:
1891+
...
1892+
1893+
1894+
@overload
1895+
def ones_like(
1896+
other: Dataset | DataArray, dtype: DTypeMaybeMapping = None
1897+
) -> Dataset | DataArray:
1898+
...
1899+
1900+
1901+
@overload
1902+
def ones_like(
1903+
other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None
1904+
) -> Dataset | DataArray | Variable:
1905+
...
1906+
1907+
1908+
def ones_like(
1909+
other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None
1910+
) -> Dataset | DataArray | Variable:
18251911
"""Return a new object of ones with the same shape and
18261912
type as a given dataarray or dataset.
18271913

xarray/core/computation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,7 +1905,7 @@ def polyval(
19051905
coeffs = coeffs.reindex(
19061906
{degree_dim: np.arange(max_deg + 1)}, fill_value=0, copy=False
19071907
)
1908-
coord = _ensure_numeric(coord) # type: ignore # https://github.com/python/mypy/issues/1533 ?
1908+
coord = _ensure_numeric(coord)
19091909

19101910
# using Horner's method
19111911
# https://en.wikipedia.org/wiki/Horner%27s_method
@@ -1917,7 +1917,7 @@ def polyval(
19171917
return res
19181918

19191919

1920-
def _ensure_numeric(data: T_Xarray) -> T_Xarray:
1920+
def _ensure_numeric(data: Dataset | DataArray) -> Dataset | DataArray:
19211921
"""Converts all datetime64 variables to float64
19221922
19231923
Parameters

xarray/core/missing.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from __future__ import annotations
2+
13
import datetime as dt
24
import warnings
35
from functools import partial
46
from numbers import Number
5-
from typing import Any, Callable, Dict, Hashable, Sequence, Union
7+
from typing import TYPE_CHECKING, Any, Callable, Hashable, Sequence
68

79
import numpy as np
810
import pandas as pd
@@ -17,8 +19,14 @@
1719
from .utils import OrderedSet, is_scalar
1820
from .variable import Variable, broadcast_variables
1921

22+
if TYPE_CHECKING:
23+
from .dataarray import DataArray
24+
from .dataset import Dataset
25+
2026

21-
def _get_nan_block_lengths(obj, dim: Hashable, index: Variable):
27+
def _get_nan_block_lengths(
28+
obj: Dataset | DataArray | Variable, dim: Hashable, index: Variable
29+
):
2230
"""
2331
Return an object where each NaN element in 'obj' is replaced by the
2432
length of the gap the element is in.
@@ -48,8 +56,8 @@ def _get_nan_block_lengths(obj, dim: Hashable, index: Variable):
4856
class BaseInterpolator:
4957
"""Generic interpolator class for normalizing interpolation methods"""
5058

51-
cons_kwargs: Dict[str, Any]
52-
call_kwargs: Dict[str, Any]
59+
cons_kwargs: dict[str, Any]
60+
call_kwargs: dict[str, Any]
5361
f: Callable
5462
method: str
5563

@@ -213,7 +221,7 @@ def _apply_over_vars_with_dim(func, self, dim=None, **kwargs):
213221

214222

215223
def get_clean_interp_index(
216-
arr, dim: Hashable, use_coordinate: Union[str, bool] = True, strict: bool = True
224+
arr, dim: Hashable, use_coordinate: str | bool = True, strict: bool = True
217225
):
218226
"""Return index to use for x values in interpolation or curve fitting.
219227
@@ -300,10 +308,10 @@ def get_clean_interp_index(
300308
def interp_na(
301309
self,
302310
dim: Hashable = None,
303-
use_coordinate: Union[bool, str] = True,
311+
use_coordinate: bool | str = True,
304312
method: str = "linear",
305313
limit: int = None,
306-
max_gap: Union[int, float, str, pd.Timedelta, np.timedelta64, dt.timedelta] = None,
314+
max_gap: int | float | str | pd.Timedelta | np.timedelta64 | dt.timedelta = None,
307315
keep_attrs: bool = None,
308316
**kwargs,
309317
):

xarray/core/npcompat.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,49 @@
2828
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
2929
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
3030
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31-
from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar, Union
31+
from typing import (
32+
TYPE_CHECKING,
33+
Any,
34+
List,
35+
Literal,
36+
Sequence,
37+
Tuple,
38+
Type,
39+
TypeVar,
40+
Union,
41+
)
3242

3343
import numpy as np
3444
from packaging.version import Version
3545

3646
# Type annotations stubs
3747
try:
3848
from numpy.typing import ArrayLike, DTypeLike
49+
from numpy.typing._dtype_like import _DTypeLikeNested, _ShapeLike, _SupportsDType
50+
51+
# Xarray requires a Mapping[Hashable, dtype] in many places which
52+
# conflics with numpys own DTypeLike (with dtypes for fields).
53+
# https://numpy.org/devdocs/reference/typing.html#numpy.typing.DTypeLike
54+
# This is a copy of this DTypeLike that allows only non-Mapping dtypes.
55+
DTypeLikeSave = Union[
56+
np.dtype,
57+
# default data type (float64)
58+
None,
59+
# array-scalar types and generic types
60+
Type[Any],
61+
# character codes, type strings or comma-separated fields, e.g., 'float64'
62+
str,
63+
# (flexible_dtype, itemsize)
64+
Tuple[_DTypeLikeNested, int],
65+
# (fixed_dtype, shape)
66+
Tuple[_DTypeLikeNested, _ShapeLike],
67+
# (base_dtype, new_dtype)
68+
Tuple[_DTypeLikeNested, _DTypeLikeNested],
69+
# because numpy does the same?
70+
List[Any],
71+
# anything with a dtype attribute
72+
_SupportsDType[np.dtype],
73+
]
3974
except ImportError:
4075
# fall back for numpy < 1.20, ArrayLike adapted from numpy.typing._array_like
4176
from typing import Protocol
@@ -46,8 +81,14 @@ class _SupportsArray(Protocol):
4681
def __array__(self) -> np.ndarray:
4782
...
4883

84+
class _SupportsDTypeFallback(Protocol):
85+
@property
86+
def dtype(self) -> np.dtype:
87+
...
88+
4989
else:
5090
_SupportsArray = Any
91+
_SupportsDTypeFallback = Any
5192

5293
_T = TypeVar("_T")
5394
_NestedSequence = Union[
@@ -72,7 +113,16 @@ def __array__(self) -> np.ndarray:
72113
# with the same name (ArrayLike and DTypeLike from the try block)
73114
ArrayLike = _ArrayLikeFallback # type: ignore
74115
# fall back for numpy < 1.20
75-
DTypeLike = Union[np.dtype, str] # type: ignore[misc]
116+
DTypeLikeSave = Union[ # type: ignore[misc]
117+
np.dtype,
118+
str,
119+
None,
120+
Type[Any],
121+
Tuple[Any, Any],
122+
List[Any],
123+
_SupportsDTypeFallback,
124+
]
125+
DTypeLike = DTypeLikeSave # type: ignore[misc]
76126

77127

78128
if Version(np.__version__) >= Version("1.20.0"):

0 commit comments

Comments
 (0)