Skip to content

Commit d71ac98

Browse files
committed
Update
1 parent 71ec6e3 commit d71ac98

File tree

4 files changed

+119
-55
lines changed

4 files changed

+119
-55
lines changed

jax/_src/dlpack.py

Lines changed: 109 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,17 @@
1515
from __future__ import annotations
1616

1717
import enum
18-
from typing import Any
18+
from typing import Any, Optional
1919
import warnings
2020

21+
from jax._src.api import device_put
2122
from jax import numpy as jnp
2223
from jax._src import array
2324
from jax._src import xla_bridge
2425
from jax._src.lib import xla_client
2526
from jax._src.lib import xla_extension_version
2627
from jax._src.typing import Array
27-
28+
from jax._src.sharding import Sharding
2829

2930
# A set of dtypes that dlpack supports.
3031
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
@@ -82,16 +83,108 @@ def to_dlpack(x: Array, take_ownership: bool = False,
8283
x.addressable_data(0), stream=stream
8384
) # type: ignore
8485

86+
def _place_array(_arr, device, dlpack_device, copy):
87+
if device and dlpack_device != device:
88+
if copy is not None and not copy:
89+
raise ValueError(
90+
f"Specified {device=} which requires a copy since the source device "
91+
f"is {repr(dlpack_device)}, however copy=False. Set copy=True or "
92+
"copy=None to perform the requested operation."
93+
)
94+
else:
95+
return device_put(_arr, device)
96+
if copy:
97+
return jnp.array(_arr, copy=True)
98+
return _arr
99+
100+
def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None , copy: Optional[bool] = None):
101+
preferred_platform = getattr(device, "platform", None)
102+
if device and preferred_platform == "gpu":
103+
preferred_platform = "cuda" if "cuda" in device.client.platform_version else "rocm"
104+
105+
cpu_backend = xla_bridge.get_backend("cpu")
106+
gpu_backend = None
107+
108+
if preferred_platform in {"cuda", "rocm"}:
109+
try:
110+
gpu_backend = xla_bridge.get_backend(preferred_platform)
111+
except RuntimeError:
112+
raise TypeError(
113+
f"A {str.upper(preferred_platform)} device was specified, however no "
114+
f"{str.upper(preferred_platform)} backend was found."
115+
)
85116

86-
def from_dlpack(external_array):
117+
if preferred_platform is None:
118+
try:
119+
gpu_backend = xla_bridge.get_backend("cuda")
120+
except RuntimeError:
121+
pass
122+
# Try ROCm if CUDA backend not found
123+
if gpu_backend is None:
124+
try:
125+
gpu_backend = xla_bridge.get_backend("rocm")
126+
except RuntimeError:
127+
pass
128+
129+
_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
130+
dlpack, cpu_backend, gpu_backend)) # type: ignore
131+
132+
return _place_array(_arr, device, _arr.devices().pop(), copy)
133+
134+
def _from_dlpack(external_array, device: xla_client.Device | None = None , copy: bool | None = None):
135+
dl_device_type, device_id = external_array.__dlpack_device__()
136+
try:
137+
dl_device_platform = {
138+
DLDeviceType.kDLCPU: "cpu",
139+
DLDeviceType.kDLCUDA: "cuda",
140+
DLDeviceType.kDLROCM: "rocm",
141+
}[dl_device_type]
142+
except TypeError:
143+
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
144+
# TypeError.
145+
raise TypeError(
146+
"Array passed to from_dlpack is on unsupported device type "
147+
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
148+
149+
backend = xla_bridge.get_backend(dl_device_platform)
150+
dlpack_device = backend.device_from_local_hardware_id(device_id)
151+
try:
152+
stream = dlpack_device.get_stream_for_external_ready_events()
153+
except xla_client.XlaRuntimeError as err: # type: ignore
154+
if "UNIMPLEMENTED" in str(err):
155+
stream = None
156+
else:
157+
raise
158+
dlpack = external_array.__dlpack__(stream=stream)
159+
160+
_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
161+
dlpack, dlpack_device, stream))
162+
163+
return _place_array(_arr, device, dlpack_device, copy)
164+
165+
def from_dlpack(external_array, device: xla_client.Device | Sharding | None = None , copy: bool | None = None):
87166
"""Returns a :class:`~jax.Array` representation of a DLPack tensor.
88167
89-
The returned :class:`~jax.Array` shares memory with ``external_array``.
168+
The returned :class:`~jax.Array` shares memory with ``external_array`` if no
169+
device transfer or copy was requested.
90170
91171
Args:
92-
external_array: an array object that has __dlpack__ and __dlpack_device__
172+
external_array: An array object that has __dlpack__ and __dlpack_device__
93173
methods, or a DLPack tensor on either CPU or GPU (legacy API).
94174
175+
device: The (optional) :py:class:`Device`, representing the device on which
176+
the returned array should be placed. If given, then the result is committed
177+
to the device. If unspecified, the resulting array will be unpacked onto the
178+
same device it originated from. Setting ``device`` to a device different from
179+
the source of ``external_array`` will require a copy, meaning ``copy`` must be
180+
set to either ``True`` or ``None``.
181+
182+
copy: An (optional) boolean, controlling whether or not to a copy is performed.
183+
If ``copy=True`` then a copy is always performed, even if unpacked onto the
184+
same device. If ``copy=False`` then the copy is never peformed and will raise
185+
an error if necessary. When ``copy=None`` then a copy may be performed if
186+
needed for a device transfer.
187+
95188
Returns:
96189
A jax.Array
97190
@@ -102,49 +195,16 @@ def from_dlpack(external_array):
102195
is later modified in-place, it may lead to undefined behavior when using
103196
the associated JAX array.
104197
"""
198+
if isinstance(device, Sharding):
199+
device_set = device.device_set
200+
if len(device_set) > 1:
201+
raise ValueError(
202+
"from_dlpack can only unpack a dlpack tensor onto a singular device, but "
203+
f"a Sharding with {len(device_set)} devices was provided."
204+
)
205+
device = device_set.pop()
105206
if hasattr(external_array, "__dlpack__"):
106-
dl_device_type, device_id = external_array.__dlpack_device__()
107-
try:
108-
device_platform = {
109-
DLDeviceType.kDLCPU: "cpu",
110-
DLDeviceType.kDLCUDA: "cuda",
111-
DLDeviceType.kDLROCM: "rocm",
112-
}[dl_device_type]
113-
except TypeError:
114-
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
115-
# TypeError.
116-
raise TypeError(
117-
"Array passed to from_dlpack is on unsupported device type "
118-
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
119-
120-
backend = xla_bridge.get_backend(device_platform)
121-
device = backend.device_from_local_hardware_id(device_id)
122-
try:
123-
stream = device.get_stream_for_external_ready_events()
124-
except xla_client.XlaRuntimeError as err: # type: ignore
125-
if "UNIMPLEMENTED" in str(err):
126-
stream = None
127-
else:
128-
raise
129-
dlpack = external_array.__dlpack__(stream=stream)
130-
131-
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
132-
dlpack, device, stream))
133-
else:
134-
# Legacy path
135-
dlpack = external_array
136-
cpu_backend = xla_bridge.get_backend("cpu")
137-
try:
138-
gpu_backend = xla_bridge.get_backend("cuda")
139-
except RuntimeError:
140-
gpu_backend = None
141-
142-
# Try ROCm if CUDA backend not found
143-
if gpu_backend is None:
144-
try:
145-
gpu_backend = xla_bridge.get_backend("rocm")
146-
except RuntimeError:
147-
gpu_backend = None
207+
return _from_dlpack(external_array, device, copy)
148208

149-
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
150-
dlpack, cpu_backend, gpu_backend))
209+
# Legacy path
210+
return _legacy_from_dlpack(external_array, device, copy)

jax/_src/numpy/lax_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2436,9 +2436,9 @@ def fromiter(*args, **kwargs):
24362436
is later modified in-place, it may lead to undefined behavior when using
24372437
the associated JAX array.
24382438
""")
2439-
def from_dlpack(x: Any) -> Array:
2439+
def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None, copy: bool | None = None) -> Array:
24402440
from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top
2441-
return from_dlpack(x)
2441+
return from_dlpack(x, device=device, copy=copy)
24422442

24432443
@util.implements(np.fromfunction)
24442444
def fromfunction(function: Callable[..., Array], shape: Any,

jax/experimental/array_api/_creation_functions.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import jax
1618
import jax.numpy as jnp
17-
19+
from jax._src.lib import xla_client as xc
20+
from jax._src.sharding import Sharding
1821

1922
def arange(start, /, stop=None, step=1, *, dtype=None, device=None):
2023
return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device)
@@ -31,8 +34,8 @@ def empty_like(x, /, *, dtype=None, device=None):
3134
def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None):
3235
return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device)
3336

34-
def from_dlpack(x, /):
35-
return jnp.from_dlpack(x)
37+
def from_dlpack(x, /, *, device: xc.Device | Sharding | None = None, copy: bool | None = None):
38+
return jnp.from_dlpack(x, device=device, copy=copy)
3639

3740
def full(shape, fill_value, *, dtype=None, device=None):
3841
return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device)

jax/numpy/__init__.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ from typing import Any, Callable, Literal, NamedTuple, Optional, Sequence, TypeV
66

77
from jax._src import core as _core
88
from jax._src import dtypes as _dtypes
9+
from jax._src.lib import xla_client as xc
910
from jax._src.lax.lax import PrecisionLike
1011
from jax._src.lax.slicing import GatherScatterMode
1112
from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass
@@ -353,7 +354,7 @@ def fmax(x: ArrayLike, y: ArrayLike, /) -> Array: ...
353354
def fmin(x: ArrayLike, y: ArrayLike, /) -> Array: ...
354355
def fmod(x: ArrayLike, y: ArrayLike, /) -> Array: ...
355356
def frexp(x: ArrayLike, /) -> tuple[Array, Array]: ...
356-
def from_dlpack(x: Any) -> Array: ...
357+
def from_dlpack(x: Any, /, *, device: _Sharding | xc.Device | None = None, copy: builtins.bool | None = None) -> Array: ...
357358
def frombuffer(buffer: Union[bytes, Any], dtype: DTypeLike = ...,
358359
count: int = ..., offset: int = ...) -> Array: ...
359360
def fromfile(*args, **kwargs): ...

0 commit comments

Comments
 (0)