Skip to content

Commit 165d95f

Browse files
authored
ENH: test diff (#369)
- Add a basic `diff` test - test `diff`'s append and prepend arguments in a separate test
1 parent fc72a0c commit 165d95f

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

array_api_tests/test_utility_functions.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,79 @@ def test_any(x, data):
6363
expected = any(elements)
6464
ph.assert_scalar_equals("any", type_=scalar_type, idx=out_idx,
6565
out=result, expected=expected, kw=kw)
66+
67+
68+
@pytest.mark.unvectorized
69+
@pytest.mark.min_version("2024.12")
70+
@given(
71+
x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)),
72+
data=st.data(),
73+
)
74+
def test_diff(x, data):
75+
axis = data.draw(
76+
st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(),
77+
label="axis"
78+
)
79+
if axis is None:
80+
axis_kw = {"axis": -1}
81+
n_axis = x.ndim - 1
82+
else:
83+
axis_kw = {"axis": axis}
84+
n_axis = axis + x.ndim if axis < 0 else axis
85+
86+
n = data.draw(st.integers(1, min(x.shape[n_axis], 3)))
87+
88+
out = xp.diff(x, **axis_kw, n=n)
89+
90+
expected_shape = list(x.shape)
91+
expected_shape[n_axis] -= n
92+
93+
assert out.shape == tuple(expected_shape)
94+
95+
# value test
96+
if n == 1:
97+
for idx in sh.ndindex(out.shape):
98+
l = list(idx)
99+
l[n_axis] += 1
100+
assert out[idx] == x[tuple(l)] - x[idx], f"diff failed with {idx = }"
101+
102+
103+
@pytest.mark.min_version("2024.12")
104+
@pytest.mark.unvectorized
105+
@given(
106+
x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)),
107+
data=st.data(),
108+
)
109+
def test_diff_append_prepend(x, data):
110+
axis = data.draw(
111+
st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(),
112+
label="axis"
113+
)
114+
if axis is None:
115+
axis_kw = {"axis": -1}
116+
n_axis = x.ndim - 1
117+
else:
118+
axis_kw = {"axis": axis}
119+
n_axis = axis + x.ndim if axis < 0 else axis
120+
121+
n = data.draw(st.integers(1, min(x.shape[n_axis], 3)))
122+
123+
append_shape = list(x.shape)
124+
append_axis_len = data.draw(st.integers(1, 2*append_shape[n_axis]), label="append_axis")
125+
append_shape[n_axis] = append_axis_len
126+
append = data.draw(hh.arrays(dtype=x.dtype, shape=tuple(append_shape)), label="append")
127+
128+
prepend_shape = list(x.shape)
129+
prepend_axis_len = data.draw(st.integers(1, 2*prepend_shape[n_axis]), label="prepend_axis")
130+
prepend_shape[n_axis] = prepend_axis_len
131+
prepend = data.draw(hh.arrays(dtype=x.dtype, shape=tuple(prepend_shape)), label="prepend")
132+
133+
out = xp.diff(x, **axis_kw, n=n, append=append, prepend=prepend)
134+
135+
in_1 = xp.concat((prepend, x, append), **axis_kw)
136+
out_1 = xp.diff(in_1, **axis_kw, n=n)
137+
138+
assert out.shape == out_1.shape
139+
for idx in sh.ndindex(out.shape):
140+
assert out[idx] == out_1[idx], f"{idx = }"
141+

0 commit comments

Comments
 (0)