diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index d5344039..754b507d 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -467,10 +467,24 @@ def test_stack(shape, dtypes, kw, data): @pytest.mark.min_version("2023.12") @given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), data=st.data()) def test_tile(x, data): - repetitions = data.draw(st.lists(st.integers(1, 4), min_size=1, max_size=x.ndim + 1).map(tuple), label="repetitions") + repetitions = data.draw( + st.lists(st.integers(1, 4), min_size=1, max_size=x.ndim + 1).map(tuple), + label="repetitions" + ) out = xp.tile(x, repetitions) ph.assert_dtype("tile", in_dtype=x.dtype, out_dtype=out.dtype) - # TODO: shapes and values testing + # TODO: values testing + + # shape check; the notation is from the Array API docs + N, M = len(x.shape), len(repetitions) + if N > M: + S = x.shape + R = (1,)*(N - M) + repetitions + else: + S = (1,)*(M - N) + x.shape + R = repetitions + + assert out.shape == tuple(r*s for r, s in zip(R, S)) @pytest.mark.min_version("2023.12")