Skip to content

Commit e5266af

Browse files
authored
REF: simplify maybe_upcast_putmask (#38487)
1 parent 8488ee2 commit e5266af

File tree

3 files changed

+19
-86
lines changed

3 files changed

+19
-86
lines changed

pandas/core/dtypes/cast.py

Lines changed: 9 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -418,9 +418,7 @@ def maybe_cast_to_extension_array(
418418
return result
419419

420420

421-
def maybe_upcast_putmask(
422-
result: np.ndarray, mask: np.ndarray, other: Scalar
423-
) -> Tuple[np.ndarray, bool]:
421+
def maybe_upcast_putmask(result: np.ndarray, mask: np.ndarray) -> np.ndarray:
424422
"""
425423
A safe version of putmask that potentially upcasts the result.
426424
@@ -434,69 +432,38 @@ def maybe_upcast_putmask(
434432
The destination array. This will be mutated in-place if no upcasting is
435433
necessary.
436434
mask : boolean ndarray
437-
other : scalar
438-
The source value.
439435
440436
Returns
441437
-------
442438
result : ndarray
443-
changed : bool
444-
Set to true if the result array was upcasted.
445439
446440
Examples
447441
--------
448442
>>> arr = np.arange(1, 6)
449443
>>> mask = np.array([False, True, False, True, True])
450-
>>> result, _ = maybe_upcast_putmask(arr, mask, False)
444+
>>> result = maybe_upcast_putmask(arr, mask)
451445
>>> result
452-
array([1, 0, 3, 0, 0])
446+
array([ 1., nan, 3., nan, nan])
453447
"""
454448
if not isinstance(result, np.ndarray):
455449
raise ValueError("The result input must be a ndarray.")
456-
if not is_scalar(other):
457-
# We _could_ support non-scalar other, but until we have a compelling
458-
# use case, we assume away the possibility.
459-
raise ValueError("other must be a scalar")
450+
451+
# NB: we never get here with result.dtype.kind in ["m", "M"]
460452

461453
if mask.any():
462-
# Two conversions for date-like dtypes that can't be done automatically
463-
# in np.place:
464-
# NaN -> NaT
465-
# integer or integer array -> date-like array
466-
if result.dtype.kind in ["m", "M"]:
467-
if isna(other):
468-
other = result.dtype.type("nat")
469-
elif is_integer(other):
470-
other = np.array(other, dtype=result.dtype)
471-
472-
def changeit():
473-
# we are forced to change the dtype of the result as the input
474-
# isn't compatible
475-
r, _ = maybe_upcast(result, fill_value=other, copy=True)
476-
np.place(r, mask, other)
477-
478-
return r, True
479454

480455
# we want to decide whether place will work
481456
# if we have nans in the False portion of our mask then we need to
482457
# upcast (possibly), otherwise we DON't want to upcast (e.g. if we
483458
# have values, say integers, in the success portion then it's ok to not
484459
# upcast)
485-
new_dtype, _ = maybe_promote(result.dtype, other)
460+
new_dtype, _ = maybe_promote(result.dtype, np.nan)
486461
if new_dtype != result.dtype:
462+
result = result.astype(new_dtype, copy=True)
487463

488-
# we have a scalar or len 0 ndarray
489-
# and its nan and we are changing some values
490-
if isna(other):
491-
return changeit()
492-
493-
try:
494-
np.place(result, mask, other)
495-
except TypeError:
496-
# e.g. int-dtype result and float-dtype other
497-
return changeit()
464+
np.place(result, mask, np.nan)
498465

499-
return result, False
466+
return result
500467

501468

502469
def maybe_promote(dtype, fill_value=np.nan):

pandas/core/ops/array_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _masked_arith_op(x: np.ndarray, y, op):
110110
with np.errstate(all="ignore"):
111111
result[mask] = op(xrav[mask], y)
112112

113-
result, _ = maybe_upcast_putmask(result, ~mask, np.nan)
113+
result = maybe_upcast_putmask(result, ~mask)
114114
result = result.reshape(x.shape) # 2D compat
115115
return result
116116

pandas/tests/dtypes/cast/test_upcast.py

Lines changed: 9 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,61 +11,27 @@
1111
def test_upcast_error(result):
1212
# GH23823 require result arg to be ndarray
1313
mask = np.array([False, True, False])
14-
other = np.array([61, 62, 63])
1514
with pytest.raises(ValueError, match="The result input must be a ndarray"):
16-
result, _ = maybe_upcast_putmask(result, mask, other)
17-
18-
19-
@pytest.mark.parametrize(
20-
"arr, other",
21-
[
22-
(np.arange(1, 6), np.array([61, 62, 63])),
23-
(np.arange(1, 6), np.array([61.1, 62.2, 63.3])),
24-
(np.arange(10, 15), np.array([61, 62])),
25-
(np.arange(10, 15), np.array([61, np.nan])),
26-
(
27-
np.arange("2019-01-01", "2019-01-06", dtype="datetime64[D]"),
28-
np.arange("2018-01-01", "2018-01-04", dtype="datetime64[D]"),
29-
),
30-
(
31-
np.arange("2019-01-01", "2019-01-06", dtype="datetime64[D]"),
32-
np.arange("2018-01-01", "2018-01-03", dtype="datetime64[D]"),
33-
),
34-
],
35-
)
36-
def test_upcast_scalar_other(arr, other):
37-
# for now we do not support non-scalar `other`
38-
mask = np.array([False, True, False, True, True])
39-
with pytest.raises(ValueError, match="other must be a scalar"):
40-
maybe_upcast_putmask(arr, mask, other)
15+
result = maybe_upcast_putmask(result, mask)
4116

4217

4318
def test_upcast():
4419
# GH23823
4520
arr = np.arange(1, 6)
4621
mask = np.array([False, True, False, True, True])
47-
result, changed = maybe_upcast_putmask(arr, mask, other=np.nan)
22+
result = maybe_upcast_putmask(arr, mask)
4823

4924
expected = np.array([1, np.nan, 3, np.nan, np.nan])
50-
assert changed
5125
tm.assert_numpy_array_equal(result, expected)
5226

5327

54-
def test_upcast_datetime():
55-
# GH23823
56-
arr = np.arange("2019-01-01", "2019-01-06", dtype="datetime64[D]")
28+
def test_maybe_upcast_putmask_bool():
29+
# a case where maybe_upcast_putmask is *not* equivalent to
30+
# try: np.putmask(result, mask, np.nan)
31+
# except (ValueError, TypeError): result = np.where(mask, result, np.nan)
32+
arr = np.array([True, False, True, False, True], dtype=bool)
5733
mask = np.array([False, True, False, True, True])
58-
result, changed = maybe_upcast_putmask(arr, mask, other=np.nan)
34+
result = maybe_upcast_putmask(arr, mask)
5935

60-
expected = np.array(
61-
[
62-
"2019-01-01",
63-
np.datetime64("NaT"),
64-
"2019-01-03",
65-
np.datetime64("NaT"),
66-
np.datetime64("NaT"),
67-
],
68-
dtype="datetime64[D]",
69-
)
70-
assert not changed
36+
expected = np.array([True, np.nan, True, np.nan, np.nan], dtype=object)
7137
tm.assert_numpy_array_equal(result, expected)

0 commit comments

Comments
 (0)