Skip to content

Commit 27fbd9c

Browse files
authored
ENH: support PyTorch device='meta' (#300)
1 parent e2762f5 commit 27fbd9c

File tree

8 files changed

+111
-29
lines changed

8 files changed

+111
-29
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
153153
) -> Array:
154154
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""
155155

156-
if not capabilities(xp)["boolean indexing"]:
156+
if not capabilities(xp, device=_compat.device(cond))["boolean indexing"]:
157157
# jax.jit does not support assignment by boolean mask
158158
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)
159159

@@ -716,7 +716,7 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
716716
# 2. backend has unique_counts and it returns a None-sized array;
717717
# e.g. Dask, ndonnx
718718
# 3. backend does not have unique_counts; e.g. wrapped JAX
719-
if capabilities(xp)["data-dependent shapes"]:
719+
if capabilities(xp, device=_compat.device(x))["data-dependent shapes"]:
720720
# xp has unique_counts; O(n) complexity
721721
_, counts = xp.unique_counts(x)
722722
n = _compat.size(counts)

src/array_api_extra/_lib/_testing.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
is_jax_namespace,
2323
is_numpy_namespace,
2424
is_pydata_sparse_namespace,
25+
is_torch_array,
2526
is_torch_namespace,
2627
to_device,
2728
)
@@ -62,18 +63,28 @@ def _check_ns_shape_dtype(
6263
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
6364
assert actual_xp == desired_xp, msg
6465

65-
if check_shape:
66-
actual_shape = actual.shape
67-
desired_shape = desired.shape
68-
if is_dask_namespace(desired_xp):
69-
# Dask uses nan instead of None for unknown shapes
70-
if any(math.isnan(i) for i in cast(tuple[float, ...], actual_shape)):
71-
actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
72-
if any(math.isnan(i) for i in cast(tuple[float, ...], desired_shape)):
73-
desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
66+
# Dask uses nan instead of None for unknown shapes
67+
actual_shape = cast(tuple[float, ...], actual.shape)
68+
desired_shape = cast(tuple[float, ...], desired.shape)
69+
assert None not in actual_shape # Requires explicit support
70+
assert None not in desired_shape
71+
if is_dask_namespace(desired_xp):
72+
if any(math.isnan(i) for i in actual_shape):
73+
actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
74+
if any(math.isnan(i) for i in desired_shape):
75+
desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
7476

77+
if check_shape:
7578
msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
7679
assert actual_shape == desired_shape, msg
80+
else:
81+
# Ignore shape, but check flattened size. This is normally done by
82+
# np.testing.assert_array_equal etc even when strict=False, but not for
83+
# non-materializable arrays.
84+
actual_size = math.prod(actual_shape) # pyright: ignore[reportUnknownArgumentType]
85+
desired_size = math.prod(desired_shape) # pyright: ignore[reportUnknownArgumentType]
86+
msg = f"sizes do not match: {actual_size} != f{desired_size}"
87+
assert actual_size == desired_size, msg
7788

7889
if check_dtype:
7990
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
@@ -90,6 +101,15 @@ def _check_ns_shape_dtype(
90101
return desired_xp
91102

92103

104+
def _is_materializable(x: Array) -> bool:
105+
"""
106+
Return True if you can call `as_numpy_array(x)`; False otherwise.
107+
"""
108+
# Important: here we assume that we're not tracing -
109+
# e.g. we're not inside `jax.jit`` nor `cupy.cuda.Stream.begin_capture`.
110+
return not is_torch_array(x) or x.device.type != "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
111+
112+
93113
def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: # type: ignore[explicit-any]
94114
"""
95115
Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
@@ -146,6 +166,8 @@ def xp_assert_equal(
146166
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
147167
"""
148168
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
169+
if not _is_materializable(actual):
170+
return
149171
actual_np = as_numpy_array(actual, xp=xp)
150172
desired_np = as_numpy_array(desired, xp=xp)
151173
np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg)
@@ -181,6 +203,8 @@ def xp_assert_less(
181203
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
182204
"""
183205
xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
206+
if not _is_materializable(x):
207+
return
184208
x_np = as_numpy_array(x, xp=xp)
185209
y_np = as_numpy_array(y, xp=xp)
186210
np.testing.assert_array_less(x_np, y_np, err_msg=err_msg)
@@ -229,6 +253,8 @@ def xp_assert_close(
229253
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
230254
"""
231255
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
256+
if not _is_materializable(actual):
257+
return
232258

233259
if rtol is None:
234260
if xp.isdtype(actual.dtype, ("real floating", "complex floating")):

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@
2929
is_jax_namespace,
3030
is_numpy_array,
3131
is_pydata_sparse_namespace,
32+
is_torch_namespace,
3233
)
33-
from ._typing import Array
34+
from ._typing import Array, Device
3435

3536
if TYPE_CHECKING: # pragma: no cover
3637
# TODO import from typing (requires Python >=3.12 and >=3.13)
@@ -300,7 +301,7 @@ def meta_namespace(
300301
return array_namespace(*metas)
301302

302303

303-
def capabilities(xp: ModuleType) -> dict[str, int]:
304+
def capabilities(xp: ModuleType, *, device: Device | None = None) -> dict[str, int]:
304305
"""
305306
Return patched ``xp.__array_namespace_info__().capabilities()``.
306307
@@ -311,6 +312,8 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
311312
----------
312313
xp : array_namespace
313314
The standard-compatible namespace.
315+
device : Device, optional
316+
The device to use.
314317
315318
Returns
316319
-------
@@ -326,6 +329,13 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
326329
# Fixed in jax >=0.6.0
327330
out = out.copy()
328331
out["boolean indexing"] = False
332+
if is_torch_namespace(xp):
333+
# FIXME https://github.com/data-apis/array-api/issues/945
334+
device = xp.get_default_device() if device is None else xp.device(device)
335+
if device.type == "meta": # type: ignore[union-attr] # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
336+
out = out.copy()
337+
out["boolean indexing"] = False
338+
out["data-dependent shapes"] = False
329339
return out
330340

331341

tests/conftest.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ def device(
211211
Where possible, return a device that is not the default one.
212212
"""
213213
if library == Backend.ARRAY_API_STRICT:
214-
d = xp.Device("device1")
215-
assert get_device(xp.empty(0)) != d
216-
return d
214+
return xp.Device("device1")
215+
if library == Backend.TORCH:
216+
return xp.device("meta")
217+
if library == Backend.TORCH_GPU:
218+
return xp.device("cpu")
217219
return get_device(xp.empty(0))

tests/test_funcs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -731,9 +731,6 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool):
731731
b = xp.asarray([1e-9, 1e-4, xp.nan], device=device)
732732
res = isclose(a, b, equal_nan=equal_nan)
733733
assert get_device(res) == device
734-
xp_assert_equal(
735-
isclose(a, b, equal_nan=equal_nan), xp.asarray([True, False, equal_nan])
736-
)
737734

738735

739736
class TestKron:
@@ -996,6 +993,9 @@ def test_all_python_scalars(self, assume_unique: bool):
996993
_ = setdiff1d(0, 0, assume_unique=assume_unique)
997994

998995
@assume_unique
996+
@pytest.mark.skip_xp_backend(
997+
Backend.TORCH, reason="device='meta' does not support unknown shapes"
998+
)
999999
def test_device(self, xp: ModuleType, device: Device, assume_unique: bool):
10001000
x1 = xp.asarray([3, 8, 20], device=device)
10011001
x2 = xp.asarray([2, 3, 4], device=device)

tests/test_helpers.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,31 @@ def test_xp(self, xp: ModuleType):
212212
assert meta_namespace(*args, xp=xp) in (xp, np_compat)
213213

214214

215-
def test_capabilities(xp: ModuleType):
216-
expect = {"boolean indexing", "data-dependent shapes"}
217-
if xp.__array_api_version__ >= "2024.12":
218-
expect.add("max dimensions")
219-
assert capabilities(xp).keys() == expect
215+
class TestCapabilities:
216+
def test_basic(self, xp: ModuleType):
217+
expect = {"boolean indexing", "data-dependent shapes"}
218+
if xp.__array_api_version__ >= "2024.12":
219+
expect.add("max dimensions")
220+
assert capabilities(xp).keys() == expect
221+
222+
def test_device(self, xp: ModuleType, library: Backend, device: Device):
223+
expect_keys = {"boolean indexing", "data-dependent shapes"}
224+
if xp.__array_api_version__ >= "2024.12":
225+
expect_keys.add("max dimensions")
226+
assert capabilities(xp, device=device).keys() == expect_keys
227+
228+
if library.like(Backend.TORCH):
229+
# The output of capabilities is device-specific.
230+
231+
# Test that device=None gets the current default device.
232+
expect = capabilities(xp, device=device)
233+
with xp.device(device):
234+
actual = capabilities(xp)
235+
assert actual == expect
236+
237+
# Test that we're accepting anything that is accepted by the
238+
# device= parameter in other functions
239+
actual = capabilities(xp, device=device.type) # type: ignore[attr-defined] # pyright: ignore[reportUnknownArgumentType,reportAttributeAccessIssue]
220240

221241

222242
class Wrapper(Generic[T]):

tests/test_lazy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,9 @@ def test_lazy_apply_none_shape_broadcast(xp: ModuleType):
278278
Backend.ARRAY_API_STRICT, reason="device->host copy"
279279
),
280280
pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"),
281+
pytest.mark.skip_xp_backend(
282+
Backend.TORCH, reason="materialize 'meta' device"
283+
),
281284
pytest.mark.skip_xp_backend(
282285
Backend.TORCH_GPU, reason="device->host copy"
283286
),

tests/test_testing.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,17 @@
2424
# pyright: reportUnknownParameterType=false,reportMissingParameterType=false
2525

2626

27-
def test_as_numpy_array(xp: ModuleType, device: Device):
28-
x = xp.asarray([1, 2, 3], device=device)
29-
y = as_numpy_array(x, xp=xp)
30-
assert isinstance(y, np.ndarray)
27+
class TestAsNumPyArray:
28+
def test_basic(self, xp: ModuleType):
29+
x = xp.asarray([1, 2, 3])
30+
y = as_numpy_array(x, xp=xp)
31+
xp_assert_equal(y, np.asarray([1, 2, 3])) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
32+
33+
@pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device")
34+
def test_device(self, xp: ModuleType, device: Device):
35+
x = xp.asarray([1, 2, 3], device=device)
36+
y = as_numpy_array(x, xp=xp)
37+
xp_assert_equal(y, np.asarray([1, 2, 3])) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
3138

3239

3340
class TestAssertEqualCloseLess:
@@ -80,7 +87,7 @@ def test_check_shape(self, xp: ModuleType, func: Callable[..., None]):
8087
func(a, b, check_shape=False)
8188
with pytest.raises(AssertionError, match="Mismatched elements"):
8289
func(a, c, check_shape=False)
83-
with pytest.raises(AssertionError, match=r"shapes \(1,\), \(2,\) mismatch"):
90+
with pytest.raises(AssertionError, match="sizes do not match"):
8491
func(a, d, check_shape=False)
8592

8693
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
@@ -169,6 +176,20 @@ def test_none_shape(self, xp: ModuleType, func: Callable[..., None]):
169176
with pytest.raises(AssertionError, match="Mismatched elements"):
170177
func(xp.asarray([4]), a)
171178

179+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
180+
def test_device(self, xp: ModuleType, device: Device, func: Callable[..., None]):
181+
a = xp.asarray([1] if func is xp_assert_less else [2], device=device)
182+
b = xp.asarray([2], device=device)
183+
c = xp.asarray([2, 2], device=device)
184+
185+
func(a, b)
186+
with pytest.raises(AssertionError, match="shapes do not match"):
187+
func(a, c)
188+
# This is normally performed by np.testing.assert_array_equal etc.
189+
# but in case of torch device='meta' we have to do it manually
190+
with pytest.raises(AssertionError, match="sizes do not match"):
191+
func(a, c, check_shape=False)
192+
172193

173194
def good_lazy(x: Array) -> Array:
174195
"""A function that behaves well in Dask and jax.jit"""

0 commit comments

Comments
 (0)