diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py index b5076d9b..b6e0a4fe 100644 --- a/array_api_tests/test_utility_functions.py +++ b/array_api_tests/test_utility_functions.py @@ -63,3 +63,79 @@ def test_any(x, data): expected = any(elements) ph.assert_scalar_equals("any", type_=scalar_type, idx=out_idx, out=result, expected=expected, kw=kw) + + +@pytest.mark.unvectorized +@pytest.mark.min_version("2024.12") +@given( + x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)), + data=st.data(), +) +def test_diff(x, data): + axis = data.draw( + st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(), + label="axis" + ) + if axis is None: + axis_kw = {"axis": -1} + n_axis = x.ndim - 1 + else: + axis_kw = {"axis": axis} + n_axis = axis + x.ndim if axis < 0 else axis + + n = data.draw(st.integers(1, min(x.shape[n_axis], 3))) + + out = xp.diff(x, **axis_kw, n=n) + + expected_shape = list(x.shape) + expected_shape[n_axis] -= n + + assert out.shape == tuple(expected_shape) + + # value test + if n == 1: + for idx in sh.ndindex(out.shape): + l = list(idx) + l[n_axis] += 1 + assert out[idx] == x[tuple(l)] - x[idx], f"diff failed with {idx = }" + + +@pytest.mark.min_version("2024.12") +@pytest.mark.unvectorized +@given( + x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)), + data=st.data(), +) +def test_diff_append_prepend(x, data): + axis = data.draw( + st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(), + label="axis" + ) + if axis is None: + axis_kw = {"axis": -1} + n_axis = x.ndim - 1 + else: + axis_kw = {"axis": axis} + n_axis = axis + x.ndim if axis < 0 else axis + + n = data.draw(st.integers(1, min(x.shape[n_axis], 3))) + + append_shape = list(x.shape) + append_axis_len = data.draw(st.integers(1, 2*append_shape[n_axis]), label="append_axis") + append_shape[n_axis] = append_axis_len + append = data.draw(hh.arrays(dtype=x.dtype, shape=tuple(append_shape)), label="append") + + prepend_shape = list(x.shape) + prepend_axis_len = data.draw(st.integers(1, 2*prepend_shape[n_axis]), label="prepend_axis") + prepend_shape[n_axis] = prepend_axis_len + prepend = data.draw(hh.arrays(dtype=x.dtype, shape=tuple(prepend_shape)), label="prepend") + + out = xp.diff(x, **axis_kw, n=n, append=append, prepend=prepend) + + in_1 = xp.concat((prepend, x, append), **axis_kw) + out_1 = xp.diff(in_1, **axis_kw, n=n) + + assert out.shape == out_1.shape + for idx in sh.ndindex(out.shape): + assert out[idx] == out_1[idx], f"{idx = }" +