Skip to content

Commit cc63484

Browse files
BUG: fix concat of Sparse with non-sparse dtypes (#34338)
1 parent 6e69ca4 commit cc63484

File tree

5 files changed

+54
-2
lines changed

5 files changed

+54
-2
lines changed

pandas/core/arrays/sparse/array.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,9 @@ def astype(self, dtype=None, copy=True):
10631063
"""
10641064
dtype = self.dtype.update_dtype(dtype)
10651065
subtype = dtype._subtype_with_str
1066-
sp_values = astype_nansafe(self.sp_values, subtype, copy=copy)
1066+
# TODO copy=False is broken for astype_nansafe with int -> float, so cannot
1067+
# passthrough copy keyword: https://github.com/pandas-dev/pandas/issues/34456
1068+
sp_values = astype_nansafe(self.sp_values, subtype, copy=True)
10671069
if sp_values is self.sp_values and copy:
10681070
sp_values = sp_values.copy()
10691071

pandas/core/arrays/sparse/dtype.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,13 @@ def _subtype_with_str(self):
360360
return self.subtype
361361

362362
def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
363+
# TODO for now only handle SparseDtypes and numpy dtypes => extend
364+
# with other compatibtle extension dtypes
365+
if any(
366+
isinstance(x, ExtensionDtype) and not isinstance(x, SparseDtype)
367+
for x in dtypes
368+
):
369+
return None
363370

364371
fill_values = [x.fill_value for x in dtypes if isinstance(x, SparseDtype)]
365372
fill_value = fill_values[0]
@@ -375,6 +382,5 @@ def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
375382
stacklevel=6,
376383
)
377384

378-
# TODO also handle non-numpy other dtypes
379385
np_dtypes = [x.subtype if isinstance(x, SparseDtype) else x for x in dtypes]
380386
return SparseDtype(np.find_common_type(np_dtypes, []), fill_value=fill_value)

pandas/core/dtypes/concat.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Utility functions related to concat.
33
"""
4+
from typing import cast
45

56
import numpy as np
67

@@ -21,6 +22,7 @@
2122
from pandas.core.dtypes.generic import ABCCategoricalIndex, ABCRangeIndex, ABCSeries
2223

2324
from pandas.core.arrays import ExtensionArray
25+
from pandas.core.arrays.sparse import SparseArray
2426
from pandas.core.construction import array
2527

2628

@@ -81,6 +83,13 @@ def _cast_to_common_type(arr: ArrayLike, dtype: DtypeObj) -> ArrayLike:
8183
except ValueError:
8284
return arr.astype(object, copy=False)
8385

86+
if is_sparse(arr) and not is_sparse(dtype):
87+
# problem case: SparseArray.astype(dtype) doesn't follow the specified
88+
# dtype exactly, but converts this to Sparse[dtype] -> first manually
89+
# convert to dense array
90+
arr = cast(SparseArray, arr)
91+
return arr.to_dense().astype(dtype, copy=False)
92+
8493
if (
8594
isinstance(arr, np.ndarray)
8695
and arr.dtype.kind in ["m", "M"]

pandas/core/dtypes/dtypes.py

+4
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,8 @@ def _is_boolean(self) -> bool:
642642
return is_bool_dtype(self.categories)
643643

644644
def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
645+
from pandas.core.arrays.sparse import SparseDtype
646+
645647
# check if we have all categorical dtype with identical categories
646648
if all(isinstance(x, CategoricalDtype) for x in dtypes):
647649
first = dtypes[0]
@@ -658,6 +660,8 @@ def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
658660
elif any(non_init_cats):
659661
return None
660662

663+
# categorical is aware of Sparse -> extract sparse subdtypes
664+
dtypes = [x.subtype if isinstance(x, SparseDtype) else x for x in dtypes]
661665
# extract the categories' dtype
662666
non_cat_dtypes = [
663667
x.categories.dtype if isinstance(x, CategoricalDtype) else x for x in dtypes

pandas/tests/arrays/sparse/test_combine_concat.py

+31
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33

4+
import pandas as pd
45
import pandas._testing as tm
56
from pandas.core.arrays.sparse import SparseArray
67

@@ -29,3 +30,33 @@ def test_uses_first_kind(self, kind):
2930
expected = np.array([1, 2, 1, 2, 2], dtype="int64")
3031
tm.assert_numpy_array_equal(result.sp_values, expected)
3132
assert result.kind == kind
33+
34+
35+
@pytest.mark.parametrize(
36+
"other, expected_dtype",
37+
[
38+
# compatible dtype -> preserve sparse
39+
(pd.Series([3, 4, 5], dtype="int64"), pd.SparseDtype("int64", 0)),
40+
# (pd.Series([3, 4, 5], dtype="Int64"), pd.SparseDtype("int64", 0)),
41+
# incompatible dtype -> Sparse[common dtype]
42+
(pd.Series([1.5, 2.5, 3.5], dtype="float64"), pd.SparseDtype("float64", 0)),
43+
# incompatible dtype -> Sparse[object] dtype
44+
(pd.Series(["a", "b", "c"], dtype=object), pd.SparseDtype(object, 0)),
45+
# categorical with compatible categories -> dtype of the categories
46+
(pd.Series([3, 4, 5], dtype="category"), np.dtype("int64")),
47+
(pd.Series([1.5, 2.5, 3.5], dtype="category"), np.dtype("float64")),
48+
# categorical with incompatible categories -> object dtype
49+
(pd.Series(["a", "b", "c"], dtype="category"), np.dtype(object)),
50+
],
51+
)
52+
def test_concat_with_non_sparse(other, expected_dtype):
53+
# https://github.com/pandas-dev/pandas/issues/34336
54+
s_sparse = pd.Series([1, 0, 2], dtype=pd.SparseDtype("int64", 0))
55+
56+
result = pd.concat([s_sparse, other], ignore_index=True)
57+
expected = pd.Series(list(s_sparse) + list(other)).astype(expected_dtype)
58+
tm.assert_series_equal(result, expected)
59+
60+
result = pd.concat([other, s_sparse], ignore_index=True)
61+
expected = pd.Series(list(other) + list(s_sparse)).astype(expected_dtype)
62+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)