Skip to content

Commit 9fbb417

Browse files
max-sixtykeewis
andauthored
Allow where to receive a callable (#3827)
* allow where to receive a callable * Update xarray/core/common.py Co-Authored-By: keewis <[email protected]> * docstring * whatsnew Co-authored-by: keewis <[email protected]>
1 parent 00e5b36 commit 9fbb417

File tree

4 files changed

+40
-1
lines changed

4 files changed

+40
-1
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ New Features
4343
in 0.14.1) is now on by default. To disable, use
4444
``xarray.set_options(display_style="text")``.
4545
By `Julia Signell <https://github.com/jsignell>`_.
46-
46+
- :py:meth:`Dataset.where` and :py:meth:`DataArray.where` accept a lambda as a
47+
first argument, which is then called on the input; replicating pandas' behavior.
48+
By `Maximilian Roos <https://github.com/max-sixty>`_
4749

4850
Bug fixes
4951
~~~~~~~~~

xarray/core/common.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,15 @@ def where(self, cond, other=dtypes.NA, drop: bool = False):
11191119
11201120
>>> import numpy as np
11211121
>>> a = xr.DataArray(np.arange(25).reshape(5, 5), dims=('x', 'y'))
1122+
>>> a
1123+
<xarray.DataArray (x: 5, y: 5)>
1124+
array([[ 0, 1, 2, 3, 4],
1125+
[ 5, 6, 7, 8, 9],
1126+
[10, 11, 12, 13, 14],
1127+
[15, 16, 17, 18, 19],
1128+
[20, 21, 22, 23, 24]])
1129+
Dimensions without coordinates: x, y
1130+
11221131
>>> a.where(a.x + a.y < 4)
11231132
<xarray.DataArray (x: 5, y: 5)>
11241133
array([[ 0., 1., 2., 3., nan],
@@ -1127,6 +1136,7 @@ def where(self, cond, other=dtypes.NA, drop: bool = False):
11271136
[ 15., nan, nan, nan, nan],
11281137
[ nan, nan, nan, nan, nan]])
11291138
Dimensions without coordinates: x, y
1139+
11301140
>>> a.where(a.x + a.y < 5, -1)
11311141
<xarray.DataArray (x: 5, y: 5)>
11321142
array([[ 0, 1, 2, 3, 4],
@@ -1135,6 +1145,7 @@ def where(self, cond, other=dtypes.NA, drop: bool = False):
11351145
[15, 16, -1, -1, -1],
11361146
[20, -1, -1, -1, -1]])
11371147
Dimensions without coordinates: x, y
1148+
11381149
>>> a.where(a.x + a.y < 4, drop=True)
11391150
<xarray.DataArray (x: 4, y: 4)>
11401151
array([[ 0., 1., 2., 3.],
@@ -1143,6 +1154,14 @@ def where(self, cond, other=dtypes.NA, drop: bool = False):
11431154
[ 15., nan, nan, nan]])
11441155
Dimensions without coordinates: x, y
11451156
1157+
>>> a.where(lambda x: x.x + x.y < 4, drop=True)
1158+
<xarray.DataArray (x: 4, y: 4)>
1159+
array([[ 0., 1., 2., 3.],
1160+
[ 5., 6., 7., nan],
1161+
[ 10., 11., nan, nan],
1162+
[ 15., nan, nan, nan]])
1163+
Dimensions without coordinates: x, y
1164+
11461165
See also
11471166
--------
11481167
numpy.where : corresponding numpy function
@@ -1152,6 +1171,9 @@ def where(self, cond, other=dtypes.NA, drop: bool = False):
11521171
from .dataarray import DataArray
11531172
from .dataset import Dataset
11541173

1174+
if callable(cond):
1175+
cond = cond(self)
1176+
11551177
if drop:
11561178
if other is not dtypes.NA:
11571179
raise ValueError("cannot set `other` if drop=True")

xarray/tests/test_dataarray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2215,6 +2215,12 @@ def test_where(self):
22152215
actual = arr.where(arr.x < 2, drop=True)
22162216
assert_identical(actual, expected)
22172217

2218+
def test_where_lambda(self):
2219+
arr = DataArray(np.arange(4), dims="y")
2220+
expected = arr.sel(y=slice(2))
2221+
actual = arr.where(lambda x: x.y < 2, drop=True)
2222+
assert_identical(actual, expected)
2223+
22182224
def test_where_string(self):
22192225
array = DataArray(["a", "b"])
22202226
expected = DataArray(np.array(["a", np.nan], dtype=object))

xarray/tests/test_dataset.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4349,13 +4349,22 @@ def test_where(self):
43494349
assert actual.a.name == "a"
43504350
assert actual.a.attrs == ds.a.attrs
43514351

4352+
# lambda
4353+
ds = Dataset({"a": ("x", range(5))})
4354+
expected = Dataset({"a": ("x", [np.nan, np.nan, 2, 3, 4])})
4355+
actual = ds.where(lambda x: x > 1)
4356+
assert_identical(expected, actual)
4357+
43524358
def test_where_other(self):
43534359
ds = Dataset({"a": ("x", range(5))}, {"x": range(5)})
43544360
expected = Dataset({"a": ("x", [-1, -1, 2, 3, 4])}, {"x": range(5)})
43554361
actual = ds.where(ds > 1, -1)
43564362
assert_equal(expected, actual)
43574363
assert actual.a.dtype == int
43584364

4365+
actual = ds.where(lambda x: x > 1, -1)
4366+
assert_equal(expected, actual)
4367+
43594368
with raises_regex(ValueError, "cannot set"):
43604369
ds.where(ds > 1, other=0, drop=True)
43614370

0 commit comments

Comments
 (0)