|
6 | 6 | from shutil import get_terminal_size
|
7 | 7 | from typing import (
|
8 | 8 | TYPE_CHECKING,
|
9 |
| - Any, |
10 | 9 | Hashable,
|
11 | 10 | Sequence,
|
12 | 11 | TypeVar,
|
|
38 | 37 | Dtype,
|
39 | 38 | NpDtype,
|
40 | 39 | Ordered,
|
41 |
| - PositionalIndexer2D, |
42 |
| - PositionalIndexerTuple, |
43 |
| - ScalarIndexer, |
44 |
| - SequenceIndexer, |
45 | 40 | Shape,
|
46 | 41 | npt,
|
47 | 42 | type_t,
|
|
102 | 97 | take_nd,
|
103 | 98 | unique1d,
|
104 | 99 | )
|
105 |
| -from pandas.core.arrays._mixins import NDArrayBackedExtensionArray |
| 100 | +from pandas.core.arrays._mixins import ( |
| 101 | + NDArrayBackedExtensionArray, |
| 102 | + ravel_compat, |
| 103 | +) |
106 | 104 | from pandas.core.base import (
|
107 | 105 | ExtensionArray,
|
108 | 106 | NoNewAttributesMixin,
|
|
113 | 111 | extract_array,
|
114 | 112 | sanitize_array,
|
115 | 113 | )
|
116 |
| -from pandas.core.indexers import deprecate_ndim_indexing |
117 | 114 | from pandas.core.ops.common import unpack_zerodim_and_defer
|
118 | 115 | from pandas.core.sorting import nargsort
|
119 | 116 | from pandas.core.strings.object_array import ObjectStringArrayMixin
|
@@ -1484,6 +1481,7 @@ def _validate_scalar(self, fill_value):
|
1484 | 1481 |
|
1485 | 1482 | # -------------------------------------------------------------
|
1486 | 1483 |
|
| 1484 | + @ravel_compat |
1487 | 1485 | def __array__(self, dtype: NpDtype | None = None) -> np.ndarray:
|
1488 | 1486 | """
|
1489 | 1487 | The numpy array interface.
|
@@ -1934,7 +1932,10 @@ def __iter__(self):
|
1934 | 1932 | """
|
1935 | 1933 | Returns an Iterator over the values of this Categorical.
|
1936 | 1934 | """
|
1937 |
| - return iter(self._internal_get_values().tolist()) |
| 1935 | + if self.ndim == 1: |
| 1936 | + return iter(self._internal_get_values().tolist()) |
| 1937 | + else: |
| 1938 | + return (self[n] for n in range(len(self))) |
1938 | 1939 |
|
1939 | 1940 | def __contains__(self, key) -> bool:
|
1940 | 1941 | """
|
@@ -2053,27 +2054,6 @@ def __repr__(self) -> str:
|
2053 | 2054 |
|
2054 | 2055 | # ------------------------------------------------------------------
|
2055 | 2056 |
|
2056 |
| - @overload |
2057 |
| - def __getitem__(self, key: ScalarIndexer) -> Any: |
2058 |
| - ... |
2059 |
| - |
2060 |
| - @overload |
2061 |
| - def __getitem__( |
2062 |
| - self: CategoricalT, |
2063 |
| - key: SequenceIndexer | PositionalIndexerTuple, |
2064 |
| - ) -> CategoricalT: |
2065 |
| - ... |
2066 |
| - |
2067 |
| - def __getitem__(self: CategoricalT, key: PositionalIndexer2D) -> CategoricalT | Any: |
2068 |
| - """ |
2069 |
| - Return an item. |
2070 |
| - """ |
2071 |
| - result = super().__getitem__(key) |
2072 |
| - if getattr(result, "ndim", 0) > 1: |
2073 |
| - result = result._ndarray |
2074 |
| - deprecate_ndim_indexing(result) |
2075 |
| - return result |
2076 |
| - |
2077 | 2057 | def _validate_listlike(self, value):
|
2078 | 2058 | # NB: here we assume scalar-like tuples have already been excluded
|
2079 | 2059 | value = extract_array(value, extract_numpy=True)
|
@@ -2311,7 +2291,19 @@ def _concat_same_type(
|
2311 | 2291 | ) -> CategoricalT:
|
2312 | 2292 | from pandas.core.dtypes.concat import union_categoricals
|
2313 | 2293 |
|
2314 |
| - return union_categoricals(to_concat) |
| 2294 | + result = union_categoricals(to_concat) |
| 2295 | + |
| 2296 | + # in case we are concatenating along axis != 0, we need to reshape |
| 2297 | + # the result from union_categoricals |
| 2298 | + first = to_concat[0] |
| 2299 | + if axis >= first.ndim: |
| 2300 | + raise ValueError |
| 2301 | + if axis == 1: |
| 2302 | + if not all(len(x) == len(first) for x in to_concat): |
| 2303 | + raise ValueError |
| 2304 | + # TODO: Will this get contiguity wrong? |
| 2305 | + result = result.reshape(-1, len(to_concat), order="F") |
| 2306 | + return result |
2315 | 2307 |
|
2316 | 2308 | # ------------------------------------------------------------------
|
2317 | 2309 |
|
@@ -2699,6 +2691,11 @@ def _get_codes_for_values(values, categories: Index) -> np.ndarray:
|
2699 | 2691 | """
|
2700 | 2692 | dtype_equal = is_dtype_equal(values.dtype, categories.dtype)
|
2701 | 2693 |
|
| 2694 | + if values.ndim > 1: |
| 2695 | + flat = values.ravel() |
| 2696 | + codes = _get_codes_for_values(flat, categories) |
| 2697 | + return codes.reshape(values.shape) |
| 2698 | + |
2702 | 2699 | if isinstance(categories.dtype, ExtensionDtype) and is_object_dtype(values):
|
2703 | 2700 | # Support inferring the correct extension dtype from an array of
|
2704 | 2701 | # scalar objects. e.g.
|
|
0 commit comments