Skip to content

Commit 224d2e8

Browse files
authored
REF: de-duplicate ndarray[datetimelike] wrapping (#38129)
1 parent eaa45cf commit 224d2e8

File tree

4 files changed

+35
-51
lines changed

4 files changed

+35
-51
lines changed

pandas/core/arrays/interval.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@
4444
from pandas.core.arrays.base import ExtensionArray, _extension_array_shared_docs
4545
from pandas.core.arrays.categorical import Categorical
4646
import pandas.core.common as com
47-
from pandas.core.construction import array, extract_array
47+
from pandas.core.construction import (
48+
array,
49+
ensure_wrapped_if_datetimelike,
50+
extract_array,
51+
)
4852
from pandas.core.indexers import check_array_indexer
4953
from pandas.core.indexes.base import ensure_index
5054
from pandas.core.ops import invalid_comparison, unpack_zerodim_and_defer
@@ -251,11 +255,9 @@ def _simple_new(
251255
raise ValueError(msg)
252256

253257
# For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray
254-
from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array
255-
256-
left = maybe_upcast_datetimelike_array(left)
258+
left = ensure_wrapped_if_datetimelike(left)
257259
left = extract_array(left, extract_numpy=True)
258-
right = maybe_upcast_datetimelike_array(right)
260+
right = ensure_wrapped_if_datetimelike(right)
259261
right = extract_array(right, extract_numpy=True)
260262

261263
lbase = getattr(left, "_ndarray", left).base

pandas/core/construction.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,24 @@ def extract_array(obj: object, extract_numpy: bool = False) -> Union[Any, ArrayL
402402
return obj
403403

404404

405+
def ensure_wrapped_if_datetimelike(arr):
406+
"""
407+
Wrap datetime64 and timedelta64 ndarrays in DatetimeArray/TimedeltaArray.
408+
"""
409+
if isinstance(arr, np.ndarray):
410+
if arr.dtype.kind == "M":
411+
from pandas.core.arrays import DatetimeArray
412+
413+
return DatetimeArray._from_sequence(arr)
414+
415+
elif arr.dtype.kind == "m":
416+
from pandas.core.arrays import TimedeltaArray
417+
418+
return TimedeltaArray._from_sequence(arr)
419+
420+
return arr
421+
422+
405423
def sanitize_array(
406424
data,
407425
index: Optional[Index],

pandas/core/dtypes/concat.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from pandas.core.arrays import ExtensionArray
2020
from pandas.core.arrays.sparse import SparseArray
21-
from pandas.core.construction import array
21+
from pandas.core.construction import array, ensure_wrapped_if_datetimelike
2222

2323

2424
def _get_dtype_kinds(arrays) -> Set[str]:
@@ -360,12 +360,14 @@ def _concat_datetime(to_concat, axis=0):
360360
-------
361361
a single array, preserving the combined dtypes
362362
"""
363-
to_concat = [_wrap_datetimelike(x) for x in to_concat]
363+
to_concat = [ensure_wrapped_if_datetimelike(x) for x in to_concat]
364+
364365
single_dtype = len({x.dtype for x in to_concat}) == 1
365366

366367
# multiple types, need to coerce to object
367368
if not single_dtype:
368-
# wrap_datetimelike ensures that astype(object) wraps in Timestamp/Timedelta
369+
# ensure_wrapped_if_datetimelike ensures that astype(object) wraps
370+
# in Timestamp/Timedelta
369371
return _concatenate_2d([x.astype(object) for x in to_concat], axis=axis)
370372

371373
if axis == 1:
@@ -379,17 +381,3 @@ def _concat_datetime(to_concat, axis=0):
379381
assert result.shape[0] == 1
380382
result = result[0]
381383
return result
382-
383-
384-
def _wrap_datetimelike(arr):
385-
"""
386-
Wrap datetime64 and timedelta64 ndarrays in DatetimeArray/TimedeltaArray.
387-
388-
DTA/TDA handle .astype(object) correctly.
389-
"""
390-
from pandas.core.construction import array as pd_array, extract_array
391-
392-
arr = extract_array(arr, extract_numpy=True)
393-
if isinstance(arr, np.ndarray) and arr.dtype.kind in ["m", "M"]:
394-
arr = pd_array(arr)
395-
return arr

pandas/core/ops/array_ops.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pandas.core.dtypes.generic import ABCExtensionArray, ABCIndexClass, ABCSeries
3131
from pandas.core.dtypes.missing import isna, notna
3232

33+
from pandas.core.construction import ensure_wrapped_if_datetimelike
3334
from pandas.core.ops import missing
3435
from pandas.core.ops.dispatch import should_extension_dispatch
3536
from pandas.core.ops.invalid import invalid_comparison
@@ -175,8 +176,8 @@ def arithmetic_op(left: ArrayLike, right: Any, op):
175176

176177
# NB: We assume that extract_array has already been called
177178
# on `left` and `right`.
178-
lvalues = maybe_upcast_datetimelike_array(left)
179-
rvalues = maybe_upcast_datetimelike_array(right)
179+
lvalues = ensure_wrapped_if_datetimelike(left)
180+
rvalues = ensure_wrapped_if_datetimelike(right)
180181
rvalues = _maybe_upcast_for_op(rvalues, lvalues.shape)
181182

182183
if should_extension_dispatch(lvalues, rvalues) or isinstance(rvalues, Timedelta):
@@ -206,7 +207,7 @@ def comparison_op(left: ArrayLike, right: Any, op) -> ArrayLike:
206207
ndarray or ExtensionArray
207208
"""
208209
# NB: We assume extract_array has already been called on left and right
209-
lvalues = maybe_upcast_datetimelike_array(left)
210+
lvalues = ensure_wrapped_if_datetimelike(left)
210211
rvalues = right
211212

212213
rvalues = lib.item_from_zerodim(rvalues)
@@ -331,7 +332,7 @@ def fill_bool(x, left=None):
331332
right = construct_1d_object_array_from_listlike(right)
332333

333334
# NB: We assume extract_array has already been called on left and right
334-
lvalues = maybe_upcast_datetimelike_array(left)
335+
lvalues = ensure_wrapped_if_datetimelike(left)
335336
rvalues = right
336337

337338
if should_extension_dispatch(lvalues, rvalues):
@@ -400,31 +401,6 @@ def get_array_op(op):
400401
raise NotImplementedError(op_name)
401402

402403

403-
def maybe_upcast_datetimelike_array(obj: ArrayLike) -> ArrayLike:
404-
"""
405-
If we have an ndarray that is either datetime64 or timedelta64, wrap in EA.
406-
407-
Parameters
408-
----------
409-
obj : ndarray or ExtensionArray
410-
411-
Returns
412-
-------
413-
ndarray or ExtensionArray
414-
"""
415-
if isinstance(obj, np.ndarray):
416-
if obj.dtype.kind == "m":
417-
from pandas.core.arrays import TimedeltaArray
418-
419-
return TimedeltaArray._from_sequence(obj)
420-
if obj.dtype.kind == "M":
421-
from pandas.core.arrays import DatetimeArray
422-
423-
return DatetimeArray._from_sequence(obj)
424-
425-
return obj
426-
427-
428404
def _maybe_upcast_for_op(obj, shape: Shape):
429405
"""
430406
Cast non-pandas objects to pandas types to unify behavior of arithmetic

0 commit comments

Comments
 (0)