From 6344a37c5c8ee839e76d083d6d27ba38efdadf90 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 5 May 2025 17:53:16 +0200 Subject: [PATCH] ENH: test `where` with scalars --- ...est_operators_and_elementwise_functions.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 82ab3351..b4400afc 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1880,3 +1880,33 @@ def test_binary_with_scalars_bitwise(func_data, x1x2): refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 ) _check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2) + +@pytest.mark.unvectorized +@given( + x1x2=hh.array_and_py_scalar([xp.int32]), + data=st.data() +) +def test_where_with_scalars(x1x2, data): + x1, x2 = x1x2 + + if dh.is_scalar(x1): + dtype, shape = x2.dtype, x2.shape + x1_arr, x2_arr = xp.broadcast_to(xp.asarray(x1), shape), x2 + else: + dtype, shape = x1.dtype, x1.shape + x1_arr, x2_arr = x1, xp.broadcast_to(xp.asarray(x2), shape) + + condition = data.draw(hh.arrays(shape=shape, dtype=xp.bool)) + + out = xp.where(condition, x1, x2) + + assert out.dtype == dtype, f"where: got {out.dtype = } for {dtype=}, {x1=} and {x2=}" + assert out.shape == shape, f"where: got {out.shape = } for {shape=}, {x1=} and {x2=}" + + # value test + for idx in sh.ndindex(shape): + if condition[idx]: + assert out[idx] == x1_arr[idx] + else: + assert out[idx] == x2_arr[idx] +