Skip to content

Commit 54f9b03

Browse files
authored
BUG/REF: unstack with EA dtypes (#33356)
1 parent 496c982 commit 54f9b03

File tree

5 files changed

+61
-33
lines changed

5 files changed

+61
-33
lines changed

pandas/core/frame.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,16 @@ def _is_homogeneous_type(self) -> bool:
600600
else:
601601
return not self._mgr.is_mixed_type
602602

603+
@property
604+
def _can_fast_transpose(self) -> bool:
605+
"""
606+
Can we transpose this DataFrame without creating any new array objects.
607+
"""
608+
if self._data.any_extension_types:
609+
# TODO(EA2D) special case would be unnecessary with 2D EAs
610+
return False
611+
return len(self._data.blocks) == 1
612+
603613
# ----------------------------------------------------------------------
604614
# Rendering Methods
605615

pandas/core/reshape/reshape.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import pandas.core.algorithms as algos
2525
from pandas.core.arrays import SparseArray
2626
from pandas.core.arrays.categorical import factorize_from_iterable
27-
from pandas.core.construction import extract_array
2827
from pandas.core.frame import DataFrame
2928
from pandas.core.indexes.api import Index, MultiIndex
3029
from pandas.core.series import Series
@@ -413,7 +412,7 @@ def unstack(obj, level, fill_value=None):
413412
level = obj.index._get_level_number(level)
414413

415414
if isinstance(obj, DataFrame):
416-
if isinstance(obj.index, MultiIndex):
415+
if isinstance(obj.index, MultiIndex) or not obj._can_fast_transpose:
417416
return _unstack_frame(obj, level, fill_value=fill_value)
418417
else:
419418
return obj.T.stack(dropna=False)
@@ -429,14 +428,14 @@ def unstack(obj, level, fill_value=None):
429428

430429

431430
def _unstack_frame(obj, level, fill_value=None):
432-
if obj._is_mixed_type:
431+
if not obj._can_fast_transpose:
433432
unstacker = _Unstacker(obj.index, level=level)
434-
blocks = obj._mgr.unstack(unstacker, fill_value=fill_value)
435-
return obj._constructor(blocks)
433+
mgr = obj._mgr.unstack(unstacker, fill_value=fill_value)
434+
return obj._constructor(mgr)
436435
else:
437436
return _Unstacker(
438437
obj.index, level=level, constructor=obj._constructor,
439-
).get_result(obj.values, value_columns=obj.columns, fill_value=fill_value)
438+
).get_result(obj._values, value_columns=obj.columns, fill_value=fill_value)
440439

441440

442441
def _unstack_extension_series(series, level, fill_value):
@@ -462,31 +461,10 @@ def _unstack_extension_series(series, level, fill_value):
462461
Each column of the DataFrame will have the same dtype as
463462
the input Series.
464463
"""
465-
# Implementation note: the basic idea is to
466-
# 1. Do a regular unstack on a dummy array of integers
467-
# 2. Followup with a columnwise take.
468-
# We use the dummy take to discover newly-created missing values
469-
# introduced by the reshape.
470-
from pandas.core.reshape.concat import concat
471-
472-
dummy_arr = np.arange(len(series))
473-
# fill_value=-1, since we will do a series.values.take later
474-
result = _Unstacker(series.index, level=level).get_result(
475-
dummy_arr, value_columns=None, fill_value=-1
476-
)
477-
478-
out = []
479-
values = extract_array(series, extract_numpy=False)
480-
481-
for col, indices in result.items():
482-
out.append(
483-
Series(
484-
values.take(indices.values, allow_fill=True, fill_value=fill_value),
485-
name=col,
486-
index=result.index,
487-
)
488-
)
489-
return concat(out, axis="columns", copy=False, keys=result.columns)
464+
# Defer to the logic in ExtensionBlock._unstack
465+
df = series.to_frame()
466+
result = df.unstack(level=level, fill_value=fill_value)
467+
return result.droplevel(level=0, axis=1)
490468

491469

492470
def stack(frame, level=-1, dropna=True):

pandas/tests/extension/base/casting.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,22 @@ class BaseCastingTests(BaseExtensionTests):
1010
"""Casting to and from ExtensionDtypes"""
1111

1212
def test_astype_object_series(self, all_data):
13-
ser = pd.Series({"A": all_data})
13+
ser = pd.Series(all_data, name="A")
1414
result = ser.astype(object)
1515
assert isinstance(result._mgr.blocks[0], ObjectBlock)
1616

17+
def test_astype_object_frame(self, all_data):
18+
df = pd.DataFrame({"A": all_data})
19+
20+
result = df.astype(object)
21+
blk = result._data.blocks[0]
22+
assert isinstance(blk, ObjectBlock), type(blk)
23+
24+
# FIXME: these currently fail; dont leave commented-out
25+
# check that we can compare the dtypes
26+
# cmp = result.dtypes.equals(df.dtypes)
27+
# assert not cmp.any()
28+
1729
def test_tolist(self, data):
1830
result = pd.Series(data).tolist()
1931
expected = list(data)

pandas/tests/extension/base/reshaping.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,14 @@ def test_unstack(self, data, index, obj):
295295
assert all(
296296
isinstance(result[col].array, type(data)) for col in result.columns
297297
)
298+
299+
if obj == "series":
300+
# We should get the same result with to_frame+unstack+droplevel
301+
df = ser.to_frame()
302+
303+
alt = df.unstack(level=level).droplevel(0, axis=1)
304+
self.assert_frame_equal(result, alt)
305+
298306
expected = ser.astype(object).unstack(level=level)
299307
result = result.astype(object)
300308

pandas/tests/extension/test_sparse.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from pandas.errors import PerformanceWarning
55

6+
from pandas.core.dtypes.common import is_object_dtype
7+
68
import pandas as pd
79
from pandas import SparseDtype
810
import pandas._testing as tm
@@ -309,7 +311,25 @@ def test_searchsorted(self, data_for_sorting, as_series):
309311

310312

311313
class TestCasting(BaseSparseTests, base.BaseCastingTests):
312-
pass
314+
def test_astype_object_series(self, all_data):
315+
# Unlike the base class, we do not expect the resulting Block
316+
# to be ObjectBlock
317+
ser = pd.Series(all_data, name="A")
318+
result = ser.astype(object)
319+
assert is_object_dtype(result._data.blocks[0].dtype)
320+
321+
def test_astype_object_frame(self, all_data):
322+
# Unlike the base class, we do not expect the resulting Block
323+
# to be ObjectBlock
324+
df = pd.DataFrame({"A": all_data})
325+
326+
result = df.astype(object)
327+
assert is_object_dtype(result._data.blocks[0].dtype)
328+
329+
# FIXME: these currently fail; dont leave commented-out
330+
# check that we can compare the dtypes
331+
# comp = result.dtypes.equals(df.dtypes)
332+
# assert not comp.any()
313333

314334

315335
class TestArithmeticOps(BaseSparseTests, base.BaseArithmeticOpsTests):

0 commit comments

Comments
 (0)