Skip to content

Commit b379772

Browse files
committed
BUG: fix unsafe dtype changes in putmasking on series
allow boolean indexing of series w/o changing the dtype with a list of the rhs if we can preserver the dtype of the input, if we can't upcast, but can only do this in cases where we won't change the itemsize
1 parent 154252a commit b379772

File tree

3 files changed

+52
-13
lines changed

3 files changed

+52
-13
lines changed

pandas/core/common.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -760,12 +760,27 @@ def _maybe_upcast_putmask(result, mask, other, dtype=None, change=None):
760760

761761
def changeit():
762762

763-
# our type is wrong here, need to upcast
763+
# try to directly set by expanding our array to full
764+
# length of the boolean
765+
om = other[mask]
766+
om_at = om.astype(result.dtype)
767+
if (om == om_at).all():
768+
new_other = result.values.copy()
769+
new_other[mask] = om_at
770+
result[:] = new_other
771+
return result, False
772+
773+
# we are forced to change the dtype of the result as the input isn't compatible
764774
r, fill_value = _maybe_upcast(result, fill_value=other, dtype=dtype, copy=True)
765775
np.putmask(r, mask, other)
766776

767777
# we need to actually change the dtype here
768778
if change is not None:
779+
780+
# if we are trying to do something unsafe
781+
# like put a bigger dtype in a smaller one, use the smaller one
782+
if change.dtype.itemsize < r.dtype.itemsize:
783+
raise Exception("cannot change dtype of input to smaller size")
769784
change.dtype = r.dtype
770785
change[:] = r
771786

pandas/core/series.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,13 @@ def where(self, cond, other=nan, inplace=False):
739739
if isinstance(other, Series):
740740
other = other.reindex(ser.index)
741741
elif isinstance(other, (tuple,list)):
742-
other = np.array(other)
742+
743+
# try to set the same dtype as ourselves
744+
new_other = np.array(other,dtype=self.dtype)
745+
if not (new_other == np.array(other)).all():
746+
other = np.array(other)
747+
else:
748+
other = new_other
743749

744750
if len(other) != len(ser):
745751
icond = ~cond

pandas/tests/test_series.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,26 +1127,44 @@ def test_where(self):
11271127
self.assertRaises(ValueError, s.__setitem__, tuple([[[True, False]]]), [0,2,3])
11281128
self.assertRaises(ValueError, s.__setitem__, tuple([[[True, False]]]), [])
11291129

1130-
1131-
s = Series(np.arange(10), dtype=np.int32)
1132-
mask = s < 5
1133-
s[mask] = range(5)
1134-
expected = Series(np.arange(10), dtype=np.int32)
1135-
assert_series_equal(s, expected)
1136-
self.assertEquals(s.dtype, expected.dtype)
1130+
# unsafe dtype changes
1131+
for dtype in [ np.int8, np.int16, np.int32, np.int64, np.float16, np.float32, np.float64 ]:
1132+
s = Series(np.arange(10), dtype=dtype)
1133+
mask = s < 5
1134+
s[mask] = range(2,7)
1135+
expected = Series(range(2,7) + range(5,10), dtype=dtype)
1136+
assert_series_equal(s, expected)
1137+
self.assertEquals(s.dtype, expected.dtype)
1138+
1139+
# these are allowed operations, but are upcasted
1140+
for dtype in [ np.int64, np.float64 ]:
1141+
s = Series(np.arange(10), dtype=dtype)
1142+
mask = s < 5
1143+
values = [2.5,3.5,4.5,5.5,6.5]
1144+
s[mask] = values
1145+
expected = Series(values + range(5,10), dtype='float64')
1146+
assert_series_equal(s, expected)
1147+
self.assertEquals(s.dtype, expected.dtype)
1148+
1149+
# can't do these as we are forced to change the itemsize of the input to something we cannot
1150+
for dtype in [ np.int8, np.int16, np.int32, np.float16, np.float32 ]:
1151+
s = Series(np.arange(10), dtype=dtype)
1152+
mask = s < 5
1153+
values = [2.5,3.5,4.5,5.5,6.5]
1154+
self.assertRaises(Exception, s.__setitem__, tuple(mask), values)
11371155

11381156
# GH3235
11391157
s = Series(np.arange(10))
11401158
mask = s < 5
1141-
s[mask] = range(5)
1142-
expected = Series(np.arange(10))
1159+
s[mask] = range(2,7)
1160+
expected = Series(range(2,7) + range(5,10))
11431161
assert_series_equal(s, expected)
11441162
self.assertEquals(s.dtype, expected.dtype)
11451163

11461164
s = Series(np.arange(10))
11471165
mask = s > 5
11481166
s[mask] = [0]*4
1149-
expected = Series([0,1,2,3,4,5] + [0]*4,dtype='float64')
1167+
expected = Series([0,1,2,3,4,5] + [0]*4)
11501168
assert_series_equal(s,expected)
11511169

11521170
s = Series(np.arange(10))
@@ -3174,7 +3192,7 @@ def test_cast_on_putmask(self):
31743192
# need to upcast
31753193
s = Series([1,2],index=[1,2],dtype='int64')
31763194
s[[True, False]] = Series([0],index=[1],dtype='int64')
3177-
expected = Series([0,2],index=[1,2],dtype='float64')
3195+
expected = Series([0,2],index=[1,2],dtype='int64')
31783196

31793197
assert_series_equal(s, expected)
31803198

0 commit comments

Comments
 (0)