@@ -1880,3 +1880,33 @@ def test_binary_with_scalars_bitwise(func_data, x1x2):
1880
1880
refimpl_ = lambda l , r : mock_int_dtype (refimpl (l , r ), xp .int32 )
1881
1881
_check_binary_with_scalars ((func_name , refimpl_ , kwargs , expected ), x1x2 )
1882
1882
1883
+
1884
+ @pytest .mark .unvectorized
1885
+ @given (
1886
+ x1x2 = hh .array_and_py_scalar ([xp .int32 ]),
1887
+ data = st .data ()
1888
+ )
1889
+ def test_where_with_scalars (x1x2 , data ):
1890
+ x1 , x2 = x1x2
1891
+
1892
+ if dh .is_scalar (x1 ):
1893
+ dtype , shape = x2 .dtype , x2 .shape
1894
+ x1_arr , x2_arr = xp .broadcast_to (xp .asarray (x1 ), shape ), x2
1895
+ else :
1896
+ dtype , shape = x1 .dtype , x1 .shape
1897
+ x1_arr , x2_arr = x1 , xp .broadcast_to (xp .asarray (x2 ), shape )
1898
+
1899
+ condition = data .draw (hh .arrays (shape = shape , dtype = xp .bool ))
1900
+
1901
+ out = xp .where (condition , x1 , x2 )
1902
+
1903
+ assert out .dtype == dtype , f"where: got { out .dtype = } for { dtype = } , { x1 = } and { x2 = } "
1904
+ assert out .shape == shape , f"where: got { out .shape = } for { shape = } , { x1 = } and { x2 = } "
1905
+
1906
+ # value test
1907
+ for idx in sh .ndindex (shape ):
1908
+ if condition [idx ]:
1909
+ assert out [idx ] == x1_arr [idx ]
1910
+ else :
1911
+ assert out [idx ] == x2_arr [idx ]
1912
+
0 commit comments