Skip to content

Commit 0170b35

Browse files
committed
Add support for device/copy kw in from_dlpack to match Array API
1 parent 2583757 commit 0170b35

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

jax/numpy/__init__.pyi

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClas
1212
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DimSize, DuckTypedArray, Shape
1313
from jax.numpy import fft as fft, linalg as linalg
1414
from jax.sharding import Sharding as _Sharding
15-
import jaxlib.xla_client as xc
1615
import numpy as _np
1716

1817
_T = TypeVar('_T')
@@ -354,7 +353,7 @@ def fmax(x: ArrayLike, y: ArrayLike, /) -> Array: ...
354353
def fmin(x: ArrayLike, y: ArrayLike, /) -> Array: ...
355354
def fmod(x: ArrayLike, y: ArrayLike, /) -> Array: ...
356355
def frexp(x: ArrayLike, /) -> tuple[Array, Array]: ...
357-
def from_dlpack(x: Any, /, *, device: _Sharding | xc.Device | None = None, copy: builtins.bool | None = None) -> Array: ...
356+
def from_dlpack(x: Any, /, *, device: _Sharding | _Device = Any | None = None, copy: builtins.bool | None = None) -> Array: ...
358357
def frombuffer(buffer: Union[bytes, Any], dtype: DTypeLike = ...,
359358
count: int = ..., offset: int = ...) -> Array: ...
360359
def fromfile(*args, **kwargs): ...

0 commit comments

Comments
 (0)