diff --git a/tests/test_trig.py b/tests/test_trig.py new file mode 100644 index 0000000..ae62f7b --- /dev/null +++ b/tests/test_trig.py @@ -0,0 +1,157 @@ +import random + +import pytest + +import arrayfire_wrapper.dtypes as dtype +import arrayfire_wrapper.lib as wrapper + +from . import utility_functions as util + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10),), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +@pytest.mark.parametrize("dtype_name", util.get_all_types()) +def test_asin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: + """Test inverse sine operation across all supported data types.""" + util.check_type_supported(dtype_name) + values = wrapper.randu(shape, dtype_name) + result = wrapper.asin(values) + assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10),), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +@pytest.mark.parametrize("dtype_name", util.get_all_types()) +def test_acos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: + """Test inverse cosine operation across all supported data types.""" + util.check_type_supported(dtype_name) + values = wrapper.randu(shape, dtype_name) + result = wrapper.acos(values) + assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10),), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +@pytest.mark.parametrize("dtype_name", util.get_all_types()) +def test_atan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: + """Test inverse tan operation across all supported data types.""" + util.check_type_supported(dtype_name) + values = wrapper.randu(shape, dtype_name) + result = wrapper.atan(values) + assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10),), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +@pytest.mark.parametrize("dtype_name", util.get_float_types()) +def test_atan2_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: + """Test inverse tan operation across all supported data types.""" + util.check_type_supported(dtype_name) + if dtype_name == dtype.f16: + pytest.skip() + lhs = wrapper.randu(shape, dtype_name) + rhs = wrapper.randu(shape, dtype_name) + result = wrapper.atan2(lhs, rhs) + assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa + + +@pytest.mark.parametrize( + "invdtypes", + [ + dtype.int16, + dtype.bool, + ], +) +def test_atan2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None: + """Test inverse tan operation for unsupported data types.""" + with pytest.raises(RuntimeError): + wrapper.atan2(wrapper.randu((10, 10), invdtypes), wrapper.randu((10, 10), invdtypes)) + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10),), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +@pytest.mark.parametrize("dtype_name", util.get_all_types()) +def test_cos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: + """Test cosine operation across all supported data types.""" + util.check_type_supported(dtype_name) + values = wrapper.randu(shape, dtype_name) + result = wrapper.cos(values) + assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10),), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +@pytest.mark.parametrize("dtype_name", util.get_all_types()) +def test_sin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: + """Test sin operation across all supported data types.""" + util.check_type_supported(dtype_name) + values = wrapper.randu(shape, dtype_name) + result = wrapper.sin(values) + assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10),), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +@pytest.mark.parametrize("dtype_name", util.get_all_types()) +def test_tan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: + """Test tan operation across all supported data types.""" + util.check_type_supported(dtype_name) + values = wrapper.randu(shape, dtype_name) + result = wrapper.tan(values) + assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa diff --git a/tests/utility_functions.py b/tests/utility_functions.py index c89bab1..a8b054d 100644 --- a/tests/utility_functions.py +++ b/tests/utility_functions.py @@ -1,14 +1,13 @@ import pytest import arrayfire_wrapper.lib as wrapper -from arrayfire_wrapper.dtypes import Dtype, c32, c64, f16, f32, f64, s16, s32, s64, u8, u16, u32, u64 +from arrayfire_wrapper.dtypes import Dtype, b8, c32, c64, f16, f32, f64, s16, s32, s64, u8, u16, u32, u64 def check_type_supported(dtype: Dtype) -> None: """Checks to see if the specified type is supported by the current system""" if dtype in [f64, c64] and not wrapper.get_dbl_support(): pytest.skip("Device does not support double types") - if dtype == f16 and not wrapper.get_half_support(): pytest.skip("Device does not support half types.") @@ -25,4 +24,9 @@ def get_real_types() -> list: def get_all_types() -> list: """Returns all types""" - return [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64] + return [b8, s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64] + + +def get_float_types() -> list: + """Returns all types""" + return [f16, f32, f64]