Skip to content

Commit 3f07983

Browse files
authored
ENH: implement EA.delete (#39405)
1 parent c519389 commit 3f07983

File tree

6 files changed

+39
-27
lines changed

6 files changed

+39
-27
lines changed

pandas/core/arrays/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -997,12 +997,12 @@ def repeat(self, repeats, axis=None):
997997
# ------------------------------------------------------------------------
998998

999999
def take(
1000-
self,
1000+
self: ExtensionArrayT,
10011001
indices: Sequence[int],
10021002
*,
10031003
allow_fill: bool = False,
10041004
fill_value: Any = None,
1005-
) -> ExtensionArray:
1005+
) -> ExtensionArrayT:
10061006
"""
10071007
Take elements from an array.
10081008
@@ -1261,6 +1261,13 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
12611261
def __hash__(self):
12621262
raise TypeError(f"unhashable type: {repr(type(self).__name__)}")
12631263

1264+
# ------------------------------------------------------------------------
1265+
# Non-Optimized Default Methods
1266+
1267+
def delete(self: ExtensionArrayT, loc) -> ExtensionArrayT:
1268+
indexer = np.delete(np.arange(len(self)), loc)
1269+
return self.take(indexer)
1270+
12641271

12651272
class ExtensionOpsMixin:
12661273
"""

pandas/core/arrays/interval.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,6 +1413,15 @@ def to_tuples(self, na_tuple=True):
14131413

14141414
# ---------------------------------------------------------------------
14151415

1416+
def delete(self: IntervalArrayT, loc) -> IntervalArrayT:
1417+
if isinstance(self._left, np.ndarray):
1418+
new_left = np.delete(self._left, loc)
1419+
new_right = np.delete(self._right, loc)
1420+
else:
1421+
new_left = self._left.delete(loc)
1422+
new_right = self._right.delete(loc)
1423+
return self._shallow_copy(left=new_left, right=new_right)
1424+
14161425
@Appender(_extension_array_shared_docs["repeat"] % _shared_docs_kwargs)
14171426
def repeat(self, repeats, axis=None):
14181427
nv.validate_repeat((), {"axis": axis})

pandas/core/arrays/string_arrow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ def __setitem__(self, key: Union[int, np.ndarray], value: Any) -> None:
507507

508508
def take(
509509
self, indices: Sequence[int], allow_fill: bool = False, fill_value: Any = None
510-
) -> ExtensionArray:
510+
):
511511
"""
512512
Take elements from an array.
513513

pandas/core/indexes/extension.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,17 @@ def searchsorted(self, value, side="left", sorter=None) -> np.ndarray:
257257
def _get_engine_target(self) -> np.ndarray:
258258
return np.asarray(self._data)
259259

260+
def delete(self, loc):
261+
"""
262+
Make new Index with passed location(-s) deleted
263+
264+
Returns
265+
-------
266+
new_index : Index
267+
"""
268+
arr = self._data.delete(loc)
269+
return type(self)._simple_new(arr, name=self.name)
270+
260271
def repeat(self, repeats, axis=None):
261272
nv.validate_repeat((), {"axis": axis})
262273
result = self._data.repeat(repeats, axis=axis)
@@ -333,17 +344,6 @@ class NDArrayBackedExtensionIndex(ExtensionIndex):
333344
def _get_engine_target(self) -> np.ndarray:
334345
return self._data._ndarray
335346

336-
def delete(self: _T, loc) -> _T:
337-
"""
338-
Make new Index with passed location(-s) deleted
339-
340-
Returns
341-
-------
342-
new_index : Index
343-
"""
344-
arr = self._data.delete(loc)
345-
return type(self)._simple_new(arr, name=self.name)
346-
347347
def insert(self: _T, loc: int, item) -> _T:
348348
"""
349349
Make new Index inserting new item at location. Follows

pandas/core/indexes/interval.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -822,19 +822,6 @@ def where(self, cond, other=None):
822822
result = IntervalArray(values)
823823
return type(self)._simple_new(result, name=self.name)
824824

825-
def delete(self, loc):
826-
"""
827-
Return a new IntervalIndex with passed location(-s) deleted
828-
829-
Returns
830-
-------
831-
IntervalIndex
832-
"""
833-
new_left = self.left.delete(loc)
834-
new_right = self.right.delete(loc)
835-
result = self._data._shallow_copy(new_left, new_right)
836-
return type(self)._simple_new(result, name=self.name)
837-
838825
def insert(self, loc, item):
839826
"""
840827
Return a new IntervalIndex inserting new item at location. Follows

pandas/tests/extension/base/methods.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,15 @@ def test_repeat_raises(self, data, repeats, kwargs, error, msg, use_numpy):
491491
else:
492492
data.repeat(repeats, **kwargs)
493493

494+
def test_delete(self, data):
495+
result = data.delete(0)
496+
expected = data[1:]
497+
self.assert_extension_array_equal(result, expected)
498+
499+
result = data.delete([1, 3])
500+
expected = data._concat_same_type([data[[0]], data[[2]], data[4:]])
501+
self.assert_extension_array_equal(result, expected)
502+
494503
@pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame])
495504
def test_equals(self, data, na_value, as_series, box):
496505
data2 = type(data)._from_sequence([data[0]] * len(data), dtype=data.dtype)

0 commit comments

Comments
 (0)