@@ -1791,6 +1791,10 @@ def _check_binary_with_scalars(func_data, x1x2):
1791
1791
)
1792
1792
1793
1793
1794
+ def _filter_zero (x ):
1795
+ return x != 0 if dh .is_scalar (x ) else (not xp .any (x == 0 ))
1796
+
1797
+
1794
1798
@pytest .mark .min_version ("2024.12" )
1795
1799
@pytest .mark .parametrize ('func_data' ,
1796
1800
# xp_func, name, refimpl, kwargs, expected_dtype
@@ -1812,11 +1816,19 @@ def _check_binary_with_scalars(func_data, x1x2):
1812
1816
(xp .less_equal , "les_equal" , operator .le , {}, xp .bool ),
1813
1817
(xp .greater , "greater" , operator .gt , {}, xp .bool ),
1814
1818
(xp .greater_equal , "greater_equal" , operator .ge , {}, xp .bool ),
1819
+ (xp .remainder , "remainder" , operator .mod , {}, None ),
1820
+ (xp .floor_divide , "floor_divide" , operator .floordiv , {}, None ),
1815
1821
],
1816
1822
ids = lambda func_data : func_data [1 ] # use names for test IDs
1817
1823
)
1818
1824
@given (x1x2 = hh .array_and_py_scalar (dh .real_float_dtypes ))
1819
1825
def test_binary_with_scalars_real (func_data , x1x2 ):
1826
+
1827
+ if func_data [1 ] == "remainder" :
1828
+ assume (_filter_zero (x1x2 [1 ]))
1829
+ if func_data [1 ] == "floor_divide" :
1830
+ assume (_filter_zero (x1x2 [0 ]) and _filter_zero (x1x2 [1 ]))
1831
+
1820
1832
_check_binary_with_scalars (func_data , x1x2 )
1821
1833
1822
1834
0 commit comments