Skip to content

Commit 4fd4a0d

Browse files
committed
Fix the torch.take() wrapper to make axis optional for ndim = 1
Closes #34
1 parent dab6775 commit 4fd4a0d

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,11 @@ def isdtype(
681681
else:
682682
return dtype == kind
683683

684-
def take(x: array, indices: array, /, *, axis: int, **kwargs) -> array:
684+
def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array:
685+
if axis is None:
686+
if x.ndim != 1:
687+
raise ValueError("axis must be specified when ndim > 1")
688+
axis = 0
685689
return torch.index_select(x, axis, indices, **kwargs)
686690

687691
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add',

0 commit comments

Comments
 (0)