Skip to content

Commit 0c08145

Browse files
committed
BUG: incorrect broadcasting that could casuse dtype coercion in a groupby-transform
closes #14457
1 parent 794f792 commit 0c08145

File tree

3 files changed

+25
-6
lines changed

3 files changed

+25
-6
lines changed

doc/source/whatsnew/v0.19.1.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Bug Fixes
3838
- Bug in localizing an ambiguous timezone when a boolean is passed (:issue:`14402`)
3939

4040

41-
41+
- Bug in groupby-transform broadcasting that could cause incorrect dtype coercion (:issue:`14457`)
4242

4343

4444

pandas/core/groupby.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3454,7 +3454,6 @@ def _transform_general(self, func, *args, **kwargs):
34543454
from pandas.tools.merge import concat
34553455

34563456
applied = []
3457-
34583457
obj = self._obj_with_exclusions
34593458
gen = self.grouper.get_iterator(obj, axis=self.axis)
34603459
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
@@ -3475,14 +3474,22 @@ def _transform_general(self, func, *args, **kwargs):
34753474
else:
34763475
res = path(group)
34773476

3478-
# broadcasting
34793477
if isinstance(res, Series):
3478+
3479+
# we need to broadcast across the
3480+
# other dimension; this will preserve dtypes
3481+
# GH14457
34803482
if res.index.is_(obj.index):
3481-
group.T.values[:] = res
3483+
r = concat([res] * len(group.columns), axis=1)
3484+
r.columns = group.columns
3485+
r.index = group.index
34823486
else:
3483-
group.values[:] = res
3487+
r = DataFrame(
3488+
np.concatenate([res.values] * len(group.index)
3489+
).reshape(group.shape),
3490+
columns=group.columns, index=group.index)
34843491

3485-
applied.append(group)
3492+
applied.append(r)
34863493
else:
34873494
applied.append(res)
34883495

pandas/tests/test_groupby.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,18 @@ def nsum(x):
13361336
for result in results:
13371337
assert_series_equal(result, expected, check_names=False)
13381338

1339+
def test_transform_coercion(self):
1340+
1341+
# 14457
1342+
# when we are transforming be sure to not coerce
1343+
# via assignment
1344+
df = pd.DataFrame(dict(A=['a', 'a'], B=[0, 1]))
1345+
g = df.groupby('A')
1346+
1347+
expected = g.transform(np.mean)
1348+
result = g.transform(lambda x: np.mean(x))
1349+
assert_frame_equal(result, expected)
1350+
13391351
def test_with_na(self):
13401352
index = Index(np.arange(10))
13411353

0 commit comments

Comments
 (0)