From e9f516efec1ce807a7a45cc396da1d46700af9ee Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 5 May 2025 15:28:44 +0200 Subject: [PATCH 1/4] ENH: Add a basic `diff` test --- array_api_tests/test_utility_functions.py | 37 +++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py index b5076d9b..14dc6bb2 100644 --- a/array_api_tests/test_utility_functions.py +++ b/array_api_tests/test_utility_functions.py @@ -63,3 +63,40 @@ def test_any(x, data): expected = any(elements) ph.assert_scalar_equals("any", type_=scalar_type, idx=out_idx, out=result, expected=expected, kw=kw) + + +@pytest.mark.unvectorized +@pytest.mark.min_version("2024.12") +@given( + x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)), + data=st.data(), +) +def test_diff(x, data): + # TODO: + # 1. append/prepend + axis = data.draw( + st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(), + label="axis" + ) + if axis is None: + axis_kw = {"axis": -1} + n_axis = x.ndim - 1 + else: + axis_kw = {"axis": axis} + n_axis = axis + x.ndim if axis < 0 else axis + + n = data.draw(st.integers(1, min(x.shape[n_axis], 3))) + + out = xp.diff(x, **axis_kw, n=n) + + expected_shape = list(x.shape) + expected_shape[n_axis] -= n + assert out.shape == tuple(expected_shape) + + # value test + if n == 1: + for idx in sh.ndindex(out.shape): + l = list(idx) + l[n_axis] += 1 + assert out[idx] == x[tuple(l)] - x[idx], f"diff failed with {idx = }" + From a8121e11ac9a32301a2ee5f0670275d15dbbc9aa Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 5 May 2025 16:18:20 +0200 Subject: [PATCH 2/4] ENH: test `diff`'s append and prepend arguments --- array_api_tests/test_utility_functions.py | 43 +++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py index 14dc6bb2..b6e0a4fe 100644 --- a/array_api_tests/test_utility_functions.py +++ b/array_api_tests/test_utility_functions.py @@ -72,8 +72,6 @@ def test_any(x, data): data=st.data(), ) def test_diff(x, data): - # TODO: - # 1. append/prepend axis = data.draw( st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(), label="axis" @@ -91,6 +89,7 @@ def test_diff(x, data): expected_shape = list(x.shape) expected_shape[n_axis] -= n + assert out.shape == tuple(expected_shape) # value test @@ -100,3 +99,43 @@ def test_diff(x, data): l[n_axis] += 1 assert out[idx] == x[tuple(l)] - x[idx], f"diff failed with {idx = }" + +@pytest.mark.min_version("2024.12") +@pytest.mark.unvectorized +@given( + x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)), + data=st.data(), +) +def test_diff_append_prepend(x, data): + axis = data.draw( + st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(), + label="axis" + ) + if axis is None: + axis_kw = {"axis": -1} + n_axis = x.ndim - 1 + else: + axis_kw = {"axis": axis} + n_axis = axis + x.ndim if axis < 0 else axis + + n = data.draw(st.integers(1, min(x.shape[n_axis], 3))) + + append_shape = list(x.shape) + append_axis_len = data.draw(st.integers(1, 2*append_shape[n_axis]), label="append_axis") + append_shape[n_axis] = append_axis_len + append = data.draw(hh.arrays(dtype=x.dtype, shape=tuple(append_shape)), label="append") + + prepend_shape = list(x.shape) + prepend_axis_len = data.draw(st.integers(1, 2*prepend_shape[n_axis]), label="prepend_axis") + prepend_shape[n_axis] = prepend_axis_len + prepend = data.draw(hh.arrays(dtype=x.dtype, shape=tuple(prepend_shape)), label="prepend") + + out = xp.diff(x, **axis_kw, n=n, append=append, prepend=prepend) + + in_1 = xp.concat((prepend, x, append), **axis_kw) + out_1 = xp.diff(in_1, **axis_kw, n=n) + + assert out.shape == out_1.shape + for idx in sh.ndindex(out.shape): + assert out[idx] == out_1[idx], f"{idx = }" + From 1b79605c2c9198959cbfbec7adce6c2a2c4f0520 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 5 May 2025 15:03:25 +0000 Subject: [PATCH 3/4] BUG: work around jax/hypothesis combination being unreasonable --- array_api_tests/test_utility_functions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py index b6e0a4fe..ec92dd51 100644 --- a/array_api_tests/test_utility_functions.py +++ b/array_api_tests/test_utility_functions.py @@ -65,10 +65,12 @@ def test_any(x, data): out=result, expected=expected, kw=kw) +# NB: hh.int_dtypes instead of hh.numeric_dtypes because of +# https://github.com/data-apis/array-api-tests/issues/368 @pytest.mark.unvectorized @pytest.mark.min_version("2024.12") @given( - x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)), + x=hh.arrays_no_scalars(hh.int_dtypes, hh.shapes(min_dims=1, min_side=1)), data=st.data(), ) def test_diff(x, data): @@ -103,7 +105,7 @@ def test_diff(x, data): @pytest.mark.min_version("2024.12") @pytest.mark.unvectorized @given( - x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)), + x=hh.arrays(hh.int_dtypes, hh.shapes(min_dims=1, min_side=1)), data=st.data(), ) def test_diff_append_prepend(x, data): From f5f9543cb8e05563cb2ab0c2aa2935ea8b906e27 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 5 May 2025 17:31:50 +0200 Subject: [PATCH 4/4] Revert "BUG: work around jax/hypothesis combination being unreasonable" This reverts commit 1b79605c2c9198959cbfbec7adce6c2a2c4f0520. cf a discussion at https://github.com/data-apis/array-api-tests/issues/368 --- array_api_tests/test_utility_functions.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py index ec92dd51..b6e0a4fe 100644 --- a/array_api_tests/test_utility_functions.py +++ b/array_api_tests/test_utility_functions.py @@ -65,12 +65,10 @@ def test_any(x, data): out=result, expected=expected, kw=kw) -# NB: hh.int_dtypes instead of hh.numeric_dtypes because of -# https://github.com/data-apis/array-api-tests/issues/368 @pytest.mark.unvectorized @pytest.mark.min_version("2024.12") @given( - x=hh.arrays_no_scalars(hh.int_dtypes, hh.shapes(min_dims=1, min_side=1)), + x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)), data=st.data(), ) def test_diff(x, data): @@ -105,7 +103,7 @@ def test_diff(x, data): @pytest.mark.min_version("2024.12") @pytest.mark.unvectorized @given( - x=hh.arrays(hh.int_dtypes, hh.shapes(min_dims=1, min_side=1)), + x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)), data=st.data(), ) def test_diff_append_prepend(x, data):