-
Notifications
You must be signed in to change notification settings - Fork 53
clarify if __array_namespace_info().default_device()
can be None
#923
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hi - for what it's worth, we made the deliberate choice to return The problem is that JAX's existing device placement does not entirely align with the model that the authors of the spec had in mind. For example, under JIT, there is no default device, because the array referenced in the Python API may not ever physically exist. Here's a silly example: @jax.jit
def f(x):
y = jnp.arange(10)
return x What device is Let's modify this slightly: @jax.jit
def f(x):
y = jnp.arange(len(x))
return x + y What device will Neither of these situations is compatible with the idea of a global default device, and so the very notion of "default_device" as envisioned by the array api specification is flawed, and not applicable to frameworks like JAX. Given that, we thought returning If you have other suggestions, I'm open to hear them! If the specification were changed such that |
IMHO, I think I'm personally happy for JAX to return None as the default device.
It's worth pointing out that this behaviour, while definitely desirable and nice to read, is something that's possible exclusively on lazy backends. In fact, the snippet above will crash on PyTorch if x does not lay on the default device. (CuPy has blocking design issues on this). from array_api_compat import array_namespace, device
def f(x):
xp = array_namespace(x)
y = xp.arange(x.shape[-1], device=device(x))
return x + y The array-api-compat shims are necessary to support NumPy 1.x, Dask, Sparse, and JAX itself. This pattern follows the guideline of prioritizing input->output propagation over global and context device:
The answer here is that no-one cares. from array_api_compat import array_namespace, device, to_device
def f(x):
"""Return x+arange, prioritizing the default device over x.device"""
xp = array_namespace(x)
y = xp.arange(x.shape[-1])
return to_device(x, device(y)) + y Here, we have some peculiar behaviour:
|
To me It is also not a problem that the jax-onic code gives the compiler more options than the code you have to write to conform to the array API. # jax-onic code
@jax.jit
def f(x):
y = jnp.arange(len(x))
return x + y
# array API compatible
def f(x):
xp = array_namespace(x)
y = xp.arange(x.shape[-1], device=device(x))
return x + y Though I do wonder if in the array API version of Either way, my main point is that it is fine that "not all valid jax/numpy/cupy code is valid array API code". |
It's what it does today. |
It looks like everyone is in agreement (as am I) that returning |
While not in the 2024.12 spec, this is generally agreed on in data-apis/array-api#923
The spec only says it returns an object corresponding to the default device. ( https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.default_device.html#array_api.info.default_device)
jax.numpy
returnsNone
, so the question is whetherNone
corresponds to the default device or not.The text was updated successfully, but these errors were encountered: