Skip to content

Commit 925b5c8

Browse files
committed
implement idxmax and idxmin
1 parent 3f34b95 commit 925b5c8

File tree

5 files changed

+983
-0
lines changed

5 files changed

+983
-0
lines changed

doc/api.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ Computation
178178
:py:attr:`~Dataset.any`
179179
:py:attr:`~Dataset.argmax`
180180
:py:attr:`~Dataset.argmin`
181+
:py:attr:`~Dataset.idxmax`
182+
:py:attr:`~Dataset.idxmin`
181183
:py:attr:`~Dataset.max`
182184
:py:attr:`~Dataset.mean`
183185
:py:attr:`~Dataset.median`
@@ -359,6 +361,8 @@ Computation
359361
:py:attr:`~DataArray.any`
360362
:py:attr:`~DataArray.argmax`
361363
:py:attr:`~DataArray.argmin`
364+
:py:attr:`~DataArray.idxmax`
365+
:py:attr:`~DataArray.idxmin`
362366
:py:attr:`~DataArray.max`
363367
:py:attr:`~DataArray.mean`
364368
:py:attr:`~DataArray.median`

xarray/core/dataarray.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@
5555
from .indexes import Indexes, default_indexes, propagate_indexes
5656
from .indexing import is_fancy_indexer
5757
from .merge import PANDAS_TYPES, _extract_indexes_from_coords
58+
from .nanops import dask_array
5859
from .options import OPTIONS
60+
from .pycompat import dask_array_type
5961
from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs
6062
from .variable import (
6163
IndexVariable,
@@ -3430,6 +3432,225 @@ def pad(
34303432
)
34313433
return self._from_temp_dataset(ds)
34323434

3435+
def _calc_idxminmax(
3436+
self,
3437+
*,
3438+
func: str,
3439+
dim: Optional[Hashable],
3440+
axis: Optional[int],
3441+
skipna: Optional[bool],
3442+
promote: Optional[bool],
3443+
keep_attrs: Optional[bool],
3444+
**kwargs: Any,
3445+
) -> "DataArray":
3446+
"""Apply common operations for idxmin and idxmax."""
3447+
# This function doesn't make sense for scalars so don't try
3448+
if not self.ndim:
3449+
ValueError("This function does not apply for scalars")
3450+
3451+
if dim is not None:
3452+
pass # Use the dim if available
3453+
elif axis is not None:
3454+
dim = self.dims[axis]
3455+
elif self.ndim == 1:
3456+
# it is okay to guess the dim if there is only 1
3457+
dim = self.dims[0]
3458+
else:
3459+
# The dim is not specified and ambiguous. Don't guess.
3460+
raise ValueError(
3461+
"Must supply either 'dim' or 'axis' argument "
3462+
"for multidimensional arrays"
3463+
)
3464+
3465+
if dim in self.coords:
3466+
pass # This is okay
3467+
elif axis is not None:
3468+
raise IndexError(f'Axis "{axis}" does not have coordinates')
3469+
else:
3470+
raise KeyError(f'Dimension "{dim}" does not have coordinates')
3471+
3472+
# These are dtypes with NaN values argmin and argmax can handle
3473+
na_dtypes = "cf0"
3474+
3475+
if skipna or (skipna is None and self.dtype.kind in na_dtypes):
3476+
# Need to skip NaN values since argmin and argmax can't handle them
3477+
allna = self.isnull().all(dim)
3478+
array = self.where(~allna, 0)
3479+
hasna = allna.any()
3480+
else:
3481+
array = self
3482+
allna = None
3483+
hasna = False
3484+
3485+
# If promote is None we only promote if there are NaN values.
3486+
if promote is None:
3487+
promote = hasna
3488+
3489+
if not promote and hasna and array.coords[dim].dtype.kind not in na_dtypes:
3490+
raise TypeError(
3491+
"NaN values present for NaN-incompatible dtype and Promote=False"
3492+
)
3493+
3494+
# This will run argmin or argmax.
3495+
indx = getattr(array, func)(
3496+
dim=dim, axis=None, keep_attrs=False, skipna=skipna, **kwargs
3497+
)
3498+
3499+
# Get the coordinate we want.
3500+
coordarray = array[dim]
3501+
3502+
# Handle dask arrays.
3503+
if isinstance(array, dask_array_type):
3504+
res = dask_array.map_blocks(coordarray, indx, dtype=indx.dtype)
3505+
else:
3506+
res = coordarray[
3507+
indx,
3508+
]
3509+
3510+
# Promote to a dtype that can handle NaN values if needed.
3511+
newdtype, fill_value = dtypes.maybe_promote(res.dtype)
3512+
if promote and newdtype != res.dtype:
3513+
res = res.astype(newdtype)
3514+
3515+
# Put the NaN values back in after removing them, if necessary.
3516+
if hasna and allna is not None:
3517+
res = res.where(~allna, fill_value)
3518+
3519+
# The dim is gone but we need to remove the corresponding coordinate.
3520+
del res.coords[dim]
3521+
3522+
# Put the attrs back in if needed
3523+
if keep_attrs:
3524+
res.attrs = self.attrs
3525+
3526+
return res
3527+
3528+
def idxmin(
3529+
self,
3530+
dim: Optional[Hashable] = None,
3531+
axis: Optional[int] = None,
3532+
skipna: Optional[bool] = None,
3533+
promote: Optional[bool] = None,
3534+
keep_attrs: Optional[bool] = False,
3535+
**kwargs: Any,
3536+
) -> "DataArray":
3537+
"""Return the coordinate of the minimum value along a dimension.
3538+
3539+
Returns a new DataArray named after the dimension with the values of
3540+
the coordinate along that dimension corresponding to minimum value
3541+
along that dimension.
3542+
3543+
In comparison to `argmin`, this returns the coordinate while `argmin`
3544+
returns the index.
3545+
3546+
Parameters
3547+
----------
3548+
dim : str (optional)
3549+
Dimension over which to apply `idxmin`.
3550+
axis : int (optional)
3551+
Axis(es) over which to repeatedly apply `idxmin`. Exactly one of
3552+
the 'dim' and 'axis' arguments must be supplied.
3553+
skipna : bool, optional
3554+
If True, skip missing values (as marked by NaN). By default, only
3555+
skips missing values for float dtypes; other dtypes either do not
3556+
have a sentinel missing value (int) or skipna=True has not been
3557+
implemented (object, datetime64 or timedelta64).{min_count_docs}
3558+
promote : bool (optional)
3559+
If True (default) dtypes that do not support NaN values will be
3560+
automatically promoted to those that do. If False a NaN in the
3561+
results will raise a TypeError. If None the result will only be
3562+
promoted if a NaN is actually present.
3563+
keep_attrs : bool, optional
3564+
If True, the attributes (`attrs`) will be copied from the original
3565+
object to the new one. If False (default), the new object will be
3566+
returned without attributes.
3567+
**kwargs : dict
3568+
Additional keyword arguments passed on to the appropriate array
3569+
function for calculating `{name}` on this object's data.
3570+
3571+
Returns
3572+
-------
3573+
reduced : DataArray
3574+
New DataArray object with `idxmin` applied to its data and the
3575+
indicated dimension removed.
3576+
3577+
See also
3578+
--------
3579+
Dataset.idxmin, DataArray.idxmax, DataArray.min, DataArray.argmin
3580+
"""
3581+
return self._calc_idxminmax(
3582+
func="argmin",
3583+
dim=dim,
3584+
axis=axis,
3585+
skipna=skipna,
3586+
promote=promote,
3587+
keep_attrs=keep_attrs,
3588+
**kwargs,
3589+
)
3590+
3591+
def idxmax(
3592+
self,
3593+
dim: Optional[Hashable] = None,
3594+
axis: Optional[int] = None,
3595+
skipna: Optional[bool] = None,
3596+
promote: Optional[bool] = None,
3597+
keep_attrs: Optional[bool] = False,
3598+
**kwargs: Any,
3599+
) -> "DataArray":
3600+
"""Return the coordinate of the maximum value along a dimension.
3601+
3602+
Returns a new DataArray named after the dimension with the values of
3603+
the coordinate along that dimension corresponding to maximum value
3604+
along that dimension.
3605+
3606+
In comparison to `argmax`, this returns the coordinate while `argmax`
3607+
returns the index.
3608+
3609+
Parameters
3610+
----------
3611+
dim : str (optional)
3612+
Dimension over which to apply `idxmax`.
3613+
axis : int (optional)
3614+
Axis(es) over which to repeatedly apply `idxmax`. Exactly one of
3615+
the 'dim' and 'axis' arguments must be supplied.
3616+
skipna : bool (optional)
3617+
If True, skip missing values (as marked by NaN). By default, only
3618+
skips missing values for float dtypes; other dtypes either do not
3619+
have a sentinel missing value (int) or skipna=True has not been
3620+
implemented (object, datetime64 or timedelta64).{min_count_docs}
3621+
promote : bool (optional)
3622+
If True (default) dtypes that do not support NaN values will be
3623+
automatically promoted to those that do. If False a NaN in the
3624+
results will raise a TypeError. If None the result will only be
3625+
promoted if a NaN is actually present.
3626+
keep_attrs : bool (optional)
3627+
If True, the attributes (`attrs`) will be copied from the original
3628+
object to the new one. If False (default), the new object will be
3629+
returned without attributes.
3630+
**kwargs : dict
3631+
Additional keyword arguments passed on to the appropriate array
3632+
function for calculating `{name}` on this object's data.
3633+
3634+
Returns
3635+
-------
3636+
reduced : DataArray
3637+
New DataArray object with `idxmax` applied to its data and the
3638+
indicated dimension removed.
3639+
3640+
See also
3641+
--------
3642+
Dataset.idxmax, DataArray.idxmin, DataArray.max, DataArray.argmax
3643+
"""
3644+
return self._calc_idxminmax(
3645+
func="argmax",
3646+
dim=dim,
3647+
axis=axis,
3648+
skipna=skipna,
3649+
promote=promote,
3650+
keep_attrs=keep_attrs,
3651+
**kwargs,
3652+
)
3653+
34333654
# this needs to be at the end, or mypy will confuse with `str`
34343655
# https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names
34353656
str = property(StringAccessor)

xarray/core/dataset.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5921,5 +5921,131 @@ def pad(
59215921

59225922
return self._replace_vars_and_dims(variables)
59235923

5924+
def idxmin(
5925+
self,
5926+
dim: Optional[Hashable] = None,
5927+
axis: Optional[int] = None,
5928+
skipna: Optional[bool] = None,
5929+
promote: Optional[bool] = None,
5930+
keep_attrs: Optional[bool] = False,
5931+
**kwargs: Any,
5932+
) -> "Dataset":
5933+
"""Return the coordinate of the minimum value along a dimension.
5934+
5935+
Returns a new Dataset named after the dimension with the values of
5936+
the coordinate along that dimension corresponding to minimum value
5937+
along that dimension.
5938+
5939+
In comparison to `argmin`, this returns the coordinate while `argmin`
5940+
returns the index.
5941+
5942+
Parameters
5943+
----------
5944+
dim : str (optional)
5945+
Dimension over which to apply `idxmin`.
5946+
axis : int (optional)
5947+
Axis(es) over which to repeatedly apply `idxmin`. Exactly one of
5948+
the 'dim' and 'axis' arguments must be supplied.
5949+
skipna : bool, optional
5950+
If True, skip missing values (as marked by NaN). By default, only
5951+
skips missing values for float dtypes; other dtypes either do not
5952+
have a sentinel missing value (int) or skipna=True has not been
5953+
implemented (object, datetime64 or timedelta64).{min_count_docs}
5954+
promote : bool (optional)
5955+
If True (default) dtypes that do not support NaN values will be
5956+
automatically promoted to those that do. If False a NaN in the
5957+
results will raise a TypeError. If None the result will only be
5958+
promoted if a NaN is actually present.
5959+
keep_attrs : bool, optional
5960+
If True, the attributes (`attrs`) will be copied from the original
5961+
object to the new one. If False (default), the new object will be
5962+
returned without attributes.
5963+
**kwargs : dict
5964+
Additional keyword arguments passed on to the appropriate array
5965+
function for calculating `{name}` on this object's data.
5966+
5967+
Returns
5968+
-------
5969+
reduced : Dataset
5970+
New Dataset object with `idxmin` applied to its data and the
5971+
indicated dimension removed.
5972+
5973+
See also
5974+
--------
5975+
DataArray.idxmin, Dataset.idxmax, Dataset.min, Dataset.argmin
5976+
"""
5977+
return self.map(
5978+
"idxmin",
5979+
dim=dim,
5980+
axis=axis,
5981+
skipna=skipna,
5982+
promote=promote,
5983+
keep_attrs=keep_attrs,
5984+
**kwargs,
5985+
)
5986+
5987+
def idxmax(
5988+
self,
5989+
dim: Optional[Hashable] = None,
5990+
axis: Optional[int] = None,
5991+
skipna: Optional[bool] = None,
5992+
promote: Optional[bool] = None,
5993+
keep_attrs: Optional[bool] = False,
5994+
**kwargs: Any,
5995+
) -> "Dataset":
5996+
"""Return the coordinate of the maximum value along a dimension.
5997+
5998+
Returns a new Dataset named after the dimension with the values of
5999+
the coordinate along that dimension corresponding to maximum value
6000+
along that dimension.
6001+
6002+
In comparison to `argmax`, this returns the coordinate while `argmax`
6003+
returns the index.
6004+
6005+
Parameters
6006+
----------
6007+
dim : str (optional)
6008+
Dimension over which to apply `idxmax`.
6009+
axis : int (optional)
6010+
Axis(es) over which to repeatedly apply `idxmax`. Exactly one of
6011+
the 'dim' and 'axis' arguments must be supplied.
6012+
skipna : bool (optional)
6013+
If True, skip missing values (as marked by NaN). By default, only
6014+
skips missing values for float dtypes; other dtypes either do not
6015+
have a sentinel missing value (int) or skipna=True has not been
6016+
implemented (object, datetime64 or timedelta64).{min_count_docs}
6017+
promote : bool (optional)
6018+
If True (default) dtypes that do not support NaN values will be
6019+
automatically promoted to those that do. If False a NaN in the
6020+
results will raise a TypeError. If None the result will only be
6021+
promoted if a NaN is actually present.
6022+
keep_attrs : bool (optional)
6023+
If True, the attributes (`attrs`) will be copied from the original
6024+
object to the new one. If False (default), the new object will be
6025+
returned without attributes.
6026+
**kwargs : dict
6027+
Additional keyword arguments passed on to the appropriate array
6028+
function for calculating `{name}` on this object's data.
6029+
6030+
Returns
6031+
-------
6032+
reduced : Dataset
6033+
New Dataset object with `idxmax` applied to its data and the
6034+
indicated dimension removed.
6035+
6036+
See also
6037+
--------
6038+
DataArray.idxmax, Dataset.idxmin, Dataset.max, Dataset.argmax
6039+
"""
6040+
return self.map(
6041+
"idxmax",
6042+
dim=dim,
6043+
axis=axis,
6044+
skipna=skipna,
6045+
promote=promote,
6046+
keep_attrs=keep_attrs,
6047+
**kwargs,
6048+
)
6049+
59246050

59256051
ops.inject_all_ops_and_reduce_methods(Dataset, array_only=False)

0 commit comments

Comments
 (0)