Skip to content

Commit 87b8a6e

Browse files
authored
TST: enable 2D tests for Categorical (#44206)
1 parent 31628b3 commit 87b8a6e

File tree

2 files changed

+38
-30
lines changed

2 files changed

+38
-30
lines changed

pandas/core/arrays/categorical.py

+27-30
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from shutil import get_terminal_size
77
from typing import (
88
TYPE_CHECKING,
9-
Any,
109
Hashable,
1110
Sequence,
1211
TypeVar,
@@ -38,10 +37,6 @@
3837
Dtype,
3938
NpDtype,
4039
Ordered,
41-
PositionalIndexer2D,
42-
PositionalIndexerTuple,
43-
ScalarIndexer,
44-
SequenceIndexer,
4540
Shape,
4641
npt,
4742
type_t,
@@ -102,7 +97,10 @@
10297
take_nd,
10398
unique1d,
10499
)
105-
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
100+
from pandas.core.arrays._mixins import (
101+
NDArrayBackedExtensionArray,
102+
ravel_compat,
103+
)
106104
from pandas.core.base import (
107105
ExtensionArray,
108106
NoNewAttributesMixin,
@@ -113,7 +111,6 @@
113111
extract_array,
114112
sanitize_array,
115113
)
116-
from pandas.core.indexers import deprecate_ndim_indexing
117114
from pandas.core.ops.common import unpack_zerodim_and_defer
118115
from pandas.core.sorting import nargsort
119116
from pandas.core.strings.object_array import ObjectStringArrayMixin
@@ -1484,6 +1481,7 @@ def _validate_scalar(self, fill_value):
14841481

14851482
# -------------------------------------------------------------
14861483

1484+
@ravel_compat
14871485
def __array__(self, dtype: NpDtype | None = None) -> np.ndarray:
14881486
"""
14891487
The numpy array interface.
@@ -1934,7 +1932,10 @@ def __iter__(self):
19341932
"""
19351933
Returns an Iterator over the values of this Categorical.
19361934
"""
1937-
return iter(self._internal_get_values().tolist())
1935+
if self.ndim == 1:
1936+
return iter(self._internal_get_values().tolist())
1937+
else:
1938+
return (self[n] for n in range(len(self)))
19381939

19391940
def __contains__(self, key) -> bool:
19401941
"""
@@ -2053,27 +2054,6 @@ def __repr__(self) -> str:
20532054

20542055
# ------------------------------------------------------------------
20552056

2056-
@overload
2057-
def __getitem__(self, key: ScalarIndexer) -> Any:
2058-
...
2059-
2060-
@overload
2061-
def __getitem__(
2062-
self: CategoricalT,
2063-
key: SequenceIndexer | PositionalIndexerTuple,
2064-
) -> CategoricalT:
2065-
...
2066-
2067-
def __getitem__(self: CategoricalT, key: PositionalIndexer2D) -> CategoricalT | Any:
2068-
"""
2069-
Return an item.
2070-
"""
2071-
result = super().__getitem__(key)
2072-
if getattr(result, "ndim", 0) > 1:
2073-
result = result._ndarray
2074-
deprecate_ndim_indexing(result)
2075-
return result
2076-
20772057
def _validate_listlike(self, value):
20782058
# NB: here we assume scalar-like tuples have already been excluded
20792059
value = extract_array(value, extract_numpy=True)
@@ -2311,7 +2291,19 @@ def _concat_same_type(
23112291
) -> CategoricalT:
23122292
from pandas.core.dtypes.concat import union_categoricals
23132293

2314-
return union_categoricals(to_concat)
2294+
result = union_categoricals(to_concat)
2295+
2296+
# in case we are concatenating along axis != 0, we need to reshape
2297+
# the result from union_categoricals
2298+
first = to_concat[0]
2299+
if axis >= first.ndim:
2300+
raise ValueError
2301+
if axis == 1:
2302+
if not all(len(x) == len(first) for x in to_concat):
2303+
raise ValueError
2304+
# TODO: Will this get contiguity wrong?
2305+
result = result.reshape(-1, len(to_concat), order="F")
2306+
return result
23152307

23162308
# ------------------------------------------------------------------
23172309

@@ -2699,6 +2691,11 @@ def _get_codes_for_values(values, categories: Index) -> np.ndarray:
26992691
"""
27002692
dtype_equal = is_dtype_equal(values.dtype, categories.dtype)
27012693

2694+
if values.ndim > 1:
2695+
flat = values.ravel()
2696+
codes = _get_codes_for_values(flat, categories)
2697+
return codes.reshape(values.shape)
2698+
27022699
if isinstance(categories.dtype, ExtensionDtype) and is_object_dtype(values):
27032700
# Support inferring the correct extension dtype from an array of
27042701
# scalar objects. e.g.

pandas/tests/extension/test_categorical.py

+11
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,14 @@ def test_not_equal_with_na(self, categories):
303303

304304
class TestParsing(base.BaseParsingTests):
305305
pass
306+
307+
308+
class Test2DCompat(base.Dim2CompatTests):
309+
def test_repr_2d(self, data):
310+
# Categorical __repr__ doesn't include "Categorical", so we need
311+
# to special-case
312+
res = repr(data.reshape(1, -1))
313+
assert res.count("\nCategories") == 1
314+
315+
res = repr(data.reshape(-1, 1))
316+
assert res.count("\nCategories") == 1

0 commit comments

Comments
 (0)