Skip to content

Commit 236d89b

Browse files
mutricylLaurent Mutricy
and
Laurent Mutricy
authored
update algo.take to solve #59177 (#59181)
* update algo.take to solve #59177 * forgot to update TestExtensionTake::test_take_coerces_list * fixing pandas/tests/dtypes/test_generic.py::TestABCClasses::test_abc_hierarchy * ABCExtensionArray set formatting --------- Co-authored-by: Laurent Mutricy <[email protected]>
1 parent f6d06b8 commit 236d89b

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

pandas/core/algorithms.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
ABCExtensionArray,
6969
ABCIndex,
7070
ABCMultiIndex,
71+
ABCNumpyExtensionArray,
7172
ABCSeries,
7273
ABCTimedeltaArray,
7374
)
@@ -1161,11 +1162,14 @@ def take(
11611162
... )
11621163
array([ 10, 10, -10])
11631164
"""
1164-
if not isinstance(arr, (np.ndarray, ABCExtensionArray, ABCIndex, ABCSeries)):
1165+
if not isinstance(
1166+
arr,
1167+
(np.ndarray, ABCExtensionArray, ABCIndex, ABCSeries, ABCNumpyExtensionArray),
1168+
):
11651169
# GH#52981
11661170
raise TypeError(
1167-
"pd.api.extensions.take requires a numpy.ndarray, "
1168-
f"ExtensionArray, Index, or Series, got {type(arr).__name__}."
1171+
"pd.api.extensions.take requires a numpy.ndarray, ExtensionArray, "
1172+
f"Index, Series, or NumpyExtensionArray got {type(arr).__name__}."
11691173
)
11701174

11711175
indices = ensure_platform_int(indices)

pandas/tests/test_take.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pandas._libs import iNaT
77

8+
from pandas import array
89
import pandas._testing as tm
910
import pandas.core.algorithms as algos
1011

@@ -303,7 +304,14 @@ def test_take_coerces_list(self):
303304
arr = [1, 2, 3]
304305
msg = (
305306
"pd.api.extensions.take requires a numpy.ndarray, ExtensionArray, "
306-
"Index, or Series, got list"
307+
"Index, Series, or NumpyExtensionArray got list"
307308
)
308309
with pytest.raises(TypeError, match=msg):
309310
algos.take(arr, [0, 0])
311+
312+
def test_take_NumpyExtensionArray(self):
313+
# GH#59177
314+
arr = array([1 + 1j, 2, 3]) # NumpyEADtype('complex128') (NumpyExtensionArray)
315+
assert algos.take(arr, [2]) == 2
316+
arr = array([1, 2, 3]) # Int64Dtype() (ExtensionArray)
317+
assert algos.take(arr, [2]) == 2

0 commit comments

Comments
 (0)