From a3f255581f4c7bec3ef71c270aae7a50ce1d1c4b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 5 May 2025 18:18:49 +0200 Subject: [PATCH 1/3] test pow() with scalars --- array_api_tests/test_operators_and_elementwise_functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 82ab3351..1f3c3a26 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -724,7 +724,9 @@ def _assert_correctness_binary( x1a, x2a = in_arrs ph.assert_dtype(name, in_dtype=in_dtypes, out_dtype=out.dtype, expected=expected_dtype) ph.assert_result_shape(name, in_shapes=in_shapes, out_shape=out.shape) - binary_assert_against_refimpl(name, x1a, x2a, out, func, **kwargs) + check_values = kwargs.pop('check_values', None) + if check_values: + binary_assert_against_refimpl(name, x1a, x2a, out, func, **kwargs) @pytest.mark.parametrize("ctx", make_unary_params("abs", dh.numeric_dtypes)) @@ -1824,6 +1826,7 @@ def _filter_zero(x): ("less_equal", operator.le, {}, xp.bool), ("greater", operator.gt, {}, xp.bool), ("greater_equal", operator.ge, {}, xp.bool), + ("pow", operator.pow, {'check_values': False}, None) # too finicky for pow ], ids=lambda func_data: func_data[0] # use names for test IDs ) From 72de19babf3518d6d923f15d69fc3ff9751709b5 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 5 May 2025 22:02:34 +0200 Subject: [PATCH 2/3] ENH: test bitwise_{left,right}_shift with scalars --- array_api_tests/hypothesis_helpers.py | 16 +++++++++++----- .../test_operators_and_elementwise_functions.py | 17 +++++++++++++++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 1bab8d34..3d1a3ec9 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -456,8 +456,12 @@ def scalars(draw, dtypes, finite=False, **kwds): dtypes should be one of the shared_* dtypes strategies. """ dtype = draw(dtypes) + mM = kwds.pop('mM', None) if dh.is_int_dtype(dtype): - m, M = dh.dtype_ranges[dtype] + if mM is None: + m, M = dh.dtype_ranges[dtype] + else: + m, M = mM return draw(integers(m, M)) elif dtype == bool_dtype: return draw(booleans()) @@ -588,18 +592,20 @@ def two_mutual_arrays( @composite -def array_and_py_scalar(draw, dtypes): +def array_and_py_scalar(draw, dtypes, mM=None, positive=False): """Draw a pair: (array, scalar) or (scalar, array).""" dtype = draw(sampled_from(dtypes)) - scalar_var = draw(scalars(just(dtype), finite=True, - **{'min_value': 1/ (2<<5), 'max_value': 2<<5} - )) + scalar_var = draw(scalars(just(dtype), finite=True, mM=mM)) + if positive: + assume (scalar_var > 0) elements={} if dtype in dh.real_float_dtypes: elements = {'allow_nan': False, 'allow_infinity': False, 'min_value': 1.0 / (2<<5), 'max_value': 2<<5} + if positive: + elements = {'min_value': 0} array_var = draw(arrays(dtype, shape=shapes(min_dims=1), elements=elements)) if draw(booleans()): diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 1f3c3a26..b923f016 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1883,3 +1883,20 @@ def test_binary_with_scalars_bitwise(func_data, x1x2): refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 ) _check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2) + +@pytest.mark.min_version("2024.12") +@pytest.mark.parametrize('func_data', + # func_name, refimpl, kwargs, expected_dtype + [ + ("bitwise_left_shift", operator.lshift, {}, None), + ("bitwise_right_shift", operator.rshift, {}, None), + ], + ids=lambda func_data: func_data[0] # use names for test IDs +) +@given(x1x2=hh.array_and_py_scalar([xp.int32], positive=True, mM=(1, 3))) +def test_binary_with_scalars_bitwise_shifts(func_data, x1x2): + func_name, refimpl, kwargs, expected = func_data + # repack the refimpl + refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 ) + _check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2) + From e4d2337595ecb1e1c6a9c6a4487c73d6fc37cafb Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 8 May 2025 23:00:22 +0200 Subject: [PATCH 3/3] Update array_api_tests/test_operators_and_elementwise_functions.py --- array_api_tests/test_operators_and_elementwise_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index b923f016..aeedc3d7 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1826,7 +1826,7 @@ def _filter_zero(x): ("less_equal", operator.le, {}, xp.bool), ("greater", operator.gt, {}, xp.bool), ("greater_equal", operator.ge, {}, xp.bool), - ("pow", operator.pow, {'check_values': False}, None) # too finicky for pow + ("pow", operator.pow, {'check_values': False}, None) # value tests are too finicky for pow ], ids=lambda func_data: func_data[0] # use names for test IDs )