Skip to content

Commit 3754e7c

Browse files
authored
ENH: pad: pad_width can be any sequence (#114)
1 parent 48fb66a commit 3754e7c

File tree

3 files changed

+21
-12
lines changed

3 files changed

+21
-12
lines changed

src/array_api_extra/_delegation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Delegation to existing implementations for Public API Functions."""
22

3+
from collections.abc import Sequence
34
from types import ModuleType
45
from typing import Literal
56

@@ -31,7 +32,7 @@ def _delegate(xp: ModuleType, *backends: Backend) -> bool:
3132

3233
def pad(
3334
x: Array,
34-
pad_width: int | tuple[int, int] | list[tuple[int, int]],
35+
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],
3536
mode: Literal["constant"] = "constant",
3637
*,
3738
constant_values: bool | int | float | complex = 0,
@@ -44,9 +45,9 @@ def pad(
4445
----------
4546
x : array
4647
Input array.
47-
pad_width : int or tuple of ints or list of pairs of ints
48+
pad_width : int or tuple of ints or sequence of pairs of ints
4849
Pad the input array with this many elements from each side.
49-
If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
50+
If a sequence of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
5051
each pair applies to the corresponding axis of ``x``.
5152
A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim``
5253
copies of this tuple.

src/array_api_extra/_lib/_funcs.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import math
77
import warnings
8+
from collections.abc import Sequence
89
from types import ModuleType
910
from typing import cast
1011

@@ -448,23 +449,30 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
448449

449450
def pad(
450451
x: Array,
451-
pad_width: int | tuple[int, int] | list[tuple[int, int]],
452+
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],
452453
*,
453454
constant_values: bool | int | float | complex = 0,
454455
xp: ModuleType,
455456
) -> Array: # numpydoc ignore=PR01,RT01
456457
"""See docstring in `array_api_extra._delegation.py`."""
457458
# make pad_width a list of length-2 tuples of ints
458459
x_ndim = cast(int, x.ndim)
460+
459461
if isinstance(pad_width, int):
460-
pad_width = [(pad_width, pad_width)] * x_ndim
461-
if isinstance(pad_width, tuple):
462-
pad_width = [pad_width] * x_ndim
462+
pad_width_seq = [(pad_width, pad_width)] * x_ndim
463+
elif (
464+
isinstance(pad_width, tuple)
465+
and len(pad_width) == 2
466+
and all(isinstance(i, int) for i in pad_width)
467+
):
468+
pad_width_seq = [cast(tuple[int, int], pad_width)] * x_ndim
469+
else:
470+
pad_width_seq = cast(list[tuple[int, int]], list(pad_width))
463471

464472
# https://github.com/python/typeshed/issues/13376
465473
slices: list[slice] = [] # type: ignore[no-any-explicit]
466474
newshape: list[int] = []
467-
for ax, w_tpl in enumerate(pad_width):
475+
for ax, w_tpl in enumerate(pad_width_seq):
468476
if len(w_tpl) != 2:
469477
msg = f"expect a 2-tuple (before, after), got {w_tpl}."
470478
raise ValueError(msg)

tests/test_funcs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,12 +390,12 @@ def test_tuple_width(self, xp: ModuleType):
390390
with pytest.raises((ValueError, RuntimeError)):
391391
pad(a, [(1, 2, 3)]) # type: ignore[list-item] # pyright: ignore[reportArgumentType]
392392

393-
def test_list_of_tuples_width(self, xp: ModuleType):
393+
def test_sequence_of_tuples_width(self, xp: ModuleType):
394394
a = xp.reshape(xp.arange(12), (3, 4))
395-
padded = pad(a, [(1, 0), (0, 2)])
396-
assert padded.shape == (4, 6)
397395

398-
padded = pad(a, [(1, 0), (0, 0)])
396+
padded = pad(a, ((1, 0), (0, 2)))
397+
assert padded.shape == (4, 6)
398+
padded = pad(a, ((1, 0), (0, 0)))
399399
assert padded.shape == (4, 4)
400400

401401

0 commit comments

Comments
 (0)