Skip to content

Commit cdaac64

Browse files
aaronspringmax-sixtykeewis
authored
Implement skipna kwarg in xr.quantile (#3844)
* quick fix, no docs, no tests * added tests * docstrings * added whatsnew * Update doc/whats-new.rst Co-Authored-By: Maximilian Roos <[email protected]> * Update doc/whats-new.rst Co-Authored-By: keewis <[email protected]> Co-authored-by: Maximilian Roos <[email protected]> Co-authored-by: keewis <[email protected]>
1 parent 9fbb417 commit cdaac64

File tree

8 files changed

+69
-19
lines changed

8 files changed

+69
-19
lines changed

doc/whats-new.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,12 @@ New Features
4545
By `Julia Signell <https://github.com/jsignell>`_.
4646
- :py:meth:`Dataset.where` and :py:meth:`DataArray.where` accept a lambda as a
4747
first argument, which is then called on the input; replicating pandas' behavior.
48-
By `Maximilian Roos <https://github.com/max-sixty>`_
48+
By `Maximilian Roos <https://github.com/max-sixty>`_.
49+
- Implement ``skipna`` in :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile`,
50+
:py:meth:`core.groupby.DatasetGroupBy.quantile`, :py:meth:`core.groupby.DataArrayGroupBy.quantile`
51+
(:issue:`3843`, :pull:`3844`)
52+
By `Aaron Spring <https://github.com/aaronspring>`_.
53+
4954

5055
Bug fixes
5156
~~~~~~~~~

xarray/core/dataarray.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2939,6 +2939,7 @@ def quantile(
29392939
dim: Union[Hashable, Sequence[Hashable], None] = None,
29402940
interpolation: str = "linear",
29412941
keep_attrs: bool = None,
2942+
skipna: bool = True,
29422943
) -> "DataArray":
29432944
"""Compute the qth quantile of the data along the specified dimension.
29442945
@@ -2966,6 +2967,8 @@ def quantile(
29662967
If True, the dataset's attributes (`attrs`) will be copied from
29672968
the original object to the new one. If False (default), the new
29682969
object will be returned without attributes.
2970+
skipna : bool, optional
2971+
Whether to skip missing values when aggregating.
29692972
29702973
Returns
29712974
-------
@@ -2978,7 +2981,7 @@ def quantile(
29782981
29792982
See Also
29802983
--------
2981-
numpy.nanquantile, pandas.Series.quantile, Dataset.quantile
2984+
numpy.nanquantile, numpy.quantile, pandas.Series.quantile, Dataset.quantile
29822985
29832986
Examples
29842987
--------
@@ -3015,7 +3018,11 @@ def quantile(
30153018
"""
30163019

30173020
ds = self._to_temp_dataset().quantile(
3018-
q, dim=dim, keep_attrs=keep_attrs, interpolation=interpolation
3021+
q,
3022+
dim=dim,
3023+
keep_attrs=keep_attrs,
3024+
interpolation=interpolation,
3025+
skipna=skipna,
30193026
)
30203027
return self._from_temp_dataset(ds)
30213028

xarray/core/dataset.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5140,7 +5140,13 @@ def sortby(self, variables, ascending=True):
51405140
return aligned_self.isel(**indices)
51415141

51425142
def quantile(
5143-
self, q, dim=None, interpolation="linear", numeric_only=False, keep_attrs=None
5143+
self,
5144+
q,
5145+
dim=None,
5146+
interpolation="linear",
5147+
numeric_only=False,
5148+
keep_attrs=None,
5149+
skipna=True,
51445150
):
51455151
"""Compute the qth quantile of the data along the specified dimension.
51465152
@@ -5171,6 +5177,8 @@ def quantile(
51715177
object will be returned without attributes.
51725178
numeric_only : bool, optional
51735179
If True, only apply ``func`` to variables with a numeric dtype.
5180+
skipna : bool, optional
5181+
Whether to skip missing values when aggregating.
51745182
51755183
Returns
51765184
-------
@@ -5183,7 +5191,7 @@ def quantile(
51835191
51845192
See Also
51855193
--------
5186-
numpy.nanquantile, pandas.Series.quantile, DataArray.quantile
5194+
numpy.nanquantile, numpy.quantile, pandas.Series.quantile, DataArray.quantile
51875195
51885196
Examples
51895197
--------
@@ -5258,6 +5266,7 @@ def quantile(
52585266
dim=reduce_dims,
52595267
interpolation=interpolation,
52605268
keep_attrs=keep_attrs,
5269+
skipna=skipna,
52615270
)
52625271

52635272
else:

xarray/core/groupby.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,9 @@ def fillna(self, value):
558558
out = ops.fillna(self, value)
559559
return out
560560

561-
def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
561+
def quantile(
562+
self, q, dim=None, interpolation="linear", keep_attrs=None, skipna=True
563+
):
562564
"""Compute the qth quantile over each array in the groups and
563565
concatenate them together into a new array.
564566
@@ -582,6 +584,8 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
582584
* higher: ``j``.
583585
* nearest: ``i`` or ``j``, whichever is nearest.
584586
* midpoint: ``(i + j) / 2``.
587+
skipna : bool, optional
588+
Whether to skip missing values when aggregating.
585589
586590
Returns
587591
-------
@@ -595,7 +599,7 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
595599
596600
See Also
597601
--------
598-
numpy.nanquantile, pandas.Series.quantile, Dataset.quantile,
602+
numpy.nanquantile, numpy.quantile, pandas.Series.quantile, Dataset.quantile,
599603
DataArray.quantile
600604
601605
Examples
@@ -656,6 +660,7 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
656660
dim=dim,
657661
interpolation=interpolation,
658662
keep_attrs=keep_attrs,
663+
skipna=skipna,
659664
)
660665

661666
return out

xarray/core/variable.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,7 +1678,9 @@ def no_conflicts(self, other, equiv=duck_array_ops.array_notnull_equiv):
16781678
"""
16791679
return self.broadcast_equals(other, equiv=equiv)
16801680

1681-
def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
1681+
def quantile(
1682+
self, q, dim=None, interpolation="linear", keep_attrs=None, skipna=True
1683+
):
16821684
"""Compute the qth quantile of the data along the specified dimension.
16831685
16841686
Returns the qth quantiles(s) of the array elements.
@@ -1725,6 +1727,8 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
17251727

17261728
from .computation import apply_ufunc
17271729

1730+
_quantile_func = np.nanquantile if skipna else np.quantile
1731+
17281732
if keep_attrs is None:
17291733
keep_attrs = _get_keep_attrs(default=False)
17301734

@@ -1739,7 +1743,7 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
17391743

17401744
def _wrapper(npa, **kwargs):
17411745
# move quantile axis to end. required for apply_ufunc
1742-
return np.moveaxis(np.nanquantile(npa, **kwargs), 0, -1)
1746+
return np.moveaxis(_quantile_func(npa, **kwargs), 0, -1)
17431747

17441748
axis = np.arange(-1, -1 * len(dim) - 1, -1)
17451749
result = apply_ufunc(

xarray/tests/test_dataarray.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2368,13 +2368,15 @@ def test_reduce_out(self):
23682368
with pytest.raises(TypeError):
23692369
orig.mean(out=np.ones(orig.shape))
23702370

2371+
@pytest.mark.parametrize("skipna", [True, False])
23712372
@pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
23722373
@pytest.mark.parametrize(
23732374
"axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]])
23742375
)
2375-
def test_quantile(self, q, axis, dim):
2376-
actual = DataArray(self.va).quantile(q, dim=dim, keep_attrs=True)
2377-
expected = np.nanpercentile(self.dv.values, np.array(q) * 100, axis=axis)
2376+
def test_quantile(self, q, axis, dim, skipna):
2377+
actual = DataArray(self.va).quantile(q, dim=dim, keep_attrs=True, skipna=skipna)
2378+
_percentile_func = np.nanpercentile if skipna else np.percentile
2379+
expected = _percentile_func(self.dv.values, np.array(q) * 100, axis=axis)
23782380
np.testing.assert_allclose(actual.values, expected)
23792381
if is_scalar(q):
23802382
assert "quantile" not in actual.dims

xarray/tests/test_dataset.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4697,25 +4697,41 @@ def test_reduce_keepdims(self):
46974697
)
46984698
assert_identical(expected, actual)
46994699

4700+
@pytest.mark.parametrize("skipna", [True, False])
47004701
@pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
4701-
def test_quantile(self, q):
4702+
def test_quantile(self, q, skipna):
47024703
ds = create_test_data(seed=123)
47034704

47044705
for dim in [None, "dim1", ["dim1"]]:
4705-
ds_quantile = ds.quantile(q, dim=dim)
4706+
ds_quantile = ds.quantile(q, dim=dim, skipna=skipna)
47064707
if is_scalar(q):
47074708
assert "quantile" not in ds_quantile.dims
47084709
else:
47094710
assert "quantile" in ds_quantile.dims
47104711

47114712
for var, dar in ds.data_vars.items():
47124713
assert var in ds_quantile
4713-
assert_identical(ds_quantile[var], dar.quantile(q, dim=dim))
4714+
assert_identical(
4715+
ds_quantile[var], dar.quantile(q, dim=dim, skipna=skipna)
4716+
)
47144717
dim = ["dim1", "dim2"]
4715-
ds_quantile = ds.quantile(q, dim=dim)
4718+
ds_quantile = ds.quantile(q, dim=dim, skipna=skipna)
47164719
assert "dim3" in ds_quantile.dims
47174720
assert all(d not in ds_quantile.dims for d in dim)
47184721

4722+
@pytest.mark.parametrize("skipna", [True, False])
4723+
def test_quantile_skipna(self, skipna):
4724+
q = 0.1
4725+
dim = "time"
4726+
ds = Dataset({"a": ([dim], np.arange(0, 11))})
4727+
ds = ds.where(ds >= 1)
4728+
4729+
result = ds.quantile(q=q, dim=dim, skipna=skipna)
4730+
4731+
value = 1.9 if skipna else np.nan
4732+
expected = Dataset({"a": value}, coords={"quantile": q})
4733+
assert_identical(result, expected)
4734+
47194735
@requires_bottleneck
47204736
def test_rank(self):
47214737
ds = create_test_data(seed=1234)

xarray/tests/test_variable.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,14 +1511,16 @@ def test_reduce(self):
15111511
with pytest.warns(DeprecationWarning, match="allow_lazy is deprecated"):
15121512
v.mean(dim="x", allow_lazy=False)
15131513

1514+
@pytest.mark.parametrize("skipna", [True, False])
15141515
@pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
15151516
@pytest.mark.parametrize(
15161517
"axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]])
15171518
)
1518-
def test_quantile(self, q, axis, dim):
1519+
def test_quantile(self, q, axis, dim, skipna):
15191520
v = Variable(["x", "y"], self.d)
1520-
actual = v.quantile(q, dim=dim)
1521-
expected = np.nanpercentile(self.d, np.array(q) * 100, axis=axis)
1521+
actual = v.quantile(q, dim=dim, skipna=skipna)
1522+
_percentile_func = np.nanpercentile if skipna else np.percentile
1523+
expected = _percentile_func(self.d, np.array(q) * 100, axis=axis)
15221524
np.testing.assert_allclose(actual.values, expected)
15231525

15241526
@requires_dask

0 commit comments

Comments
 (0)