Skip to content

Commit 09df5ca

Browse files
authored
Allow non-unique and non-monotonic coordinates in get_clean_interp_index and polyfit (#4099)
* Allow non-unique and non-monotonic in get_clean_interp_index and polyfit * black on missing.py * Apply change to polyval, add pr to whats new * Add tests for get_clean_interp_index return values
1 parent 93b2d04 commit 09df5ca

File tree

5 files changed

+26
-9
lines changed

5 files changed

+26
-9
lines changed

doc/whats-new.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ New Features
5050
By `Andrew Williams <https://github.com/AndrewWilliams3142>`_
5151
- Added :py:func:`xarray.cov` and :py:func:`xarray.corr` (:issue:`3784`, :pull:`3550`, :pull:`4089`).
5252
By `Andrew Williams <https://github.com/AndrewWilliams3142>`_ and `Robin Beer <https://github.com/r-beer>`_.
53-
- Added :py:meth:`DataArray.polyfit` and :py:func:`xarray.polyval` for fitting polynomials. (:issue:`3349`)
53+
- Added :py:meth:`DataArray.polyfit` and :py:func:`xarray.polyval` for fitting polynomials. (:issue:`3349`, :pull:`3733`, :pull:`4099`)
5454
By `Pascal Bourgault <https://github.com/aulemahal>`_.
5555
- Control over attributes of result in :py:func:`merge`, :py:func:`concat`,
5656
:py:func:`combine_by_coords` and :py:func:`combine_nested` using

xarray/core/computation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1506,7 +1506,7 @@ def polyval(coord, coeffs, degree_dim="degree"):
15061506
from .dataarray import DataArray
15071507
from .missing import get_clean_interp_index
15081508

1509-
x = get_clean_interp_index(coord, coord.name)
1509+
x = get_clean_interp_index(coord, coord.name, strict=False)
15101510

15111511
deg_coord = coeffs[degree_dim]
15121512

xarray/core/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5839,7 +5839,7 @@ def polyfit(
58395839
variables = {}
58405840
skipna_da = skipna
58415841

5842-
x = get_clean_interp_index(self, dim)
5842+
x = get_clean_interp_index(self, dim, strict=False)
58435843
xname = "{}_".format(self[dim].name)
58445844
order = int(deg) + 1
58455845
lhs = np.vander(x, order)

xarray/core/missing.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ def _apply_over_vars_with_dim(func, self, dim=None, **kwargs):
208208
return ds
209209

210210

211-
def get_clean_interp_index(arr, dim: Hashable, use_coordinate: Union[str, bool] = True):
211+
def get_clean_interp_index(
212+
arr, dim: Hashable, use_coordinate: Union[str, bool] = True, strict: bool = True
213+
):
212214
"""Return index to use for x values in interpolation or curve fitting.
213215
214216
Parameters
@@ -221,6 +223,8 @@ def get_clean_interp_index(arr, dim: Hashable, use_coordinate: Union[str, bool]
221223
If use_coordinate is True, the coordinate that shares the name of the
222224
dimension along which interpolation is being performed will be used as the
223225
x values. If False, the x values are set as an equally spaced sequence.
226+
strict : bool
227+
Whether to raise errors if the index is either non-unique or non-monotonic (default).
224228
225229
Returns
226230
-------
@@ -257,11 +261,12 @@ def get_clean_interp_index(arr, dim: Hashable, use_coordinate: Union[str, bool]
257261
if isinstance(index, pd.MultiIndex):
258262
index.name = dim
259263

260-
if not index.is_monotonic:
261-
raise ValueError(f"Index {index.name!r} must be monotonically increasing")
264+
if strict:
265+
if not index.is_monotonic:
266+
raise ValueError(f"Index {index.name!r} must be monotonically increasing")
262267

263-
if not index.is_unique:
264-
raise ValueError(f"Index {index.name!r} has duplicate values")
268+
if not index.is_unique:
269+
raise ValueError(f"Index {index.name!r} has duplicate values")
265270

266271
# Special case for non-standard calendar indexes
267272
# Numerical datetime values are defined with respect to 1970-01-01T00:00:00 in units of nanoseconds
@@ -282,7 +287,7 @@ def get_clean_interp_index(arr, dim: Hashable, use_coordinate: Union[str, bool]
282287
# xarray/numpy raise a ValueError
283288
raise TypeError(
284289
f"Index {index.name!r} must be castable to float64 to support "
285-
f"interpolation, got {type(index).__name__}."
290+
f"interpolation or curve fitting, got {type(index).__name__}."
286291
)
287292

288293
return index

xarray/tests/test_missing.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,18 @@ def test_get_clean_interp_index_potential_overflow():
534534
get_clean_interp_index(da, "time")
535535

536536

537+
@pytest.mark.parametrize("index", ([0, 2, 1], [0, 1, 1]))
538+
def test_get_clean_interp_index_strict(index):
539+
da = xr.DataArray([0, 1, 2], dims=("x",), coords={"x": index})
540+
541+
with pytest.raises(ValueError):
542+
get_clean_interp_index(da, "x")
543+
544+
clean = get_clean_interp_index(da, "x", strict=False)
545+
np.testing.assert_array_equal(index, clean)
546+
assert clean.dtype == np.float64
547+
548+
537549
@pytest.fixture
538550
def da_time():
539551
return xr.DataArray(

0 commit comments

Comments
 (0)