15
15
from __future__ import annotations
16
16
17
17
import enum
18
- from typing import Any
18
+ from typing import Any , Optional
19
19
import warnings
20
20
21
+ from jax ._src .api import device_put
21
22
from jax import numpy as jnp
22
23
from jax ._src import array
23
24
from jax ._src import xla_bridge
24
25
from jax ._src .lib import xla_client
25
26
from jax ._src .lib import xla_extension_version
26
27
from jax ._src .typing import Array
27
-
28
+ from jax . _src . sharding import Sharding
28
29
29
30
# A set of dtypes that dlpack supports.
30
31
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
@@ -82,16 +83,108 @@ def to_dlpack(x: Array, take_ownership: bool = False,
82
83
x .addressable_data (0 ), stream = stream
83
84
) # type: ignore
84
85
86
+ def _place_array (_arr , device , dlpack_device , copy ):
87
+ if device and dlpack_device != device :
88
+ if copy is not None and not copy :
89
+ raise ValueError (
90
+ f"Specified { device = } which requires a copy since the source device "
91
+ f"is { repr (dlpack_device )} , however copy=False. Set copy=True or "
92
+ "copy=None to perform the requested operation."
93
+ )
94
+ else :
95
+ return device_put (_arr , device )
96
+ if copy :
97
+ return jnp .array (_arr , copy = True )
98
+ return _arr
99
+
100
+ def _legacy_from_dlpack (dlpack , device : xla_client .Device | None = None , copy : Optional [bool ] = None ):
101
+ preferred_platform = getattr (device , "platform" , None )
102
+ if device and preferred_platform == "gpu" :
103
+ preferred_platform = "cuda" if "cuda" in device .client .platform_version else "rocm"
104
+
105
+ cpu_backend = xla_bridge .get_backend ("cpu" )
106
+ gpu_backend = None
107
+
108
+ if preferred_platform in {"cuda" , "rocm" }:
109
+ try :
110
+ gpu_backend = xla_bridge .get_backend (preferred_platform )
111
+ except RuntimeError :
112
+ raise TypeError (
113
+ f"A { str .upper (preferred_platform )} device was specified, however no "
114
+ f"{ str .upper (preferred_platform )} backend was found."
115
+ )
85
116
86
- def from_dlpack (external_array ):
117
+ if preferred_platform is None :
118
+ try :
119
+ gpu_backend = xla_bridge .get_backend ("cuda" )
120
+ except RuntimeError :
121
+ pass
122
+ # Try ROCm if CUDA backend not found
123
+ if gpu_backend is None :
124
+ try :
125
+ gpu_backend = xla_bridge .get_backend ("rocm" )
126
+ except RuntimeError :
127
+ pass
128
+
129
+ _arr = jnp .asarray (xla_client ._xla .dlpack_managed_tensor_to_buffer (
130
+ dlpack , cpu_backend , gpu_backend )) # type: ignore
131
+
132
+ return _place_array (_arr , device , _arr .devices ().pop (), copy )
133
+
134
+ def _from_dlpack (external_array , device : xla_client .Device | None = None , copy : bool | None = None ):
135
+ dl_device_type , device_id = external_array .__dlpack_device__ ()
136
+ try :
137
+ dl_device_platform = {
138
+ DLDeviceType .kDLCPU : "cpu" ,
139
+ DLDeviceType .kDLCUDA : "cuda" ,
140
+ DLDeviceType .kDLROCM : "rocm" ,
141
+ }[dl_device_type ]
142
+ except TypeError :
143
+ # https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
144
+ # TypeError.
145
+ raise TypeError (
146
+ "Array passed to from_dlpack is on unsupported device type "
147
+ f"(DLDeviceType: { dl_device_type } , array: { external_array } " )
148
+
149
+ backend = xla_bridge .get_backend (dl_device_platform )
150
+ dlpack_device = backend .device_from_local_hardware_id (device_id )
151
+ try :
152
+ stream = dlpack_device .get_stream_for_external_ready_events ()
153
+ except xla_client .XlaRuntimeError as err : # type: ignore
154
+ if "UNIMPLEMENTED" in str (err ):
155
+ stream = None
156
+ else :
157
+ raise
158
+ dlpack = external_array .__dlpack__ (stream = stream )
159
+
160
+ _arr = jnp .asarray (xla_client ._xla .dlpack_managed_tensor_to_buffer (
161
+ dlpack , dlpack_device , stream ))
162
+
163
+ return _place_array (_arr , device , dlpack_device , copy )
164
+
165
+ def from_dlpack (external_array , device : xla_client .Device | Sharding | None = None , copy : bool | None = None ):
87
166
"""Returns a :class:`~jax.Array` representation of a DLPack tensor.
88
167
89
- The returned :class:`~jax.Array` shares memory with ``external_array``.
168
+ The returned :class:`~jax.Array` shares memory with ``external_array`` if no
169
+ device transfer or copy was requested.
90
170
91
171
Args:
92
- external_array: an array object that has __dlpack__ and __dlpack_device__
172
+ external_array: An array object that has __dlpack__ and __dlpack_device__
93
173
methods, or a DLPack tensor on either CPU or GPU (legacy API).
94
174
175
+ device: The (optional) :py:class:`Device`, representing the device on which
176
+ the returned array should be placed. If given, then the result is committed
177
+ to the device. If unspecified, the resulting array will be unpacked onto the
178
+ same device it originated from. Setting ``device`` to a device different from
179
+ the source of ``external_array`` will require a copy, meaning ``copy`` must be
180
+ set to either ``True`` or ``None``.
181
+
182
+ copy: An (optional) boolean, controlling whether or not to a copy is performed.
183
+ If ``copy=True`` then a copy is always performed, even if unpacked onto the
184
+ same device. If ``copy=False`` then the copy is never peformed and will raise
185
+ an error if necessary. When ``copy=None`` then a copy may be performed if
186
+ needed for a device transfer.
187
+
95
188
Returns:
96
189
A jax.Array
97
190
@@ -102,49 +195,16 @@ def from_dlpack(external_array):
102
195
is later modified in-place, it may lead to undefined behavior when using
103
196
the associated JAX array.
104
197
"""
198
+ if isinstance (device , Sharding ):
199
+ device_set = device .device_set
200
+ if len (device_set ) > 1 :
201
+ raise ValueError (
202
+ "from_dlpack can only unpack a dlpack tensor onto a singular device, but "
203
+ f"a Sharding with { len (device_set )} devices was provided."
204
+ )
205
+ device = device_set .pop ()
105
206
if hasattr (external_array , "__dlpack__" ):
106
- dl_device_type , device_id = external_array .__dlpack_device__ ()
107
- try :
108
- device_platform = {
109
- DLDeviceType .kDLCPU : "cpu" ,
110
- DLDeviceType .kDLCUDA : "cuda" ,
111
- DLDeviceType .kDLROCM : "rocm" ,
112
- }[dl_device_type ]
113
- except TypeError :
114
- # https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
115
- # TypeError.
116
- raise TypeError (
117
- "Array passed to from_dlpack is on unsupported device type "
118
- f"(DLDeviceType: { dl_device_type } , array: { external_array } " )
119
-
120
- backend = xla_bridge .get_backend (device_platform )
121
- device = backend .device_from_local_hardware_id (device_id )
122
- try :
123
- stream = device .get_stream_for_external_ready_events ()
124
- except xla_client .XlaRuntimeError as err : # type: ignore
125
- if "UNIMPLEMENTED" in str (err ):
126
- stream = None
127
- else :
128
- raise
129
- dlpack = external_array .__dlpack__ (stream = stream )
130
-
131
- return jnp .asarray (xla_client ._xla .dlpack_managed_tensor_to_buffer (
132
- dlpack , device , stream ))
133
- else :
134
- # Legacy path
135
- dlpack = external_array
136
- cpu_backend = xla_bridge .get_backend ("cpu" )
137
- try :
138
- gpu_backend = xla_bridge .get_backend ("cuda" )
139
- except RuntimeError :
140
- gpu_backend = None
141
-
142
- # Try ROCm if CUDA backend not found
143
- if gpu_backend is None :
144
- try :
145
- gpu_backend = xla_bridge .get_backend ("rocm" )
146
- except RuntimeError :
147
- gpu_backend = None
207
+ return _from_dlpack (external_array , device , copy )
148
208
149
- return jnp . asarray ( xla_client . _xla . dlpack_managed_tensor_to_buffer (
150
- dlpack , cpu_backend , gpu_backend ) )
209
+ # Legacy path
210
+ return _legacy_from_dlpack ( external_array , device , copy )
0 commit comments