Skip to content

Commit e571d1c

Browse files
authored
Use T_DataArray in Weighted (#8630)
* Use `T_DataArray` in `Weighted` Allows subtypes. (I had this in my git stash, so commiting it...) * Apply suggestions from code review
1 parent 5bd3d8b commit e571d1c

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

xarray/core/weighted.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from xarray.core.alignment import align, broadcast
1111
from xarray.core.computation import apply_ufunc, dot
1212
from xarray.core.pycompat import is_duck_dask_array
13-
from xarray.core.types import Dims, T_Xarray
13+
from xarray.core.types import Dims, T_DataArray, T_Xarray
1414
from xarray.util.deprecation_helpers import _deprecate_positional_args
1515

1616
# Weighted quantile methods are a subset of the numpy supported quantile methods.
@@ -145,7 +145,7 @@ class Weighted(Generic[T_Xarray]):
145145

146146
__slots__ = ("obj", "weights")
147147

148-
def __init__(self, obj: T_Xarray, weights: DataArray) -> None:
148+
def __init__(self, obj: T_Xarray, weights: T_DataArray) -> None:
149149
"""
150150
Create a Weighted object
151151
@@ -189,7 +189,7 @@ def _weight_check(w):
189189
_weight_check(weights.data)
190190

191191
self.obj: T_Xarray = obj
192-
self.weights: DataArray = weights
192+
self.weights: T_DataArray = weights
193193

194194
def _check_dim(self, dim: Dims):
195195
"""raise an error if any dimension is missing"""
@@ -208,11 +208,11 @@ def _check_dim(self, dim: Dims):
208208

209209
@staticmethod
210210
def _reduce(
211-
da: DataArray,
212-
weights: DataArray,
211+
da: T_DataArray,
212+
weights: T_DataArray,
213213
dim: Dims = None,
214214
skipna: bool | None = None,
215-
) -> DataArray:
215+
) -> T_DataArray:
216216
"""reduce using dot; equivalent to (da * weights).sum(dim, skipna)
217217
218218
for internal use only
@@ -230,7 +230,7 @@ def _reduce(
230230
# DataArray (if `weights` has additional dimensions)
231231
return dot(da, weights, dim=dim)
232232

233-
def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray:
233+
def _sum_of_weights(self, da: T_DataArray, dim: Dims = None) -> T_DataArray:
234234
"""Calculate the sum of weights, accounting for missing values"""
235235

236236
# we need to mask data values that are nan; else the weights are wrong
@@ -255,10 +255,10 @@ def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray:
255255

256256
def _sum_of_squares(
257257
self,
258-
da: DataArray,
258+
da: T_DataArray,
259259
dim: Dims = None,
260260
skipna: bool | None = None,
261-
) -> DataArray:
261+
) -> T_DataArray:
262262
"""Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s)."""
263263

264264
demeaned = da - da.weighted(self.weights).mean(dim=dim)
@@ -267,20 +267,20 @@ def _sum_of_squares(
267267

268268
def _weighted_sum(
269269
self,
270-
da: DataArray,
270+
da: T_DataArray,
271271
dim: Dims = None,
272272
skipna: bool | None = None,
273-
) -> DataArray:
273+
) -> T_DataArray:
274274
"""Reduce a DataArray by a weighted ``sum`` along some dimension(s)."""
275275

276276
return self._reduce(da, self.weights, dim=dim, skipna=skipna)
277277

278278
def _weighted_mean(
279279
self,
280-
da: DataArray,
280+
da: T_DataArray,
281281
dim: Dims = None,
282282
skipna: bool | None = None,
283-
) -> DataArray:
283+
) -> T_DataArray:
284284
"""Reduce a DataArray by a weighted ``mean`` along some dimension(s)."""
285285

286286
weighted_sum = self._weighted_sum(da, dim=dim, skipna=skipna)
@@ -291,10 +291,10 @@ def _weighted_mean(
291291

292292
def _weighted_var(
293293
self,
294-
da: DataArray,
294+
da: T_DataArray,
295295
dim: Dims = None,
296296
skipna: bool | None = None,
297-
) -> DataArray:
297+
) -> T_DataArray:
298298
"""Reduce a DataArray by a weighted ``var`` along some dimension(s)."""
299299

300300
sum_of_squares = self._sum_of_squares(da, dim=dim, skipna=skipna)
@@ -305,21 +305,21 @@ def _weighted_var(
305305

306306
def _weighted_std(
307307
self,
308-
da: DataArray,
308+
da: T_DataArray,
309309
dim: Dims = None,
310310
skipna: bool | None = None,
311-
) -> DataArray:
311+
) -> T_DataArray:
312312
"""Reduce a DataArray by a weighted ``std`` along some dimension(s)."""
313313

314-
return cast("DataArray", np.sqrt(self._weighted_var(da, dim, skipna)))
314+
return cast("T_DataArray", np.sqrt(self._weighted_var(da, dim, skipna)))
315315

316316
def _weighted_quantile(
317317
self,
318-
da: DataArray,
318+
da: T_DataArray,
319319
q: ArrayLike,
320320
dim: Dims = None,
321321
skipna: bool | None = None,
322-
) -> DataArray:
322+
) -> T_DataArray:
323323
"""Apply a weighted ``quantile`` to a DataArray along some dimension(s)."""
324324

325325
def _get_h(n: float, q: np.ndarray, method: QUANTILE_METHODS) -> np.ndarray:

0 commit comments

Comments
 (0)