Skip to content

Commit 0308672

Browse files
committed
refactor timedelta decoding to _numbers_to_timedelta and res-use it within decode_cf_timedelta
1 parent 2bbf0ff commit 0308672

File tree

1 file changed

+53
-38
lines changed

1 file changed

+53
-38
lines changed

xarray/coding/times.py

Lines changed: 53 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -458,41 +458,12 @@ def _decode_datetime_with_pandas(
458458
elif flat_num_dates.dtype.kind in "f":
459459
flat_num_dates = flat_num_dates.astype(np.float64)
460460

461-
# keep NaT/nan mask
462-
nan = np.isnan(flat_num_dates) | (flat_num_dates == np.iinfo(np.int64).min)
463-
464-
# in case we need to change the unit, we fix the numbers here
465-
# this should be safe, as errors would have been raised above
466-
ns_time_unit = _NS_PER_TIME_DELTA[time_unit]
467-
ns_ref_date_unit = _NS_PER_TIME_DELTA[ref_date.unit]
468-
if ns_time_unit > ns_ref_date_unit:
469-
flat_num_dates *= np.int64(ns_time_unit / ns_ref_date_unit)
470-
time_unit = ref_date.unit
471-
472-
# estimate fitting resolution for floating point values
473-
# this iterates until all floats are fractionless or time_unit == "ns"
474-
if flat_num_dates.dtype.kind == "f" and time_unit != "ns":
475-
flat_num_dates, new_time_unit = _check_higher_resolution(
476-
flat_num_dates, time_unit
477-
)
478-
if time_unit != new_time_unit:
479-
msg = (
480-
f"Can't decode floating point datetime to {time_unit!r} without "
481-
f"precision loss, decoding to {new_time_unit!r} instead. "
482-
f"To silence this warning use time_unit={new_time_unit!r} in call to "
483-
f"decoding function."
484-
)
485-
emit_user_level_warning(msg, SerializationWarning)
486-
time_unit = new_time_unit
487-
488-
# Cast input ordinals to integers and properly handle NaN/NaT
489-
# to prevent casting NaN to int
490-
flat_num_dates_int = np.zeros_like(flat_num_dates, dtype=np.int64)
491-
flat_num_dates_int[nan] = np.iinfo(np.int64).min
492-
flat_num_dates_int[~nan] = flat_num_dates[~nan].astype(np.int64)
461+
timedeltas = _numbers_to_timedelta(
462+
flat_num_dates, time_unit, ref_date.unit, "datetime"
463+
)
493464

494-
# cast to timedelta64[time_unit] and add to ref_date
495-
return ref_date + flat_num_dates_int.astype(f"timedelta64[{time_unit}]")
465+
# add timedeltas to ref_date
466+
return ref_date + timedeltas
496467

497468

498469
def decode_cf_datetime(
@@ -590,21 +561,65 @@ def to_datetime_unboxed(value, **kwargs):
590561
return result
591562

592563

564+
def _numbers_to_timedelta(
565+
flat_num: np.ndarray,
566+
time_unit: NPDatetimeUnitOptions,
567+
ref_unit: PDDatetimeUnitOptions,
568+
datatype: str,
569+
) -> np.ndarray:
570+
"""Transform numbers to np.timedelta64."""
571+
# keep NaT/nan mask
572+
nan = np.isnan(flat_num) | (flat_num == np.iinfo(np.int64).min)
573+
574+
# in case we need to change the unit, we fix the numbers here
575+
# this should be safe, as errors would have been raised above
576+
ns_time_unit = _NS_PER_TIME_DELTA[time_unit]
577+
ns_ref_date_unit = _NS_PER_TIME_DELTA[ref_unit]
578+
if ns_time_unit > ns_ref_date_unit:
579+
flat_num *= np.int64(ns_time_unit / ns_ref_date_unit)
580+
time_unit = ref_unit
581+
582+
# estimate fitting resolution for floating point values
583+
# this iterates until all floats are fractionless or time_unit == "ns"
584+
if flat_num.dtype.kind == "f" and time_unit != "ns":
585+
flat_num_dates, new_time_unit = _check_higher_resolution(flat_num, time_unit)
586+
if time_unit != new_time_unit:
587+
msg = (
588+
f"Can't decode floating point {datatype} to {time_unit!r} without "
589+
f"precision loss, decoding to {new_time_unit!r} instead. "
590+
f"To silence this warning use time_unit={new_time_unit!r} in call to "
591+
f"decoding function."
592+
)
593+
emit_user_level_warning(msg, SerializationWarning)
594+
time_unit = new_time_unit
595+
596+
# Cast input ordinals to integers and properly handle NaN/NaT
597+
# to prevent casting NaN to int
598+
with warnings.catch_warnings():
599+
warnings.simplefilter("ignore", RuntimeWarning)
600+
flat_num = flat_num.astype(np.int64)
601+
flat_num[nan] = np.iinfo(np.int64).min
602+
603+
# cast to wanted type
604+
return flat_num.astype(f"timedelta64[{time_unit}]")
605+
606+
593607
def decode_cf_timedelta(num_timedeltas, units: str) -> np.ndarray:
594608
# todo: check, if this works as intended
595609
"""Given an array of numeric timedeltas in netCDF format, convert it into a
596610
numpy timedelta64 ["s", "ms", "us", "ns"] array.
597611
"""
598612
num_timedeltas = np.asarray(num_timedeltas)
599613
unit = _netcdf_to_numpy_timeunit(units)
614+
615+
timedeltas = _numbers_to_timedelta(num_timedeltas, unit, "s", "timedelta")
616+
600617
as_unit = unit
601618
if unit not in {"s", "ms", "us", "ns"}:
602619
# default to ns, when not specified
603620
as_unit = "ns"
604-
result = (
605-
pd.to_timedelta(ravel(num_timedeltas), unit=unit).as_unit(as_unit).to_numpy()
606-
)
607-
return reshape(result, num_timedeltas.shape)
621+
result = pd.to_timedelta(ravel(timedeltas)).as_unit(as_unit).to_numpy()
622+
return reshape(result, timedeltas.shape)
608623

609624

610625
def _unit_timedelta_cftime(units: str) -> timedelta:

0 commit comments

Comments
 (0)