@@ -1901,3 +1901,33 @@ def test_binary_with_scalars_bitwise(func_data, x1x2):
1901
1901
refimpl_ = lambda l , r : mock_int_dtype (refimpl (l , r ), xp .int32 )
1902
1902
_check_binary_with_scalars ((func_name , refimpl_ , kwargs , expected ), x1x2 )
1903
1903
1904
+
1905
+ @pytest .mark .unvectorized
1906
+ @given (
1907
+ x1x2 = hh .array_and_py_scalar ([xp .int32 ]),
1908
+ data = st .data ()
1909
+ )
1910
+ def test_where_with_scalars (x1x2 , data ):
1911
+ x1 , x2 = x1x2
1912
+
1913
+ if dh .is_scalar (x1 ):
1914
+ dtype , shape = x2 .dtype , x2 .shape
1915
+ x1_arr , x2_arr = xp .broadcast_to (xp .asarray (x1 ), shape ), x2
1916
+ else :
1917
+ dtype , shape = x1 .dtype , x1 .shape
1918
+ x1_arr , x2_arr = x1 , xp .broadcast_to (xp .asarray (x2 ), shape )
1919
+
1920
+ condition = data .draw (hh .arrays (shape = shape , dtype = xp .bool ))
1921
+
1922
+ out = xp .where (condition , x1 , x2 )
1923
+
1924
+ assert out .dtype == dtype , f"where: got { out .dtype = } for { dtype = } , { x1 = } and { x2 = } "
1925
+ assert out .shape == shape , f"where: got { out .shape = } for { shape = } , { x1 = } and { x2 = } "
1926
+
1927
+ # value test
1928
+ for idx in sh .ndindex (shape ):
1929
+ if condition [idx ]:
1930
+ assert out [idx ] == x1_arr [idx ]
1931
+ else :
1932
+ assert out [idx ] == x2_arr [idx ]
1933
+
0 commit comments