Skip to content

Commit 546fa3d

Browse files
authored
Merge pull request #47 from asmeurer/take-fix
Fix the torch.take() wrapper to make axis optional for ndim = 1
2 parents ea6a9d6 + 4fd4a0d commit 546fa3d

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

array_api_compat/torch/_aliases.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,11 @@ def isdtype(
681681
else:
682682
return dtype == kind
683683

684-
def take(x: array, indices: array, /, *, axis: int, **kwargs) -> array:
684+
def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array:
685+
if axis is None:
686+
if x.ndim != 1:
687+
raise ValueError("axis must be specified when ndim > 1")
688+
axis = 0
685689
return torch.index_select(x, axis, indices, **kwargs)
686690

687691
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add',

0 commit comments

Comments
 (0)