Skip to content

Commit b63b89b

Browse files
committed
ENH: test floor_div, remainder with scalars
1 parent 889d1dc commit b63b89b

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,6 +1791,10 @@ def _check_binary_with_scalars(func_data, x1x2):
17911791
)
17921792

17931793

1794+
def _filter_zero(x):
1795+
return x != 0 if dh.is_scalar(x) else (not xp.any(x == 0))
1796+
1797+
17941798
@pytest.mark.min_version("2024.12")
17951799
@pytest.mark.parametrize('func_data',
17961800
# xp_func, name, refimpl, kwargs, expected_dtype
@@ -1812,11 +1816,19 @@ def _check_binary_with_scalars(func_data, x1x2):
18121816
(xp.less_equal, "les_equal", operator.le, {}, xp.bool),
18131817
(xp.greater, "greater", operator.gt, {}, xp.bool),
18141818
(xp.greater_equal, "greater_equal", operator.ge, {}, xp.bool),
1819+
(xp.remainder, "remainder", operator.mod, {}, None),
1820+
(xp.floor_divide, "floor_divide", operator.floordiv, {}, None),
18151821
],
18161822
ids=lambda func_data: func_data[1] # use names for test IDs
18171823
)
18181824
@given(x1x2=hh.array_and_py_scalar(dh.real_float_dtypes))
18191825
def test_binary_with_scalars_real(func_data, x1x2):
1826+
1827+
if func_data[1] == "remainder":
1828+
assume(_filter_zero(x1x2[1]))
1829+
if func_data[1] == "floor_divide":
1830+
assume(_filter_zero(x1x2[0]) and _filter_zero(x1x2[1]))
1831+
18201832
_check_binary_with_scalars(func_data, x1x2)
18211833

18221834

0 commit comments

Comments
 (0)