Skip to content

Commit 0da7010

Browse files
committed
ENH: make MAX_ARRAY_SIZE configurable, lower the default to 1024
Changing the default is via the environment variable ARRAY_API_TESTS_MAX_ARRAY_SIZE.
1 parent 6144df4 commit 0da7010

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,14 @@ entries from the ``xfail.txt`` file instead of xfailing them. Anecdotally, we sa
316316
speed-ups by a factor of 4-5---which allowed us to use 4-5 larger values of
317317
``--max-examples`` within the same time budget.
318318
319+
#### Limiting the array sizes
320+
321+
The test suite generates random arrays as inputs to functions it tests. "unvectorized"
322+
tests iterate over elements of arrays, which might be slow. If the run time becomes
323+
a problem, you can limit the maximum number of elements in generated arrays by
324+
setting the environment variable ``ARRAY_API_TESTS_MAX_ARRAY_SIZE`` to the
325+
desired value. By default, it is set to 1024.
326+
319327
320328
## Contributing
321329

array_api_tests/hypothesis_helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
import re
45
from contextlib import contextmanager
56
from functools import wraps
@@ -232,7 +233,7 @@ def all_floating_dtypes() -> SearchStrategy[DataType]:
232233
lambda i: getattr(xp, i))
233234

234235
# Limit the total size of an array shape
235-
MAX_ARRAY_SIZE = 10000
236+
MAX_ARRAY_SIZE = int(os.environ.get("ARRAY_API_TESTS_MAX_ARRAY_SIZE", 1024))
236237
# Size to use for 2-dim arrays
237238
SQRT_MAX_ARRAY_SIZE = int(math.sqrt(MAX_ARRAY_SIZE))
238239

array_api_tests/test_creation_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def test_meshgrid(dtype, data):
499499
shapes = data.draw(
500500
st.integers(1, 5).flatmap(
501501
lambda n: hh.mutually_broadcastable_shapes(
502-
n, min_dims=1, max_dims=1, max_side=5
502+
n, min_dims=1, max_dims=1, max_side=4
503503
)
504504
),
505505
label="shapes",
@@ -509,7 +509,7 @@ def test_meshgrid(dtype, data):
509509
x = data.draw(hh.arrays(dtype=dtype, shape=shape), label=f"x{i}")
510510
arrays.append(x)
511511
# sanity check
512-
assert math.prod(math.prod(x.shape) for x in arrays) <= hh.MAX_ARRAY_SIZE
512+
# assert math.prod(math.prod(x.shape) for x in arrays) <= hh.MAX_ARRAY_SIZE
513513
out = xp.meshgrid(*arrays)
514514
for i, x in enumerate(out):
515515
ph.assert_dtype("meshgrid", in_dtype=dtype, out_dtype=x.dtype, repr_name=f"out[{i}].dtype")

0 commit comments

Comments
 (0)