10
10
from xarray .core .alignment import align , broadcast
11
11
from xarray .core .computation import apply_ufunc , dot
12
12
from xarray .core .pycompat import is_duck_dask_array
13
- from xarray .core .types import Dims , T_Xarray
13
+ from xarray .core .types import Dims , T_DataArray , T_Xarray
14
14
from xarray .util .deprecation_helpers import _deprecate_positional_args
15
15
16
16
# Weighted quantile methods are a subset of the numpy supported quantile methods.
@@ -145,7 +145,7 @@ class Weighted(Generic[T_Xarray]):
145
145
146
146
__slots__ = ("obj" , "weights" )
147
147
148
- def __init__ (self , obj : T_Xarray , weights : DataArray ) -> None :
148
+ def __init__ (self , obj : T_Xarray , weights : T_DataArray ) -> None :
149
149
"""
150
150
Create a Weighted object
151
151
@@ -189,7 +189,7 @@ def _weight_check(w):
189
189
_weight_check (weights .data )
190
190
191
191
self .obj : T_Xarray = obj
192
- self .weights : DataArray = weights
192
+ self .weights : T_DataArray = weights
193
193
194
194
def _check_dim (self , dim : Dims ):
195
195
"""raise an error if any dimension is missing"""
@@ -208,11 +208,11 @@ def _check_dim(self, dim: Dims):
208
208
209
209
@staticmethod
210
210
def _reduce (
211
- da : DataArray ,
212
- weights : DataArray ,
211
+ da : T_DataArray ,
212
+ weights : T_DataArray ,
213
213
dim : Dims = None ,
214
214
skipna : bool | None = None ,
215
- ) -> DataArray :
215
+ ) -> T_DataArray :
216
216
"""reduce using dot; equivalent to (da * weights).sum(dim, skipna)
217
217
218
218
for internal use only
@@ -230,7 +230,7 @@ def _reduce(
230
230
# DataArray (if `weights` has additional dimensions)
231
231
return dot (da , weights , dim = dim )
232
232
233
- def _sum_of_weights (self , da : DataArray , dim : Dims = None ) -> DataArray :
233
+ def _sum_of_weights (self , da : T_DataArray , dim : Dims = None ) -> T_DataArray :
234
234
"""Calculate the sum of weights, accounting for missing values"""
235
235
236
236
# we need to mask data values that are nan; else the weights are wrong
@@ -255,10 +255,10 @@ def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray:
255
255
256
256
def _sum_of_squares (
257
257
self ,
258
- da : DataArray ,
258
+ da : T_DataArray ,
259
259
dim : Dims = None ,
260
260
skipna : bool | None = None ,
261
- ) -> DataArray :
261
+ ) -> T_DataArray :
262
262
"""Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s)."""
263
263
264
264
demeaned = da - da .weighted (self .weights ).mean (dim = dim )
@@ -267,20 +267,20 @@ def _sum_of_squares(
267
267
268
268
def _weighted_sum (
269
269
self ,
270
- da : DataArray ,
270
+ da : T_DataArray ,
271
271
dim : Dims = None ,
272
272
skipna : bool | None = None ,
273
- ) -> DataArray :
273
+ ) -> T_DataArray :
274
274
"""Reduce a DataArray by a weighted ``sum`` along some dimension(s)."""
275
275
276
276
return self ._reduce (da , self .weights , dim = dim , skipna = skipna )
277
277
278
278
def _weighted_mean (
279
279
self ,
280
- da : DataArray ,
280
+ da : T_DataArray ,
281
281
dim : Dims = None ,
282
282
skipna : bool | None = None ,
283
- ) -> DataArray :
283
+ ) -> T_DataArray :
284
284
"""Reduce a DataArray by a weighted ``mean`` along some dimension(s)."""
285
285
286
286
weighted_sum = self ._weighted_sum (da , dim = dim , skipna = skipna )
@@ -291,10 +291,10 @@ def _weighted_mean(
291
291
292
292
def _weighted_var (
293
293
self ,
294
- da : DataArray ,
294
+ da : T_DataArray ,
295
295
dim : Dims = None ,
296
296
skipna : bool | None = None ,
297
- ) -> DataArray :
297
+ ) -> T_DataArray :
298
298
"""Reduce a DataArray by a weighted ``var`` along some dimension(s)."""
299
299
300
300
sum_of_squares = self ._sum_of_squares (da , dim = dim , skipna = skipna )
@@ -305,21 +305,21 @@ def _weighted_var(
305
305
306
306
def _weighted_std (
307
307
self ,
308
- da : DataArray ,
308
+ da : T_DataArray ,
309
309
dim : Dims = None ,
310
310
skipna : bool | None = None ,
311
- ) -> DataArray :
311
+ ) -> T_DataArray :
312
312
"""Reduce a DataArray by a weighted ``std`` along some dimension(s)."""
313
313
314
- return cast ("DataArray " , np .sqrt (self ._weighted_var (da , dim , skipna )))
314
+ return cast ("T_DataArray " , np .sqrt (self ._weighted_var (da , dim , skipna )))
315
315
316
316
def _weighted_quantile (
317
317
self ,
318
- da : DataArray ,
318
+ da : T_DataArray ,
319
319
q : ArrayLike ,
320
320
dim : Dims = None ,
321
321
skipna : bool | None = None ,
322
- ) -> DataArray :
322
+ ) -> T_DataArray :
323
323
"""Apply a weighted ``quantile`` to a DataArray along some dimension(s)."""
324
324
325
325
def _get_h (n : float , q : np .ndarray , method : QUANTILE_METHODS ) -> np .ndarray :
0 commit comments