@@ -12,7 +12,6 @@ from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClas
12
12
from jax ._src .typing import Array , ArrayLike , DType , DTypeLike , DimSize , DuckTypedArray , Shape
13
13
from jax .numpy import fft as fft , linalg as linalg
14
14
from jax .sharding import Sharding as _Sharding
15
- import jaxlib .xla_client as xc
16
15
import numpy as _np
17
16
18
17
_T = TypeVar ('_T' )
@@ -354,7 +353,7 @@ def fmax(x: ArrayLike, y: ArrayLike, /) -> Array: ...
354
353
def fmin (x : ArrayLike , y : ArrayLike , / ) -> Array : ...
355
354
def fmod (x : ArrayLike , y : ArrayLike , / ) -> Array : ...
356
355
def frexp (x : ArrayLike , / ) -> tuple [Array , Array ]: ...
357
- def from_dlpack (x : Any , / , * , device : _Sharding | xc . Device | None = None , copy : builtins .bool | None = None ) -> Array : ...
356
+ def from_dlpack (x : Any , / , * , device : _Sharding | _Device = Any | None = None , copy : builtins .bool | None = None ) -> Array : ...
358
357
def frombuffer (buffer : Union [bytes , Any ], dtype : DTypeLike = ...,
359
358
count : int = ..., offset : int = ...) -> Array : ...
360
359
def fromfile (* args , ** kwargs ): ...
0 commit comments