Skip to content

Commit 6e884aa

Browse files
committed
implement additional request checks
1 parent 3e6b71f commit 6e884aa

File tree

3 files changed

+92
-31
lines changed

3 files changed

+92
-31
lines changed

xarray/core/computation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,7 +1322,7 @@ def _calc_idxminmax(
13221322
"""Apply common operations for idxmin and idxmax."""
13231323
# This function doesn't make sense for scalars so don't try
13241324
if not array.ndim:
1325-
ValueError("This function does not apply for scalars")
1325+
raise ValueError("This function does not apply for scalars")
13261326

13271327
if dim is not None:
13281328
pass # Use the dim if available
@@ -1333,13 +1333,13 @@ def _calc_idxminmax(
13331333
# The dim is not specified and ambiguous. Don't guess.
13341334
raise ValueError("Must supply 'dim' argument for multidimensional arrays")
13351335

1336-
if dim in array.coords:
1337-
pass # This is okay
1338-
else:
1336+
if dim not in array.dims:
1337+
raise KeyError(f'Dimension "{dim}" not in dimension')
1338+
if dim not in array.coords:
13391339
raise KeyError(f'Dimension "{dim}" does not have coordinates')
13401340

13411341
# These are dtypes with NaN values argmin and argmax can handle
1342-
na_dtypes = "cf0"
1342+
na_dtypes = "cfO"
13431343

13441344
if skipna or (skipna is None and array.dtype.kind in na_dtypes):
13451345
# Need to skip NaN values since argmin and argmax can't handle them

xarray/core/dataset.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections import defaultdict
77
from html import escape
88
from numbers import Number
9+
from operator import methodcaller
910
from pathlib import Path
1011
from typing import (
1112
TYPE_CHECKING,
@@ -6003,12 +6004,14 @@ def idxmin(
60036004
DataArray.idxmin, Dataset.idxmax, Dataset.min, Dataset.argmin
60046005
"""
60056006
return self.map(
6006-
"idxmin",
6007-
dim=dim,
6008-
skipna=skipna,
6009-
promote=promote,
6010-
keep_attrs=keep_attrs,
6011-
**kwargs,
6007+
methodcaller(
6008+
"idxmin",
6009+
dim=dim,
6010+
skipna=skipna,
6011+
promote=promote,
6012+
keep_attrs=keep_attrs,
6013+
**kwargs,
6014+
),
60126015
)
60136016

60146017
def idxmax(
@@ -6062,12 +6065,14 @@ def idxmax(
60626065
DataArray.idxmax, Dataset.idxmin, Dataset.max, Dataset.argmax
60636066
"""
60646067
return self.map(
6065-
"idxmax",
6066-
dim=dim,
6067-
skipna=skipna,
6068-
promote=promote,
6069-
keep_attrs=keep_attrs,
6070-
**kwargs,
6068+
methodcaller(
6069+
"idxmax",
6070+
dim=dim,
6071+
skipna=skipna,
6072+
promote=promote,
6073+
keep_attrs=keep_attrs,
6074+
**kwargs,
6075+
),
60716076
)
60726077

60736078

xarray/tests/test_dataarray.py

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4298,6 +4298,12 @@ def setup(self):
42984298
(np.array([0, 1, 2, 0, -2, -4, 2]), 5, 2, None),
42994299
(np.array([0.0, 1.0, 2.0, 0.0, -2.0, -4.0, 2.0]), 5, 2, None),
43004300
(np.array([1.0, np.NaN, 2.0, np.NaN, -2.0, -4.0, 2.0]), 5, 2, 1),
4301+
(
4302+
np.array([1.0, np.NaN, 2.0, np.NaN, -2.0, -4.0, 2.0]).astype("object"),
4303+
5,
4304+
2,
4305+
1,
4306+
),
43014307
(np.array([np.NaN, np.NaN]), np.NaN, np.NaN, 0),
43024308
(
43034309
np.array(
@@ -4329,7 +4335,7 @@ def test_min(self, x, minindex, maxindex, nanindex):
43294335
assert_identical(result1, expected1)
43304336

43314337
result2 = ar.min(skipna=False)
4332-
if nanindex is not None:
4338+
if nanindex is not None and ar.dtype.kind != "O":
43334339
expected2 = ar.isel(x=nanindex, drop=True)
43344340
expected2.attrs = {}
43354341
else:
@@ -4355,7 +4361,7 @@ def test_max(self, x, minindex, maxindex, nanindex):
43554361
assert_identical(result1, expected1)
43564362

43574363
result2 = ar.max(skipna=False)
4358-
if nanindex is not None:
4364+
if nanindex is not None and ar.dtype.kind != "O":
43594365
expected2 = ar.isel(x=nanindex, drop=True)
43604366
expected2.attrs = {}
43614367
else:
@@ -4384,7 +4390,7 @@ def test_argmin(self, x, minindex, maxindex, nanindex):
43844390
assert_identical(result1, expected1)
43854391

43864392
result2 = ar.argmin(skipna=False)
4387-
if nanindex is not None:
4393+
if nanindex is not None and ar.dtype.kind != "O":
43884394
expected2 = indarr.isel(x=nanindex, drop=True)
43894395
expected2.attrs = {}
43904396
else:
@@ -4413,7 +4419,7 @@ def test_argmax(self, x, minindex, maxindex, nanindex):
44134419
assert_identical(result1, expected1)
44144420

44154421
result2 = ar.argmax(skipna=False)
4416-
if nanindex is not None:
4422+
if nanindex is not None and ar.dtype.kind != "O":
44174423
expected2 = indarr.isel(x=nanindex, drop=True)
44184424
expected2.attrs = {}
44194425
else:
@@ -4448,7 +4454,7 @@ def test_idxmin(self, x, minindex, maxindex, nanindex):
44484454
assert_identical(result2, expected1)
44494455

44504456
result3 = ar.idxmin(skipna=False)
4451-
if nanindex is not None:
4457+
if nanindex is not None and ar.dtype.kind != "O":
44524458
expected2 = coordarr.isel(x=nanindex, drop=True).astype("float")
44534459
expected2.name = "x"
44544460
expected2.attrs = {}
@@ -4461,7 +4467,7 @@ def test_idxmin(self, x, minindex, maxindex, nanindex):
44614467
with pytest.raises(TypeError):
44624468
ar.idxmin(promote=False)
44634469

4464-
if nanindex is None:
4470+
if nanindex is None or ar.dtype.kind == "O":
44654471
expected3 = coordarr.isel(x=minindex, drop=True)
44664472
expected3.name = "x"
44674473

@@ -4480,6 +4486,13 @@ def test_idxmin(self, x, minindex, maxindex, nanindex):
44804486
result7 = ar.idxmin(skipna=False, promote=None)
44814487
assert_identical(result7, expected4)
44824488

4489+
with pytest.raises(KeyError):
4490+
ar.idxmin(dim="spam")
4491+
4492+
ar2 = xr.DataArray(5)
4493+
with pytest.raises(ValueError):
4494+
ar2.idxmin()
4495+
44834496
def test_idxmax(self, x, minindex, maxindex, nanindex):
44844497
ar = xr.DataArray(
44854498
x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs,
@@ -4507,7 +4520,7 @@ def test_idxmax(self, x, minindex, maxindex, nanindex):
45074520
assert_identical(result2, expected1)
45084521

45094522
result3 = ar.idxmax(skipna=False)
4510-
if nanindex is not None:
4523+
if nanindex is not None and ar.dtype.kind != "O":
45114524
expected2 = coordarr.isel(x=nanindex, drop=True).astype("float")
45124525
expected2.name = "x"
45134526
expected2.attrs = {}
@@ -4529,7 +4542,7 @@ def test_idxmax(self, x, minindex, maxindex, nanindex):
45294542
result5 = ar2.idxmax(promote=None)
45304543
assert_identical(result5, expected1)
45314544

4532-
if nanindex is None:
4545+
if nanindex is None or ar.dtype.kind == "O":
45334546
expected3 = coordarr.isel(x=maxindex, drop=True)
45344547
expected3.name = "x"
45354548

@@ -4548,6 +4561,13 @@ def test_idxmax(self, x, minindex, maxindex, nanindex):
45484561
result9 = ar.idxmax(skipna=False, promote=None)
45494562
assert_identical(result9, expected4)
45504563

4564+
with pytest.raises(KeyError):
4565+
ar.idxmin(dim="spam")
4566+
4567+
ar2 = xr.DataArray(5)
4568+
with pytest.raises(ValueError):
4569+
ar2.idxmin()
4570+
45514571

45524572
@pytest.mark.parametrize(
45534573
"x, minindex, maxindex, nanindex",
@@ -4576,6 +4596,18 @@ def test_idxmax(self, x, minindex, maxindex, nanindex):
45764596
[0, 2, np.NaN],
45774597
[None, 1, 0],
45784598
),
4599+
(
4600+
np.array(
4601+
[
4602+
[2.0, 1.0, 2.0, 0.0, -2.0, -4.0, 2.0],
4603+
[-4.0, np.NaN, 2.0, np.NaN, -2.0, -4.0, 2.0],
4604+
[np.NaN] * 7,
4605+
]
4606+
).astype("object"),
4607+
[5, 0, np.NaN],
4608+
[0, 2, np.NaN],
4609+
[None, 1, 0],
4610+
),
45794611
(
45804612
np.array(
45814613
[
@@ -4617,7 +4649,10 @@ def test_min(self, x, minindex, maxindex, nanindex):
46174649
result2 = ar.min(axis=1)
46184650
assert_identical(result2, expected1)
46194651

4620-
minindex = [x if y is None else y for x, y in zip(minindex, nanindex)]
4652+
minindex = [
4653+
x if y is None or ar.dtype.kind == "O" else y
4654+
for x, y in zip(minindex, nanindex)
4655+
]
46214656
expected2 = [
46224657
ar.isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(minindex)
46234658
]
@@ -4653,7 +4688,10 @@ def test_max(self, x, minindex, maxindex, nanindex):
46534688
result2 = ar.max(axis=1)
46544689
assert_identical(result2, expected1)
46554690

4656-
maxindex = [x if y is None else y for x, y in zip(maxindex, nanindex)]
4691+
maxindex = [
4692+
x if y is None or ar.dtype.kind == "O" else y
4693+
for x, y in zip(maxindex, nanindex)
4694+
]
46574695
expected2 = [
46584696
ar.isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(maxindex)
46594697
]
@@ -4696,7 +4734,10 @@ def test_argmin(self, x, minindex, maxindex, nanindex):
46964734
expected1.attrs = self.attrs
46974735
assert_identical(result2, expected1)
46984736

4699-
minindex = [x if y is None else y for x, y in zip(minindex, nanindex)]
4737+
minindex = [
4738+
x if y is None or ar.dtype.kind == "O" else y
4739+
for x, y in zip(minindex, nanindex)
4740+
]
47004741
expected2 = [
47014742
indarr.isel(y=yi).isel(x=indi, drop=True)
47024743
for yi, indi in enumerate(minindex)
@@ -4740,7 +4781,10 @@ def test_argmax(self, x, minindex, maxindex, nanindex):
47404781
expected1.attrs = self.attrs
47414782
assert_identical(result2, expected1)
47424783

4743-
maxindex = [x if y is None else y for x, y in zip(maxindex, nanindex)]
4784+
maxindex = [
4785+
x if y is None or ar.dtype.kind == "O" else y
4786+
for x, y in zip(maxindex, nanindex)
4787+
]
47444788
expected2 = [
47454789
indarr.isel(y=yi).isel(x=indi, drop=True)
47464790
for yi, indi in enumerate(maxindex)
@@ -4782,7 +4826,10 @@ def test_idxmin(self, x, minindex, maxindex, nanindex):
47824826
expected1.attrs = self.attrs
47834827
assert_identical(result1, expected1)
47844828

4785-
minindex = [x if y is None else y for x, y in zip(minindex, nanindex)]
4829+
minindex = [
4830+
x if y is None or ar.dtype.kind == "O" else y
4831+
for x, y in zip(minindex, nanindex)
4832+
]
47864833
expected2 = [
47874834
coordarr.isel(y=yi).isel(x=indi, drop=True)
47884835
for yi, indi in enumerate(minindex)
@@ -4835,6 +4882,9 @@ def test_idxmin(self, x, minindex, maxindex, nanindex):
48354882
with pytest.raises(KeyError):
48364883
ar.idxmin(dim="y")
48374884

4885+
with pytest.raises(KeyError):
4886+
ar.idxmax(dim="spam")
4887+
48384888
def test_idxmax(self, x, minindex, maxindex, nanindex):
48394889
ar = xr.DataArray(
48404890
x,
@@ -4865,7 +4915,10 @@ def test_idxmax(self, x, minindex, maxindex, nanindex):
48654915
expected1.attrs = self.attrs
48664916
assert_identical(result1, expected1)
48674917

4868-
maxindex = [x if y is None else y for x, y in zip(maxindex, nanindex)]
4918+
maxindex = [
4919+
x if y is None or ar.dtype.kind == "O" else y
4920+
for x, y in zip(maxindex, nanindex)
4921+
]
48694922
expected2 = [
48704923
coordarr.isel(y=yi).isel(x=indi, drop=True)
48714924
for yi, indi in enumerate(maxindex)
@@ -4918,6 +4971,9 @@ def test_idxmax(self, x, minindex, maxindex, nanindex):
49184971
with pytest.raises(KeyError):
49194972
ar.idxmax(dim="y")
49204973

4974+
with pytest.raises(KeyError):
4975+
ar.idxmax(dim="spam")
4976+
49214977

49224978
@pytest.fixture(params=[1])
49234979
def da(request):

0 commit comments

Comments
 (0)