Skip to content

Commit 6fbeb13

Browse files
headtr1ckpre-commit-ci[bot]max-sixtydcherian
authored
polyval: Use Horner's algorithm + support chunked inputs (#6548)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Maximilian Roos <[email protected]> Co-authored-by: Deepak Cherian <[email protected]>
1 parent 39bda44 commit 6fbeb13

File tree

4 files changed

+220
-45
lines changed

4 files changed

+220
-45
lines changed

asv_bench/benchmarks/polyfit.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import numpy as np
2+
3+
import xarray as xr
4+
5+
from . import parameterized, randn, requires_dask
6+
7+
NDEGS = (2, 5, 20)
8+
NX = (10**2, 10**6)
9+
10+
11+
class Polyval:
12+
def setup(self, *args, **kwargs):
13+
self.xs = {nx: xr.DataArray(randn((nx,)), dims="x", name="x") for nx in NX}
14+
self.coeffs = {
15+
ndeg: xr.DataArray(
16+
randn((ndeg,)), dims="degree", coords={"degree": np.arange(ndeg)}
17+
)
18+
for ndeg in NDEGS
19+
}
20+
21+
@parameterized(["nx", "ndeg"], [NX, NDEGS])
22+
def time_polyval(self, nx, ndeg):
23+
x = self.xs[nx]
24+
c = self.coeffs[ndeg]
25+
xr.polyval(x, c).compute()
26+
27+
@parameterized(["nx", "ndeg"], [NX, NDEGS])
28+
def peakmem_polyval(self, nx, ndeg):
29+
x = self.xs[nx]
30+
c = self.coeffs[ndeg]
31+
xr.polyval(x, c).compute()
32+
33+
34+
class PolyvalDask(Polyval):
35+
def setup(self, *args, **kwargs):
36+
requires_dask()
37+
super().setup(*args, **kwargs)
38+
self.xs = {k: v.chunk({"x": 10000}) for k, v in self.xs.items()}

doc/whats-new.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ New Features
4141
- Allow passing chunks in ``**kwargs`` form to :py:meth:`Dataset.chunk`, :py:meth:`DataArray.chunk`, and
4242
:py:meth:`Variable.chunk`. (:pull:`6471`)
4343
By `Tom Nicholas <https://github.com/TomNicholas>`_.
44+
- :py:meth:`xr.polyval` now supports :py:class:`Dataset` and :py:class:`DataArray` args of any shape,
45+
is faster and requires less memory. (:pull:`6548`)
46+
By `Michael Niklas <https://github.com/headtr1ck>`_.
4447

4548
Breaking changes
4649
~~~~~~~~~~~~~~~~
@@ -74,6 +77,10 @@ Breaking changes
7477
- Xarray's ufuncs have been removed, now that they can be replaced by numpy's ufuncs in all
7578
supported versions of numpy.
7679
By `Maximilian Roos <https://github.com/max-sixty>`_.
80+
- :py:meth:`xr.polyval` now uses the ``coord`` argument directly instead of its index coordinate.
81+
(:pull:`6548`)
82+
By `Michael Niklas <https://github.com/headtr1ck>`_.
83+
7784

7885
Deprecations
7986
~~~~~~~~~~~~

xarray/core/computation.py

Lines changed: 84 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717
Iterable,
1818
Mapping,
1919
Sequence,
20+
overload,
2021
)
2122

2223
import numpy as np
2324

2425
from . import dtypes, duck_array_ops, utils
2526
from .alignment import align, deep_align
27+
from .common import zeros_like
28+
from .duck_array_ops import datetime_to_numeric
2629
from .indexes import Index, filter_indexes_from_coords
2730
from .merge import merge_attrs, merge_coordinates_without_align
2831
from .options import OPTIONS, _get_keep_attrs
@@ -1843,36 +1846,100 @@ def where(cond, x, y, keep_attrs=None):
18431846
)
18441847

18451848

1846-
def polyval(coord, coeffs, degree_dim="degree"):
1849+
@overload
1850+
def polyval(coord: DataArray, coeffs: DataArray, degree_dim: Hashable) -> DataArray:
1851+
...
1852+
1853+
1854+
@overload
1855+
def polyval(coord: T_Xarray, coeffs: Dataset, degree_dim: Hashable) -> Dataset:
1856+
...
1857+
1858+
1859+
@overload
1860+
def polyval(coord: Dataset, coeffs: T_Xarray, degree_dim: Hashable) -> Dataset:
1861+
...
1862+
1863+
1864+
def polyval(
1865+
coord: T_Xarray, coeffs: T_Xarray, degree_dim: Hashable = "degree"
1866+
) -> T_Xarray:
18471867
"""Evaluate a polynomial at specific values
18481868
18491869
Parameters
18501870
----------
1851-
coord : DataArray
1852-
The 1D coordinate along which to evaluate the polynomial.
1853-
coeffs : DataArray
1854-
Coefficients of the polynomials.
1855-
degree_dim : str, default: "degree"
1871+
coord : DataArray or Dataset
1872+
Values at which to evaluate the polynomial.
1873+
coeffs : DataArray or Dataset
1874+
Coefficients of the polynomial.
1875+
degree_dim : Hashable, default: "degree"
18561876
Name of the polynomial degree dimension in `coeffs`.
18571877
1878+
Returns
1879+
-------
1880+
DataArray or Dataset
1881+
Evaluated polynomial.
1882+
18581883
See Also
18591884
--------
18601885
xarray.DataArray.polyfit
1861-
numpy.polyval
1886+
numpy.polynomial.polynomial.polyval
18621887
"""
1863-
from .dataarray import DataArray
1864-
from .missing import get_clean_interp_index
18651888

1866-
x = get_clean_interp_index(coord, coord.name, strict=False)
1889+
if degree_dim not in coeffs._indexes:
1890+
raise ValueError(
1891+
f"Dimension `{degree_dim}` should be a coordinate variable with labels."
1892+
)
1893+
if not np.issubdtype(coeffs[degree_dim].dtype, int):
1894+
raise ValueError(
1895+
f"Dimension `{degree_dim}` should be of integer dtype. Received {coeffs[degree_dim].dtype} instead."
1896+
)
1897+
max_deg = coeffs[degree_dim].max().item()
1898+
coeffs = coeffs.reindex(
1899+
{degree_dim: np.arange(max_deg + 1)}, fill_value=0, copy=False
1900+
)
1901+
coord = _ensure_numeric(coord)
1902+
1903+
# using Horner's method
1904+
# https://en.wikipedia.org/wiki/Horner%27s_method
1905+
res = coeffs.isel({degree_dim: max_deg}, drop=True) + zeros_like(coord)
1906+
for deg in range(max_deg - 1, -1, -1):
1907+
res *= coord
1908+
res += coeffs.isel({degree_dim: deg}, drop=True)
18671909

1868-
deg_coord = coeffs[degree_dim]
1910+
return res
18691911

1870-
lhs = DataArray(
1871-
np.vander(x, int(deg_coord.max()) + 1),
1872-
dims=(coord.name, degree_dim),
1873-
coords={coord.name: coord, degree_dim: np.arange(deg_coord.max() + 1)[::-1]},
1874-
)
1875-
return (lhs * coeffs).sum(degree_dim)
1912+
1913+
def _ensure_numeric(data: T_Xarray) -> T_Xarray:
1914+
"""Converts all datetime64 variables to float64
1915+
1916+
Parameters
1917+
----------
1918+
data : DataArray or Dataset
1919+
Variables with possible datetime dtypes.
1920+
1921+
Returns
1922+
-------
1923+
DataArray or Dataset
1924+
Variables with datetime64 dtypes converted to float64.
1925+
"""
1926+
from .dataset import Dataset
1927+
1928+
def to_floatable(x: DataArray) -> DataArray:
1929+
if x.dtype.kind in "mM":
1930+
return x.copy(
1931+
data=datetime_to_numeric(
1932+
x.data,
1933+
offset=np.datetime64("1970-01-01"),
1934+
datetime_unit="ns",
1935+
),
1936+
)
1937+
return x
1938+
1939+
if isinstance(data, Dataset):
1940+
return data.map(to_floatable)
1941+
else:
1942+
return to_floatable(data)
18761943

18771944

18781945
def _calc_idxminmax(

xarray/tests/test_computation.py

Lines changed: 91 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,37 +1933,100 @@ def test_where_attrs() -> None:
19331933
assert actual.attrs == {}
19341934

19351935

1936-
@pytest.mark.parametrize("use_dask", [True, False])
1937-
@pytest.mark.parametrize("use_datetime", [True, False])
1938-
def test_polyval(use_dask, use_datetime) -> None:
1939-
if use_dask and not has_dask:
1940-
pytest.skip("requires dask")
1941-
1942-
if use_datetime:
1943-
xcoord = xr.DataArray(
1944-
pd.date_range("2000-01-01", freq="D", periods=10), dims=("x",), name="x"
1945-
)
1946-
x = xr.core.missing.get_clean_interp_index(xcoord, "x")
1947-
else:
1948-
x = np.arange(10)
1949-
xcoord = xr.DataArray(x, dims=("x",), name="x")
1950-
1951-
da = xr.DataArray(
1952-
np.stack((1.0 + x + 2.0 * x**2, 1.0 + 2.0 * x + 3.0 * x**2)),
1953-
dims=("d", "x"),
1954-
coords={"x": xcoord, "d": [0, 1]},
1955-
)
1956-
coeffs = xr.DataArray(
1957-
[[2, 1, 1], [3, 2, 1]],
1958-
dims=("d", "degree"),
1959-
coords={"d": [0, 1], "degree": [2, 1, 0]},
1960-
)
1936+
@pytest.mark.parametrize("use_dask", [False, True])
1937+
@pytest.mark.parametrize(
1938+
["x", "coeffs", "expected"],
1939+
[
1940+
pytest.param(
1941+
xr.DataArray([1, 2, 3], dims="x"),
1942+
xr.DataArray([2, 3, 4], dims="degree", coords={"degree": [0, 1, 2]}),
1943+
xr.DataArray([9, 2 + 6 + 16, 2 + 9 + 36], dims="x"),
1944+
id="simple",
1945+
),
1946+
pytest.param(
1947+
xr.DataArray([1, 2, 3], dims="x"),
1948+
xr.DataArray(
1949+
[[0, 1], [0, 1]], dims=("y", "degree"), coords={"degree": [0, 1]}
1950+
),
1951+
xr.DataArray([[1, 2, 3], [1, 2, 3]], dims=("y", "x")),
1952+
id="broadcast-x",
1953+
),
1954+
pytest.param(
1955+
xr.DataArray([1, 2, 3], dims="x"),
1956+
xr.DataArray(
1957+
[[0, 1], [1, 0], [1, 1]],
1958+
dims=("x", "degree"),
1959+
coords={"degree": [0, 1]},
1960+
),
1961+
xr.DataArray([1, 1, 1 + 3], dims="x"),
1962+
id="shared-dim",
1963+
),
1964+
pytest.param(
1965+
xr.DataArray([1, 2, 3], dims="x"),
1966+
xr.DataArray([1, 0, 0], dims="degree", coords={"degree": [2, 1, 0]}),
1967+
xr.DataArray([1, 2**2, 3**2], dims="x"),
1968+
id="reordered-index",
1969+
),
1970+
pytest.param(
1971+
xr.DataArray([1, 2, 3], dims="x"),
1972+
xr.DataArray([5], dims="degree", coords={"degree": [3]}),
1973+
xr.DataArray([5, 5 * 2**3, 5 * 3**3], dims="x"),
1974+
id="sparse-index",
1975+
),
1976+
pytest.param(
1977+
xr.DataArray([1, 2, 3], dims="x"),
1978+
xr.Dataset(
1979+
{"a": ("degree", [0, 1]), "b": ("degree", [1, 0])},
1980+
coords={"degree": [0, 1]},
1981+
),
1982+
xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [1, 1, 1])}),
1983+
id="array-dataset",
1984+
),
1985+
pytest.param(
1986+
xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [2, 3, 4])}),
1987+
xr.DataArray([1, 1], dims="degree", coords={"degree": [0, 1]}),
1988+
xr.Dataset({"a": ("x", [2, 3, 4]), "b": ("x", [3, 4, 5])}),
1989+
id="dataset-array",
1990+
),
1991+
pytest.param(
1992+
xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("y", [2, 3, 4])}),
1993+
xr.Dataset(
1994+
{"a": ("degree", [0, 1]), "b": ("degree", [1, 1])},
1995+
coords={"degree": [0, 1]},
1996+
),
1997+
xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("y", [3, 4, 5])}),
1998+
id="dataset-dataset",
1999+
),
2000+
pytest.param(
2001+
xr.DataArray(pd.date_range("1970-01-01", freq="s", periods=3), dims="x"),
2002+
xr.DataArray([0, 1], dims="degree", coords={"degree": [0, 1]}),
2003+
xr.DataArray(
2004+
[0, 1e9, 2e9],
2005+
dims="x",
2006+
coords={"x": pd.date_range("1970-01-01", freq="s", periods=3)},
2007+
),
2008+
id="datetime",
2009+
),
2010+
],
2011+
)
2012+
def test_polyval(use_dask, x, coeffs, expected) -> None:
19612013
if use_dask:
1962-
coeffs = coeffs.chunk({"d": 2})
2014+
if not has_dask:
2015+
pytest.skip("requires dask")
2016+
coeffs = coeffs.chunk({"degree": 2})
2017+
x = x.chunk({"x": 2})
2018+
with raise_if_dask_computes():
2019+
actual = xr.polyval(x, coeffs)
2020+
xr.testing.assert_allclose(actual, expected)
19632021

1964-
da_pv = xr.polyval(da.x, coeffs)
19652022

1966-
xr.testing.assert_allclose(da, da_pv.T)
2023+
def test_polyval_degree_dim_checks():
2024+
x = (xr.DataArray([1, 2, 3], dims="x"),)
2025+
coeffs = xr.DataArray([2, 3, 4], dims="degree", coords={"degree": [0, 1, 2]})
2026+
with pytest.raises(ValueError):
2027+
xr.polyval(x, coeffs.drop_vars("degree"))
2028+
with pytest.raises(ValueError):
2029+
xr.polyval(x, coeffs.assign_coords(degree=coeffs.degree.astype(float)))
19672030

19682031

19692032
@pytest.mark.parametrize("use_dask", [False, True])

0 commit comments

Comments
 (0)