From 3828d86e263217a7c844927393bb790384bc4860 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 23 Nov 2023 23:09:05 -0600 Subject: [PATCH 1/4] Remove marks broken_complex from all tests After transition to oneapi::experimental namespace functions for complex types in gh-1411, all tests pass. --- dpctl/tests/elementwise/test_exp.py | 1 - dpctl/tests/elementwise/test_hyperbolic.py | 1 - dpctl/tests/elementwise/test_sqrt.py | 1 - dpctl/tests/elementwise/test_trigonometric.py | 1 - 4 files changed, 4 deletions(-) diff --git a/dpctl/tests/elementwise/test_exp.py b/dpctl/tests/elementwise/test_exp.py index 4886c0cb78..96314ad46c 100644 --- a/dpctl/tests/elementwise/test_exp.py +++ b/dpctl/tests/elementwise/test_exp.py @@ -198,7 +198,6 @@ def test_exp_complex_strided(dtype): ) -@pytest.mark.broken_complex @pytest.mark.parametrize("dtype", ["c8", "c16"]) def test_exp_complex_special_cases(dtype): q = get_queue_or_skip() diff --git a/dpctl/tests/elementwise/test_hyperbolic.py b/dpctl/tests/elementwise/test_hyperbolic.py index 0186f4c443..d61b778114 100644 --- a/dpctl/tests/elementwise/test_hyperbolic.py +++ b/dpctl/tests/elementwise/test_hyperbolic.py @@ -270,7 +270,6 @@ def test_hyper_real_special_cases(np_call, dpt_call, dtype): assert_allclose(dpt.asnumpy(dpt_call(yf)), Y_np, atol=tol, rtol=tol) -@pytest.mark.broken_complex @pytest.mark.parametrize("np_call, dpt_call", _all_funcs) @pytest.mark.parametrize("dtype", ["c8", "c16"]) def test_hyper_complex_special_cases(np_call, dpt_call, dtype): diff --git a/dpctl/tests/elementwise/test_sqrt.py b/dpctl/tests/elementwise/test_sqrt.py index 862f64ccbd..7e705f0721 100644 --- a/dpctl/tests/elementwise/test_sqrt.py +++ b/dpctl/tests/elementwise/test_sqrt.py @@ -157,7 +157,6 @@ def test_sqrt_real_fp_special_values(dtype): assert dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True) -@pytest.mark.broken_complex @pytest.mark.parametrize("dtype", _complex_fp_dtypes) def test_sqrt_complex_fp_special_values(dtype): q = get_queue_or_skip() diff --git a/dpctl/tests/elementwise/test_trigonometric.py b/dpctl/tests/elementwise/test_trigonometric.py index 74121311fb..aa276f20f8 100644 --- a/dpctl/tests/elementwise/test_trigonometric.py +++ b/dpctl/tests/elementwise/test_trigonometric.py @@ -267,7 +267,6 @@ def test_trig_real_special_cases(np_call, dpt_call, dtype): assert_allclose(dpt.asnumpy(dpt_call(yf)), Y_np, atol=tol, rtol=tol) -@pytest.mark.broken_complex @pytest.mark.parametrize("np_call, dpt_call", _all_funcs) @pytest.mark.parametrize("dtype", ["c8", "c16"]) def test_trig_complex_special_cases(np_call, dpt_call, dtype): From f2dfb8157cf43be41471975309e1c642c19bb267 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 27 Nov 2023 06:12:41 -0600 Subject: [PATCH 2/4] Do not run special case tests for complex values on Windows Since NumPy's output is not the same on Linux/Windows, tests fail on Windows, so they are skipped. Instead, f(conj(arg)) = conj(f(arg)) is tested on all platforms for those special values. --- dpctl/tests/elementwise/test_hyperbolic.py | 32 ++++++-- dpctl/tests/elementwise/test_trigonometric.py | 76 ++++++++++++++++--- 2 files changed, 90 insertions(+), 18 deletions(-) diff --git a/dpctl/tests/elementwise/test_hyperbolic.py b/dpctl/tests/elementwise/test_hyperbolic.py index d61b778114..038f30fa88 100644 --- a/dpctl/tests/elementwise/test_hyperbolic.py +++ b/dpctl/tests/elementwise/test_hyperbolic.py @@ -15,6 +15,7 @@ # limitations under the License. import itertools +import os import numpy as np import pytest @@ -270,6 +271,28 @@ def test_hyper_real_special_cases(np_call, dpt_call, dtype): assert_allclose(dpt.asnumpy(dpt_call(yf)), Y_np, atol=tol, rtol=tol) +@pytest.mark.parametrize("np_call, dpt_call", _all_funcs) +@pytest.mark.parametrize("dtype", ["c8", "c16"]) +def test_hyper_complex_special_cases_conj_property(np_call, dpt_call, dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + x = [np.nan, np.inf, -np.inf, +0.0, -0.0, +1.0, -1.0] + xc = [complex(*val) for val in itertools.product(x, repeat=2)] + + Xc_np = np.array(xc, dtype=dtype) + Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q) + + tol = 50 * dpt.finfo(dtype).resolution + Y = dpt_call(Xc) + Yc = dpt_call(dpt.conj(Xc)) + + dpt.allclose(Y, dpt.conj(Yc), atol=tol, rtol=tol) + + +@pytest.mark.skipif( + os.name != "posix", reason="Known to fail on Windows due to bug in NumPy" +) @pytest.mark.parametrize("np_call, dpt_call", _all_funcs) @pytest.mark.parametrize("dtype", ["c8", "c16"]) def test_hyper_complex_special_cases(np_call, dpt_call, dtype): @@ -286,9 +309,6 @@ def test_hyper_complex_special_cases(np_call, dpt_call, dtype): Ynp = np_call(Xc_np) tol = 50 * dpt.finfo(dtype).resolution - assert_allclose( - dpt.asnumpy(dpt.real(dpt_call(Xc))), np.real(Ynp), atol=tol, rtol=tol - ) - assert_allclose( - dpt.asnumpy(dpt.imag(dpt_call(Xc))), np.imag(Ynp), atol=tol, rtol=tol - ) + Y = dpt_call(Xc) + assert_allclose(dpt.asnumpy(dpt.real(Y)), np.real(Ynp), atol=tol, rtol=tol) + assert_allclose(dpt.asnumpy(dpt.imag(Y)), np.imag(Ynp), atol=tol, rtol=tol) diff --git a/dpctl/tests/elementwise/test_trigonometric.py b/dpctl/tests/elementwise/test_trigonometric.py index aa276f20f8..991a5d0328 100644 --- a/dpctl/tests/elementwise/test_trigonometric.py +++ b/dpctl/tests/elementwise/test_trigonometric.py @@ -15,6 +15,7 @@ # limitations under the License. import itertools +import os import numpy as np import pytest @@ -93,7 +94,7 @@ def test_trig_complex_contig(np_call, dpt_call, dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q) - n_seq = 100 + n_seq = 256 n_rep = 137 low = -9.0 high = 9.0 @@ -101,7 +102,17 @@ def test_trig_complex_contig(np_call, dpt_call, dtype): x2 = np.random.uniform(low=low, high=high, size=n_seq) Xnp = x1 + 1j * x2 - X = dpt.asarray(np.repeat(Xnp, n_rep), dtype=dtype, sycl_queue=q) + # stay away from poles and branch lines + modulus = np.abs(Xnp) + sel = np.logical_or( + modulus < 0.9, + np.logical_and( + modulus > 1.2, np.minimum(np.abs(x2), np.abs(x1)) > 0.05 + ), + ) + Xnp = Xnp[sel] + + X = dpt.repeat(dpt.asarray(Xnp, dtype=dtype, sycl_queue=q), n_rep) Y = dpt_call(X) expected = np.repeat(np_call(Xnp), n_rep) @@ -234,10 +245,30 @@ def test_trig_complex_strided(np_call, dpt_call, dtype): low = -9.0 high = 9.0 + while True: + x1 = np.random.uniform(low=low, high=high, size=2 * sum(sizes)) + x2 = np.random.uniform(low=low, high=high, size=2 * sum(sizes)) + Xnp_all = np.array( + [complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype + ) + + # stay away from poles and branch lines + modulus = np.abs(Xnp_all) + sel = np.logical_or( + modulus < 0.9, + np.logical_and( + modulus > 1.2, np.minimum(np.abs(x2), np.abs(x1)) > 0.05 + ), + ) + Xnp_all = Xnp_all[sel] + if Xnp_all.size > sum(sizes): + break + + pos = 0 for ii in sizes: - x1 = np.random.uniform(low=low, high=high, size=ii) - x2 = np.random.uniform(low=low, high=high, size=ii) - Xnp = np.array([complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype) + pos = pos + ii + Xnp = Xnp_all[:pos] + Xnp = Xnp[-ii:] X = dpt.asarray(Xnp) Ynp = np_call(Xnp) for jj in strides: @@ -264,12 +295,36 @@ def test_trig_real_special_cases(np_call, dpt_call, dtype): Y_np = np_call(xf) tol = 8 * dpt.finfo(dtype).resolution - assert_allclose(dpt.asnumpy(dpt_call(yf)), Y_np, atol=tol, rtol=tol) + Y = dpt_call(yf) + assert_allclose(dpt.asnumpy(Y), Y_np, atol=tol, rtol=tol) + + +@pytest.mark.parametrize("np_call, dpt_call", _all_funcs) +@pytest.mark.parametrize("dtype", ["c8", "c16"]) +def test_trig_complex_special_cases_conj_property(np_call, dpt_call, dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + x = [np.nan, np.inf, -np.inf, +0.0, -0.0, +1.0, -1.0] + xc = [complex(*val) for val in itertools.product(x, repeat=2)] + + Xc_np = np.array(xc, dtype=dtype) + Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q) + + tol = 50 * dpt.finfo(dtype).resolution + Y = dpt_call(Xc) + Yc = dpt_call(dpt.conj(Xc)) + dpt.allclose(Y, dpt.conj(Yc), atol=tol, rtol=tol) + + +@pytest.mark.skipif( + os.name != "posix", reason="Known to fail on Windows due to bug in NumPy" +) @pytest.mark.parametrize("np_call, dpt_call", _all_funcs) @pytest.mark.parametrize("dtype", ["c8", "c16"]) def test_trig_complex_special_cases(np_call, dpt_call, dtype): + q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q) @@ -283,9 +338,6 @@ def test_trig_complex_special_cases(np_call, dpt_call, dtype): Ynp = np_call(Xc_np) tol = 50 * dpt.finfo(dtype).resolution - assert_allclose( - dpt.asnumpy(dpt.real(dpt_call(Xc))), np.real(Ynp), atol=tol, rtol=tol - ) - assert_allclose( - dpt.asnumpy(dpt.imag(dpt_call(Xc))), np.imag(Ynp), atol=tol, rtol=tol - ) + Y = dpt_call(Xc) + assert_allclose(dpt.asnumpy(dpt.real(Y)), np.real(Ynp), atol=tol, rtol=tol) + assert_allclose(dpt.asnumpy(dpt.imag(Y)), np.imag(Ynp), atol=tol, rtol=tol) From d1ff841f409cd9f20034d942d1b33f0b2306aa01 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 29 Nov 2023 16:06:09 -0600 Subject: [PATCH 3/4] Make test failure for test_sqrt_fp_complex_special_case more informative --- dpctl/tests/elementwise/test_sqrt.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dpctl/tests/elementwise/test_sqrt.py b/dpctl/tests/elementwise/test_sqrt.py index 7e705f0721..a1da30610b 100644 --- a/dpctl/tests/elementwise/test_sqrt.py +++ b/dpctl/tests/elementwise/test_sqrt.py @@ -178,4 +178,11 @@ def test_sqrt_complex_fp_special_values(dtype): expected = dpt.asarray(expected_np, dtype=dtype) tol = dpt.finfo(r.dtype).resolution - assert dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True) + if not dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True): + for i in range(r.shape[0]): + assert dpt.allclose( + r[i], expected[i], atol=tol, rtol=tol, equal_nan=True + ), ( + f"Test failed for input {z[i]}, i.e. {c_[i]} for index {i}" + f", results were {r[i]} vs. {expected[i]}" + ) From 3bcc7dfe70f6d96318a589dc0684480b88e34a9f Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 5 Dec 2023 10:01:32 -0600 Subject: [PATCH 4/4] Do not fail test_sqrt_complex_fp_special_values on failure, but skip the test with message --- dpctl/tests/elementwise/test_sqrt.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/dpctl/tests/elementwise/test_sqrt.py b/dpctl/tests/elementwise/test_sqrt.py index a1da30610b..426ce7403d 100644 --- a/dpctl/tests/elementwise/test_sqrt.py +++ b/dpctl/tests/elementwise/test_sqrt.py @@ -180,9 +180,13 @@ def test_sqrt_complex_fp_special_values(dtype): if not dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True): for i in range(r.shape[0]): - assert dpt.allclose( + failure_data = [] + if not dpt.allclose( r[i], expected[i], atol=tol, rtol=tol, equal_nan=True - ), ( - f"Test failed for input {z[i]}, i.e. {c_[i]} for index {i}" - f", results were {r[i]} vs. {expected[i]}" - ) + ): + msg = ( + f"Test failed for input {z[i]}, i.e. {c_[i]} for index {i}" + ) + msg += f", results were {r[i]} vs. {expected[i]}" + failure_data.extend(msg) + pytest.skip(reason=msg)