1
- from typing import TYPE_CHECKING , Generic , Hashable , Iterable , Optional , Union
1
+ from typing import TYPE_CHECKING , Generic , Hashable , Iterable , Optional , Union , cast
2
+
3
+ import numpy as np
2
4
3
5
from . import duck_array_ops
4
6
from .computation import dot
35
37
"""
36
38
37
39
_SUM_OF_WEIGHTS_DOCSTRING = """
38
- Calculate the sum of weights, accounting for missing values in the data
40
+ Calculate the sum of weights, accounting for missing values in the data.
39
41
40
42
Parameters
41
43
----------
@@ -177,13 +179,25 @@ def _sum_of_weights(
177
179
178
180
return sum_of_weights .where (valid_weights )
179
181
182
+ def _sum_of_squares (
183
+ self ,
184
+ da : "DataArray" ,
185
+ dim : Optional [Union [Hashable , Iterable [Hashable ]]] = None ,
186
+ skipna : Optional [bool ] = None ,
187
+ ) -> "DataArray" :
188
+ """Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s)."""
189
+
190
+ demeaned = da - da .weighted (self .weights ).mean (dim = dim )
191
+
192
+ return self ._reduce ((demeaned ** 2 ), self .weights , dim = dim , skipna = skipna )
193
+
180
194
def _weighted_sum (
181
195
self ,
182
196
da : "DataArray" ,
183
197
dim : Optional [Union [Hashable , Iterable [Hashable ]]] = None ,
184
198
skipna : Optional [bool ] = None ,
185
199
) -> "DataArray" :
186
- """Reduce a DataArray by a by a weighted ``sum`` along some dimension(s)."""
200
+ """Reduce a DataArray by a weighted ``sum`` along some dimension(s)."""
187
201
188
202
return self ._reduce (da , self .weights , dim = dim , skipna = skipna )
189
203
@@ -201,6 +215,30 @@ def _weighted_mean(
201
215
202
216
return weighted_sum / sum_of_weights
203
217
218
+ def _weighted_var (
219
+ self ,
220
+ da : "DataArray" ,
221
+ dim : Optional [Union [Hashable , Iterable [Hashable ]]] = None ,
222
+ skipna : Optional [bool ] = None ,
223
+ ) -> "DataArray" :
224
+ """Reduce a DataArray by a weighted ``var`` along some dimension(s)."""
225
+
226
+ sum_of_squares = self ._sum_of_squares (da , dim = dim , skipna = skipna )
227
+
228
+ sum_of_weights = self ._sum_of_weights (da , dim = dim )
229
+
230
+ return sum_of_squares / sum_of_weights
231
+
232
+ def _weighted_std (
233
+ self ,
234
+ da : "DataArray" ,
235
+ dim : Optional [Union [Hashable , Iterable [Hashable ]]] = None ,
236
+ skipna : Optional [bool ] = None ,
237
+ ) -> "DataArray" :
238
+ """Reduce a DataArray by a weighted ``std`` along some dimension(s)."""
239
+
240
+ return cast ("DataArray" , np .sqrt (self ._weighted_var (da , dim , skipna )))
241
+
204
242
def _implementation (self , func , dim , ** kwargs ):
205
243
206
244
raise NotImplementedError ("Use `Dataset.weighted` or `DataArray.weighted`" )
@@ -215,6 +253,17 @@ def sum_of_weights(
215
253
self ._sum_of_weights , dim = dim , keep_attrs = keep_attrs
216
254
)
217
255
256
+ def sum_of_squares (
257
+ self ,
258
+ dim : Optional [Union [Hashable , Iterable [Hashable ]]] = None ,
259
+ skipna : Optional [bool ] = None ,
260
+ keep_attrs : Optional [bool ] = None ,
261
+ ) -> T_Xarray :
262
+
263
+ return self ._implementation (
264
+ self ._sum_of_squares , dim = dim , skipna = skipna , keep_attrs = keep_attrs
265
+ )
266
+
218
267
def sum (
219
268
self ,
220
269
dim : Optional [Union [Hashable , Iterable [Hashable ]]] = None ,
@@ -237,6 +286,28 @@ def mean(
237
286
self ._weighted_mean , dim = dim , skipna = skipna , keep_attrs = keep_attrs
238
287
)
239
288
289
+ def var (
290
+ self ,
291
+ dim : Optional [Union [Hashable , Iterable [Hashable ]]] = None ,
292
+ skipna : Optional [bool ] = None ,
293
+ keep_attrs : Optional [bool ] = None ,
294
+ ) -> T_Xarray :
295
+
296
+ return self ._implementation (
297
+ self ._weighted_var , dim = dim , skipna = skipna , keep_attrs = keep_attrs
298
+ )
299
+
300
+ def std (
301
+ self ,
302
+ dim : Optional [Union [Hashable , Iterable [Hashable ]]] = None ,
303
+ skipna : Optional [bool ] = None ,
304
+ keep_attrs : Optional [bool ] = None ,
305
+ ) -> T_Xarray :
306
+
307
+ return self ._implementation (
308
+ self ._weighted_std , dim = dim , skipna = skipna , keep_attrs = keep_attrs
309
+ )
310
+
240
311
def __repr__ (self ):
241
312
"""provide a nice str repr of our Weighted object"""
242
313
@@ -275,6 +346,18 @@ def _inject_docstring(cls, cls_name):
275
346
cls = cls_name , fcn = "mean" , on_zero = "NaN"
276
347
)
277
348
349
+ cls .sum_of_squares .__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE .format (
350
+ cls = cls_name , fcn = "sum_of_squares" , on_zero = "0"
351
+ )
352
+
353
+ cls .var .__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE .format (
354
+ cls = cls_name , fcn = "var" , on_zero = "NaN"
355
+ )
356
+
357
+ cls .std .__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE .format (
358
+ cls = cls_name , fcn = "std" , on_zero = "NaN"
359
+ )
360
+
278
361
279
362
_inject_docstring (DataArrayWeighted , "DataArray" )
280
363
_inject_docstring (DatasetWeighted , "Dataset" )
0 commit comments