Skip to content

Commit 188b1e9

Browse files
authored
Merge pull request #370 from ev-br/test_where_with_scalars
ENH: test `where` with scalars
2 parents 165d95f + 6344a37 commit 188b1e9

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1901,3 +1901,33 @@ def test_binary_with_scalars_bitwise(func_data, x1x2):
19011901
refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 )
19021902
_check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2)
19031903

1904+
1905+
@pytest.mark.unvectorized
1906+
@given(
1907+
x1x2=hh.array_and_py_scalar([xp.int32]),
1908+
data=st.data()
1909+
)
1910+
def test_where_with_scalars(x1x2, data):
1911+
x1, x2 = x1x2
1912+
1913+
if dh.is_scalar(x1):
1914+
dtype, shape = x2.dtype, x2.shape
1915+
x1_arr, x2_arr = xp.broadcast_to(xp.asarray(x1), shape), x2
1916+
else:
1917+
dtype, shape = x1.dtype, x1.shape
1918+
x1_arr, x2_arr = x1, xp.broadcast_to(xp.asarray(x2), shape)
1919+
1920+
condition = data.draw(hh.arrays(shape=shape, dtype=xp.bool))
1921+
1922+
out = xp.where(condition, x1, x2)
1923+
1924+
assert out.dtype == dtype, f"where: got {out.dtype = } for {dtype=}, {x1=} and {x2=}"
1925+
assert out.shape == shape, f"where: got {out.shape = } for {shape=}, {x1=} and {x2=}"
1926+
1927+
# value test
1928+
for idx in sh.ndindex(shape):
1929+
if condition[idx]:
1930+
assert out[idx] == x1_arr[idx]
1931+
else:
1932+
assert out[idx] == x2_arr[idx]
1933+

0 commit comments

Comments
 (0)