Skip to content

Commit e28f171

Browse files
kmuehlbauerspencerkclarkdcherian
authored
fix mean for datetime-like using the respective time resolution unit (#9977)
* fix mean for datetime-like by using the respective dtype time resolution unit, adapting tests * fix mypy * add PR to existing entry for non-nanosecond datetimes * Update xarray/core/duck_array_ops.py Co-authored-by: Spencer Clark <[email protected]> * cast to "int64" in calculation of datime-like mean * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Spencer Clark <[email protected]> --------- Co-authored-by: Spencer Clark <[email protected]> Co-authored-by: Deepak Cherian <[email protected]>
1 parent e432479 commit e28f171

File tree

3 files changed

+51
-38
lines changed

3 files changed

+51
-38
lines changed

doc/whats-new.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ eventually be deprecated.
5050

5151
New Features
5252
~~~~~~~~~~~~
53-
- Relax nanosecond datetime restriction in CF time decoding (:issue:`7493`, :pull:`9618`).
53+
- Relax nanosecond datetime restriction in CF time decoding (:issue:`7493`, :pull:`9618`, :pull:`9977`).
5454
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_ and `Spencer Clark <https://github.com/spencerkclark>`_.
5555
- Enable the ``compute=False`` option in :py:meth:`DataTree.to_zarr`. (:pull:`9958`).
5656
By `Sam Levang <https://github.com/slevang>`_.

xarray/core/duck_array_ops.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -662,16 +662,10 @@ def _to_pytimedelta(array, unit="us"):
662662

663663

664664
def np_timedelta64_to_float(array, datetime_unit):
665-
"""Convert numpy.timedelta64 to float.
666-
667-
Notes
668-
-----
669-
The array is first converted to microseconds, which is less likely to
670-
cause overflow errors.
671-
"""
672-
array = array.astype("timedelta64[ns]").astype(np.float64)
673-
conversion_factor = np.timedelta64(1, "ns") / np.timedelta64(1, datetime_unit)
674-
return conversion_factor * array
665+
"""Convert numpy.timedelta64 to float, possibly at a loss of resolution."""
666+
unit, _ = np.datetime_data(array.dtype)
667+
conversion_factor = np.timedelta64(1, unit) / np.timedelta64(1, datetime_unit)
668+
return conversion_factor * array.astype(np.float64)
675669

676670

677671
def pd_timedelta_to_float(value, datetime_unit):
@@ -715,12 +709,15 @@ def mean(array, axis=None, skipna=None, **kwargs):
715709
if dtypes.is_datetime_like(array.dtype):
716710
offset = _datetime_nanmin(array)
717711

718-
# xarray always uses np.datetime64[ns] for np.datetime64 data
719-
dtype = "timedelta64[ns]"
712+
# From version 2025.01.2 xarray uses np.datetime64[unit], where unit
713+
# is one of "s", "ms", "us", "ns".
714+
# To not have to worry about the resolution, we just convert the output
715+
# to "timedelta64" (without unit) and let the dtype of offset take precedence.
716+
# This is fully backwards compatible with datetime64[ns].
720717
return (
721718
_mean(
722719
datetime_to_numeric(array, offset), axis=axis, skipna=skipna, **kwargs
723-
).astype(dtype)
720+
).astype("timedelta64")
724721
+ offset
725722
)
726723
elif _contains_cftime_datetimes(array):

xarray/tests/test_duck_array_ops.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from numpy import array, nan
1010

1111
from xarray import DataArray, Dataset, cftime_range, concat
12+
from xarray.coding.times import _NS_PER_TIME_DELTA
1213
from xarray.core import dtypes, duck_array_ops
1314
from xarray.core.duck_array_ops import (
1415
array_notnull_equiv,
@@ -28,6 +29,7 @@
2829
where,
2930
)
3031
from xarray.core.extension_array import PandasExtensionArray
32+
from xarray.core.types import NPDatetimeUnitOptions, PDDatetimeUnitOptions
3133
from xarray.namedarray.pycompat import array_type
3234
from xarray.testing import assert_allclose, assert_equal, assert_identical
3335
from xarray.tests import (
@@ -411,10 +413,11 @@ def assert_dask_array(da, dask):
411413
@arm_xfail
412414
@pytest.mark.filterwarnings("ignore:All-NaN .* encountered:RuntimeWarning")
413415
@pytest.mark.parametrize("dask", [False, True] if has_dask else [False])
414-
def test_datetime_mean(dask: bool) -> None:
416+
def test_datetime_mean(dask: bool, time_unit: PDDatetimeUnitOptions) -> None:
415417
# Note: only testing numpy, as dask is broken upstream
418+
dtype = f"M8[{time_unit}]"
416419
da = DataArray(
417-
np.array(["2010-01-01", "NaT", "2010-01-03", "NaT", "NaT"], dtype="M8[ns]"),
420+
np.array(["2010-01-01", "NaT", "2010-01-03", "NaT", "NaT"], dtype=dtype),
418421
dims=["time"],
419422
)
420423
if dask:
@@ -846,11 +849,11 @@ def test_multiple_dims(dtype, dask, skipna, func):
846849

847850

848851
@pytest.mark.parametrize("dask", [True, False])
849-
def test_datetime_to_numeric_datetime64(dask):
852+
def test_datetime_to_numeric_datetime64(dask, time_unit: PDDatetimeUnitOptions):
850853
if dask and not has_dask:
851854
pytest.skip("requires dask")
852855

853-
times = pd.date_range("2000", periods=5, freq="7D").values
856+
times = pd.date_range("2000", periods=5, freq="7D").as_unit(time_unit).values
854857
if dask:
855858
import dask.array
856859

@@ -874,8 +877,8 @@ def test_datetime_to_numeric_datetime64(dask):
874877
result = duck_array_ops.datetime_to_numeric(
875878
times, datetime_unit="h", dtype=dtype
876879
)
877-
expected = 24 * np.arange(0, 35, 7).astype(dtype)
878-
np.testing.assert_array_equal(result, expected)
880+
expected2 = 24 * np.arange(0, 35, 7).astype(dtype)
881+
np.testing.assert_array_equal(result, expected2)
879882

880883

881884
@requires_cftime
@@ -923,15 +926,18 @@ def test_datetime_to_numeric_cftime(dask):
923926

924927

925928
@requires_cftime
926-
def test_datetime_to_numeric_potential_overflow():
929+
def test_datetime_to_numeric_potential_overflow(time_unit: PDDatetimeUnitOptions):
927930
import cftime
928931

929-
times = pd.date_range("2000", periods=5, freq="7D").values.astype("datetime64[us]")
932+
if time_unit == "ns":
933+
pytest.skip("out-of-bounds datetime64 overflow")
934+
dtype = f"M8[{time_unit}]"
935+
times = pd.date_range("2000", periods=5, freq="7D").values.astype(dtype)
930936
cftimes = cftime_range(
931937
"2000", periods=5, freq="7D", calendar="proleptic_gregorian"
932938
).values
933939

934-
offset = np.datetime64("0001-01-01")
940+
offset = np.datetime64("0001-01-01", time_unit)
935941
cfoffset = cftime.DatetimeProlepticGregorian(1, 1, 1)
936942

937943
result = duck_array_ops.datetime_to_numeric(
@@ -957,35 +963,45 @@ def test_py_timedelta_to_float():
957963
assert py_timedelta_to_float(dt.timedelta(days=1e6), "D") == 1e6
958964

959965

960-
@pytest.mark.parametrize(
961-
"td, expected",
962-
([np.timedelta64(1, "D"), 86400 * 1e9], [np.timedelta64(1, "ns"), 1.0]),
963-
)
964-
def test_np_timedelta64_to_float(td, expected):
965-
out = np_timedelta64_to_float(td, datetime_unit="ns")
966+
@pytest.mark.parametrize("np_dt_unit", ["D", "h", "m", "s", "ms", "us", "ns"])
967+
def test_np_timedelta64_to_float(
968+
np_dt_unit: NPDatetimeUnitOptions, time_unit: PDDatetimeUnitOptions
969+
):
970+
# tests any combination of source np.timedelta64 (NPDatetimeUnitOptions) with
971+
# np_timedelta_to_float with dedicated target unit (PDDatetimeUnitOptions)
972+
td = np.timedelta64(1, np_dt_unit)
973+
expected = _NS_PER_TIME_DELTA[np_dt_unit] / _NS_PER_TIME_DELTA[time_unit]
974+
975+
out = np_timedelta64_to_float(td, datetime_unit=time_unit)
966976
np.testing.assert_allclose(out, expected)
967977
assert isinstance(out, float)
968978

969-
out = np_timedelta64_to_float(np.atleast_1d(td), datetime_unit="ns")
979+
out = np_timedelta64_to_float(np.atleast_1d(td), datetime_unit=time_unit)
970980
np.testing.assert_allclose(out, expected)
971981

972982

973-
@pytest.mark.parametrize(
974-
"td, expected", ([pd.Timedelta(1, "D"), 86400 * 1e9], [pd.Timedelta(1, "ns"), 1.0])
975-
)
976-
def test_pd_timedelta_to_float(td, expected):
977-
out = pd_timedelta_to_float(td, datetime_unit="ns")
983+
@pytest.mark.parametrize("np_dt_unit", ["D", "h", "m", "s", "ms", "us", "ns"])
984+
def test_pd_timedelta_to_float(
985+
np_dt_unit: NPDatetimeUnitOptions, time_unit: PDDatetimeUnitOptions
986+
):
987+
# tests any combination of source pd.Timedelta (NPDatetimeUnitOptions) with
988+
# np_timedelta_to_float with dedicated target unit (PDDatetimeUnitOptions)
989+
td = pd.Timedelta(1, np_dt_unit)
990+
expected = _NS_PER_TIME_DELTA[np_dt_unit] / _NS_PER_TIME_DELTA[time_unit]
991+
992+
out = pd_timedelta_to_float(td, datetime_unit=time_unit)
978993
np.testing.assert_allclose(out, expected)
979994
assert isinstance(out, float)
980995

981996

982997
@pytest.mark.parametrize(
983998
"td", [dt.timedelta(days=1), np.timedelta64(1, "D"), pd.Timedelta(1, "D"), "1 day"]
984999
)
985-
def test_timedelta_to_numeric(td):
1000+
def test_timedelta_to_numeric(td, time_unit: PDDatetimeUnitOptions):
9861001
# Scalar input
987-
out = timedelta_to_numeric(td, "ns")
988-
np.testing.assert_allclose(out, 86400 * 1e9)
1002+
out = timedelta_to_numeric(td, time_unit)
1003+
expected = _NS_PER_TIME_DELTA["D"] / _NS_PER_TIME_DELTA[time_unit]
1004+
np.testing.assert_allclose(out, expected)
9891005
assert isinstance(out, float)
9901006

9911007

0 commit comments

Comments
 (0)