Skip to content

Implement idxmax and idxmin functions #3871

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ Computation
:py:attr:`~Dataset.any`
:py:attr:`~Dataset.argmax`
:py:attr:`~Dataset.argmin`
:py:attr:`~Dataset.idxmax`
:py:attr:`~Dataset.idxmin`
:py:attr:`~Dataset.max`
:py:attr:`~Dataset.mean`
:py:attr:`~Dataset.median`
Expand Down Expand Up @@ -362,6 +364,8 @@ Computation
:py:attr:`~DataArray.any`
:py:attr:`~DataArray.argmax`
:py:attr:`~DataArray.argmin`
:py:attr:`~DataArray.idxmax`
:py:attr:`~DataArray.idxmin`
:py:attr:`~DataArray.max`
:py:attr:`~DataArray.mean`
:py:attr:`~DataArray.median`
Expand Down
6 changes: 6 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,16 @@ New Features
- Limited the length of array items with long string reprs to a
reasonable width (:pull:`3900`)
By `Maximilian Roos <https://github.com/max-sixty>`_
- Implement :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`,
:py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:issue:`60`, :pull:`3871`)
By `Todd Jennings <https://github.com/toddrjen>`_


Bug fixes
~~~~~~~~~
- Fix a regression where deleting a coordinate from a copied :py:class:`DataArray`
can affect the original :py:class:`Dataarray`. (:issue:`3899`, :pull:`3871`)
By `Todd Jennings <https://github.com/toddrjen>`_


Documentation
Expand Down
66 changes: 65 additions & 1 deletion xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@

import numpy as np

from . import duck_array_ops, utils
from . import dtypes, duck_array_ops, utils
from .alignment import deep_align
from .merge import merge_coordinates_without_align
from .nanops import dask_array
from .options import OPTIONS
from .pycompat import dask_array_type
from .utils import is_dict_like
Expand Down Expand Up @@ -1338,3 +1339,66 @@ def polyval(coord, coeffs, degree_dim="degree"):
coords={coord.name: coord, degree_dim: np.arange(deg_coord.max() + 1)[::-1]},
)
return (lhs * coeffs).sum(degree_dim)


def _calc_idxminmax(
*,
array,
func: Callable,
dim: Hashable = None,
skipna: bool = None,
fill_value: Any = dtypes.NA,
keep_attrs: bool = None,
):
"""Apply common operations for idxmin and idxmax."""
# This function doesn't make sense for scalars so don't try
if not array.ndim:
raise ValueError("This function does not apply for scalars")

if dim is not None:
pass # Use the dim if available
elif array.ndim == 1:
# it is okay to guess the dim if there is only 1
dim = array.dims[0]
else:
# The dim is not specified and ambiguous. Don't guess.
raise ValueError("Must supply 'dim' argument for multidimensional arrays")

if dim not in array.dims:
raise KeyError(f'Dimension "{dim}" not in dimension')
if dim not in array.coords:
raise KeyError(f'Dimension "{dim}" does not have coordinates')

# These are dtypes with NaN values argmin and argmax can handle
na_dtypes = "cfO"

if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Need to skip NaN values since argmin and argmax can't handle them
allna = array.isnull().all(dim)
array = array.where(~allna, 0)

# This will run argmin or argmax.
indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)

# Get the coordinate we want.
coordarray = array[dim]

# Handle dask arrays.
if isinstance(array, dask_array_type):
res = dask_array.map_blocks(coordarray, indx, dtype=indx.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@toddrjen @dcherian

Sorry, I might be wrong, but it seems, that the func is missing as argument to map_blocks. I tried with lambda a, b: a[b] which seems to work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two more things:

  • isinstance-check need to use array.data.
  • res need to be computed, otherwise subsequent actions with res will fail.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You’re right, this is definitely broken. Anyone up for putting together a fix in a follow up PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
res = coordarray[
indx,
]

if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Put the NaN values back in after removing them
res = res.where(~allna, fill_value)

# The dim is gone but we need to remove the corresponding coordinate.
del res.coords[dim]

# Copy attributes from argmin/argmax, if any
res.attrs = indx.attrs

return res
193 changes: 192 additions & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,10 @@ def copy(self, deep: bool = True, data: Any = None) -> "DataArray":
"""
variable = self.variable.copy(deep=deep, data=data)
coords = {k: v.copy(deep=deep) for k, v in self._coords.items()}
indexes = self._indexes
if self._indexes is None:
indexes = self._indexes
else:
indexes = {k: v.copy(deep=deep) for k, v in self._indexes.items()}
return self._replace(variable, coords, indexes=indexes)

def __copy__(self) -> "DataArray":
Expand Down Expand Up @@ -3505,6 +3508,194 @@ def pad(
)
return self._from_temp_dataset(ds)

def idxmin(
self,
dim: Hashable = None,
skipna: bool = None,
fill_value: Any = dtypes.NA,
keep_attrs: bool = None,
) -> "DataArray":
"""Return the coordinate label of the minimum value along a dimension.

Returns a new `DataArray` named after the dimension with the values of
the coordinate labels along that dimension corresponding to minimum
values along that dimension.

In comparison to :py:meth:`~DataArray.argmin`, this returns the
coordinate label while :py:meth:`~DataArray.argmin` returns the index.

Parameters
----------
dim : str, optional
Dimension over which to apply `idxmin`. This is optional for 1D
arrays, but required for arrays with 2 or more dimensions.
skipna : bool or None, default None
If True, skip missing values (as marked by NaN). By default, only
skips missing values for ``float``, ``complex``, and ``object``
dtypes; other dtypes either do not have a sentinel missing value
(``int``) or ``skipna=True`` has not been implemented
(``datetime64`` or ``timedelta64``).
fill_value : Any, default NaN
Value to be filled in case all of the values along a dimension are
null. By default this is NaN. The fill value and result are
automatically converted to a compatible dtype if possible.
Ignored if ``skipna`` is False.
keep_attrs : bool, default False
If True, the attributes (``attrs``) will be copied from the
original object to the new one. If False (default), the new object
will be returned without attributes.

Returns
-------
reduced : DataArray
New `DataArray` object with `idxmin` applied to its data and the
indicated dimension removed.

See also
--------
Dataset.idxmin, DataArray.idxmax, DataArray.min, DataArray.argmin

Examples
--------

>>> array = xr.DataArray([0, 2, 1, 0, -2], dims="x",
... coords={"x": ['a', 'b', 'c', 'd', 'e']})
>>> array.min()
<xarray.DataArray ()>
array(-2)
>>> array.argmin()
<xarray.DataArray ()>
array(4)
>>> array.idxmin()
<xarray.DataArray 'x' ()>
array('e', dtype='<U1')

>>> array = xr.DataArray([[2.0, 1.0, 2.0, 0.0, -2.0],
... [-4.0, np.NaN, 2.0, np.NaN, -2.0],
... [np.NaN, np.NaN, 1., np.NaN, np.NaN]],
... dims=["y", "x"],
... coords={"y": [-1, 0, 1],
... "x": np.arange(5.)**2}
... )
>>> array.min(dim="x")
<xarray.DataArray (y: 3)>
array([-2., -4., 1.])
Coordinates:
* y (y) int64 -1 0 1
>>> array.argmin(dim="x")
<xarray.DataArray (y: 3)>
array([4, 0, 2])
Coordinates:
* y (y) int64 -1 0 1
>>> array.idxmin(dim="x")
<xarray.DataArray 'x' (y: 3)>
array([16., 0., 4.])
Coordinates:
* y (y) int64 -1 0 1
"""
return computation._calc_idxminmax(
array=self,
func=lambda x, *args, **kwargs: x.argmin(*args, **kwargs),
dim=dim,
skipna=skipna,
fill_value=fill_value,
keep_attrs=keep_attrs,
)

def idxmax(
self,
dim: Hashable = None,
skipna: bool = None,
fill_value: Any = dtypes.NA,
keep_attrs: bool = None,
) -> "DataArray":
"""Return the coordinate label of the maximum value along a dimension.

Returns a new `DataArray` named after the dimension with the values of
the coordinate labels along that dimension corresponding to maximum
values along that dimension.

In comparison to :py:meth:`~DataArray.argmax`, this returns the
coordinate label while :py:meth:`~DataArray.argmax` returns the index.

Parameters
----------
dim : str, optional
Dimension over which to apply `idxmax`. This is optional for 1D
arrays, but required for arrays with 2 or more dimensions.
skipna : bool or None, default None
If True, skip missing values (as marked by NaN). By default, only
skips missing values for ``float``, ``complex``, and ``object``
dtypes; other dtypes either do not have a sentinel missing value
(``int``) or ``skipna=True`` has not been implemented
(``datetime64`` or ``timedelta64``).
fill_value : Any, default NaN
Value to be filled in case all of the values along a dimension are
null. By default this is NaN. The fill value and result are
automatically converted to a compatible dtype if possible.
Ignored if ``skipna`` is False.
keep_attrs : bool, default False
If True, the attributes (``attrs``) will be copied from the
original object to the new one. If False (default), the new object
will be returned without attributes.

Returns
-------
reduced : DataArray
New `DataArray` object with `idxmax` applied to its data and the
indicated dimension removed.

See also
--------
Dataset.idxmax, DataArray.idxmin, DataArray.max, DataArray.argmax

Examples
--------

>>> array = xr.DataArray([0, 2, 1, 0, -2], dims="x",
... coords={"x": ['a', 'b', 'c', 'd', 'e']})
>>> array.max()
<xarray.DataArray ()>
array(2)
>>> array.argmax()
<xarray.DataArray ()>
array(1)
>>> array.idxmax()
<xarray.DataArray 'x' ()>
array('b', dtype='<U1')

>>> array = xr.DataArray([[2.0, 1.0, 2.0, 0.0, -2.0],
... [-4.0, np.NaN, 2.0, np.NaN, -2.0],
... [np.NaN, np.NaN, 1., np.NaN, np.NaN]],
... dims=["y", "x"],
... coords={"y": [-1, 0, 1],
... "x": np.arange(5.)**2}
... )
>>> array.max(dim="x")
<xarray.DataArray (y: 3)>
array([2., 2., 1.])
Coordinates:
* y (y) int64 -1 0 1
>>> array.argmax(dim="x")
<xarray.DataArray (y: 3)>
array([0, 2, 2])
Coordinates:
* y (y) int64 -1 0 1
>>> array.idxmax(dim="x")
<xarray.DataArray 'x' (y: 3)>
array([0., 4., 4.])
Coordinates:
* y (y) int64 -1 0 1
"""
return computation._calc_idxminmax(
array=self,
func=lambda x, *args, **kwargs: x.argmax(*args, **kwargs),
dim=dim,
skipna=skipna,
fill_value=fill_value,
keep_attrs=keep_attrs,
)

# this needs to be at the end, or mypy will confuse with `str`
# https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names
str = property(StringAccessor)
Expand Down
Loading