Skip to content

Commit 6344a37

Browse files
committed
ENH: test where with scalars
1 parent 31eec9d commit 6344a37

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,3 +1880,33 @@ def test_binary_with_scalars_bitwise(func_data, x1x2):
18801880
refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 )
18811881
_check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2)
18821882

1883+
1884+
@pytest.mark.unvectorized
1885+
@given(
1886+
x1x2=hh.array_and_py_scalar([xp.int32]),
1887+
data=st.data()
1888+
)
1889+
def test_where_with_scalars(x1x2, data):
1890+
x1, x2 = x1x2
1891+
1892+
if dh.is_scalar(x1):
1893+
dtype, shape = x2.dtype, x2.shape
1894+
x1_arr, x2_arr = xp.broadcast_to(xp.asarray(x1), shape), x2
1895+
else:
1896+
dtype, shape = x1.dtype, x1.shape
1897+
x1_arr, x2_arr = x1, xp.broadcast_to(xp.asarray(x2), shape)
1898+
1899+
condition = data.draw(hh.arrays(shape=shape, dtype=xp.bool))
1900+
1901+
out = xp.where(condition, x1, x2)
1902+
1903+
assert out.dtype == dtype, f"where: got {out.dtype = } for {dtype=}, {x1=} and {x2=}"
1904+
assert out.shape == shape, f"where: got {out.shape = } for {shape=}, {x1=} and {x2=}"
1905+
1906+
# value test
1907+
for idx in sh.ndindex(shape):
1908+
if condition[idx]:
1909+
assert out[idx] == x1_arr[idx]
1910+
else:
1911+
assert out[idx] == x2_arr[idx]
1912+

0 commit comments

Comments
 (0)