Skip to content

Commit b3b77f5

Browse files
cjauvinIllviljan
andauthored
Add var and std to weighted computations (#5870)
Co-authored-by: Illviljan <[email protected]>
1 parent 7b93333 commit b3b77f5

File tree

5 files changed

+290
-8
lines changed

5 files changed

+290
-8
lines changed

doc/api.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,12 +779,18 @@ Weighted objects
779779

780780
core.weighted.DataArrayWeighted
781781
core.weighted.DataArrayWeighted.mean
782+
core.weighted.DataArrayWeighted.std
782783
core.weighted.DataArrayWeighted.sum
784+
core.weighted.DataArrayWeighted.sum_of_squares
783785
core.weighted.DataArrayWeighted.sum_of_weights
786+
core.weighted.DataArrayWeighted.var
784787
core.weighted.DatasetWeighted
785788
core.weighted.DatasetWeighted.mean
789+
core.weighted.DatasetWeighted.std
786790
core.weighted.DatasetWeighted.sum
791+
core.weighted.DatasetWeighted.sum_of_squares
787792
core.weighted.DatasetWeighted.sum_of_weights
793+
core.weighted.DatasetWeighted.var
788794

789795

790796
Coarsen objects

doc/user-guide/computation.rst

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ Weighted array reductions
263263

264264
:py:class:`DataArray` and :py:class:`Dataset` objects include :py:meth:`DataArray.weighted`
265265
and :py:meth:`Dataset.weighted` array reduction methods. They currently
266-
support weighted ``sum`` and weighted ``mean``.
266+
support weighted ``sum``, ``mean``, ``std`` and ``var``.
267267

268268
.. ipython:: python
269269
@@ -298,13 +298,27 @@ The weighted sum corresponds to:
298298
weighted_sum = (prec * weights).sum()
299299
weighted_sum
300300
301-
and the weighted mean to:
301+
the weighted mean to:
302302

303303
.. ipython:: python
304304
305305
weighted_mean = weighted_sum / weights.sum()
306306
weighted_mean
307307
308+
the weighted variance to:
309+
310+
.. ipython:: python
311+
312+
weighted_var = weighted_prec.sum_of_squares() / weights.sum()
313+
weighted_var
314+
315+
and the weighted standard deviation to:
316+
317+
.. ipython:: python
318+
319+
weighted_std = np.sqrt(weighted_var)
320+
weighted_std
321+
308322
However, the functions also take missing values in the data into account:
309323

310324
.. ipython:: python
@@ -327,7 +341,7 @@ If the weights add up to to 0, ``sum`` returns 0:
327341
328342
data.weighted(weights).sum()
329343
330-
and ``mean`` returns ``NaN``:
344+
and ``mean``, ``std`` and ``var`` return ``NaN``:
331345

332346
.. ipython:: python
333347

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ v0.19.1 (unreleased)
2323
2424
New Features
2525
~~~~~~~~~~~~
26+
- Add :py:meth:`var`, :py:meth:`std` and :py:meth:`sum_of_squares` to :py:meth:`Dataset.weighted` and :py:meth:`DataArray.weighted`.
27+
By `Christian Jauvin <https://github.com/cjauvin>`_.
2628
- Added a :py:func:`get_options` method to xarray's root namespace (:issue:`5698`, :pull:`5716`)
2729
By `Pushkar Kopparla <https://github.com/pkopparla>`_.
2830
- Xarray now does a better job rendering variable names that are long LaTeX sequences when plotting (:issue:`5681`, :pull:`5682`).

xarray/core/weighted.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, Union
1+
from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, Union, cast
2+
3+
import numpy as np
24

35
from . import duck_array_ops
46
from .computation import dot
@@ -35,7 +37,7 @@
3537
"""
3638

3739
_SUM_OF_WEIGHTS_DOCSTRING = """
38-
Calculate the sum of weights, accounting for missing values in the data
40+
Calculate the sum of weights, accounting for missing values in the data.
3941
4042
Parameters
4143
----------
@@ -177,13 +179,25 @@ def _sum_of_weights(
177179

178180
return sum_of_weights.where(valid_weights)
179181

182+
def _sum_of_squares(
183+
self,
184+
da: "DataArray",
185+
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
186+
skipna: Optional[bool] = None,
187+
) -> "DataArray":
188+
"""Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s)."""
189+
190+
demeaned = da - da.weighted(self.weights).mean(dim=dim)
191+
192+
return self._reduce((demeaned ** 2), self.weights, dim=dim, skipna=skipna)
193+
180194
def _weighted_sum(
181195
self,
182196
da: "DataArray",
183197
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
184198
skipna: Optional[bool] = None,
185199
) -> "DataArray":
186-
"""Reduce a DataArray by a by a weighted ``sum`` along some dimension(s)."""
200+
"""Reduce a DataArray by a weighted ``sum`` along some dimension(s)."""
187201

188202
return self._reduce(da, self.weights, dim=dim, skipna=skipna)
189203

@@ -201,6 +215,30 @@ def _weighted_mean(
201215

202216
return weighted_sum / sum_of_weights
203217

218+
def _weighted_var(
219+
self,
220+
da: "DataArray",
221+
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
222+
skipna: Optional[bool] = None,
223+
) -> "DataArray":
224+
"""Reduce a DataArray by a weighted ``var`` along some dimension(s)."""
225+
226+
sum_of_squares = self._sum_of_squares(da, dim=dim, skipna=skipna)
227+
228+
sum_of_weights = self._sum_of_weights(da, dim=dim)
229+
230+
return sum_of_squares / sum_of_weights
231+
232+
def _weighted_std(
233+
self,
234+
da: "DataArray",
235+
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
236+
skipna: Optional[bool] = None,
237+
) -> "DataArray":
238+
"""Reduce a DataArray by a weighted ``std`` along some dimension(s)."""
239+
240+
return cast("DataArray", np.sqrt(self._weighted_var(da, dim, skipna)))
241+
204242
def _implementation(self, func, dim, **kwargs):
205243

206244
raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`")
@@ -215,6 +253,17 @@ def sum_of_weights(
215253
self._sum_of_weights, dim=dim, keep_attrs=keep_attrs
216254
)
217255

256+
def sum_of_squares(
257+
self,
258+
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
259+
skipna: Optional[bool] = None,
260+
keep_attrs: Optional[bool] = None,
261+
) -> T_Xarray:
262+
263+
return self._implementation(
264+
self._sum_of_squares, dim=dim, skipna=skipna, keep_attrs=keep_attrs
265+
)
266+
218267
def sum(
219268
self,
220269
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
@@ -237,6 +286,28 @@ def mean(
237286
self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs
238287
)
239288

289+
def var(
290+
self,
291+
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
292+
skipna: Optional[bool] = None,
293+
keep_attrs: Optional[bool] = None,
294+
) -> T_Xarray:
295+
296+
return self._implementation(
297+
self._weighted_var, dim=dim, skipna=skipna, keep_attrs=keep_attrs
298+
)
299+
300+
def std(
301+
self,
302+
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
303+
skipna: Optional[bool] = None,
304+
keep_attrs: Optional[bool] = None,
305+
) -> T_Xarray:
306+
307+
return self._implementation(
308+
self._weighted_std, dim=dim, skipna=skipna, keep_attrs=keep_attrs
309+
)
310+
240311
def __repr__(self):
241312
"""provide a nice str repr of our Weighted object"""
242313

@@ -275,6 +346,18 @@ def _inject_docstring(cls, cls_name):
275346
cls=cls_name, fcn="mean", on_zero="NaN"
276347
)
277348

349+
cls.sum_of_squares.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
350+
cls=cls_name, fcn="sum_of_squares", on_zero="0"
351+
)
352+
353+
cls.var.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
354+
cls=cls_name, fcn="var", on_zero="NaN"
355+
)
356+
357+
cls.std.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
358+
cls=cls_name, fcn="std", on_zero="NaN"
359+
)
360+
278361

279362
_inject_docstring(DataArrayWeighted, "DataArray")
280363
_inject_docstring(DatasetWeighted, "Dataset")

0 commit comments

Comments
 (0)