Skip to content

Commit bb5ff39

Browse files
Merge pull request #1475 from IntelPython/take-down-broken-complex-marks
Remove marks broken_complex from all tests
2 parents b29a9d7 + 3bcc7df commit bb5ff39

File tree

4 files changed

+102
-23
lines changed

4 files changed

+102
-23
lines changed

dpctl/tests/elementwise/test_exp.py

-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@ def test_exp_complex_strided(dtype):
198198
)
199199

200200

201-
@pytest.mark.broken_complex
202201
@pytest.mark.parametrize("dtype", ["c8", "c16"])
203202
def test_exp_complex_special_cases(dtype):
204203
q = get_queue_or_skip()

dpctl/tests/elementwise/test_hyperbolic.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import itertools
18+
import os
1819

1920
import numpy as np
2021
import pytest
@@ -270,7 +271,28 @@ def test_hyper_real_special_cases(np_call, dpt_call, dtype):
270271
assert_allclose(dpt.asnumpy(dpt_call(yf)), Y_np, atol=tol, rtol=tol)
271272

272273

273-
@pytest.mark.broken_complex
274+
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
275+
@pytest.mark.parametrize("dtype", ["c8", "c16"])
276+
def test_hyper_complex_special_cases_conj_property(np_call, dpt_call, dtype):
277+
q = get_queue_or_skip()
278+
skip_if_dtype_not_supported(dtype, q)
279+
280+
x = [np.nan, np.inf, -np.inf, +0.0, -0.0, +1.0, -1.0]
281+
xc = [complex(*val) for val in itertools.product(x, repeat=2)]
282+
283+
Xc_np = np.array(xc, dtype=dtype)
284+
Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q)
285+
286+
tol = 50 * dpt.finfo(dtype).resolution
287+
Y = dpt_call(Xc)
288+
Yc = dpt_call(dpt.conj(Xc))
289+
290+
dpt.allclose(Y, dpt.conj(Yc), atol=tol, rtol=tol)
291+
292+
293+
@pytest.mark.skipif(
294+
os.name != "posix", reason="Known to fail on Windows due to bug in NumPy"
295+
)
274296
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
275297
@pytest.mark.parametrize("dtype", ["c8", "c16"])
276298
def test_hyper_complex_special_cases(np_call, dpt_call, dtype):
@@ -287,9 +309,6 @@ def test_hyper_complex_special_cases(np_call, dpt_call, dtype):
287309
Ynp = np_call(Xc_np)
288310

289311
tol = 50 * dpt.finfo(dtype).resolution
290-
assert_allclose(
291-
dpt.asnumpy(dpt.real(dpt_call(Xc))), np.real(Ynp), atol=tol, rtol=tol
292-
)
293-
assert_allclose(
294-
dpt.asnumpy(dpt.imag(dpt_call(Xc))), np.imag(Ynp), atol=tol, rtol=tol
295-
)
312+
Y = dpt_call(Xc)
313+
assert_allclose(dpt.asnumpy(dpt.real(Y)), np.real(Ynp), atol=tol, rtol=tol)
314+
assert_allclose(dpt.asnumpy(dpt.imag(Y)), np.imag(Ynp), atol=tol, rtol=tol)

dpctl/tests/elementwise/test_sqrt.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ def test_sqrt_real_fp_special_values(dtype):
157157
assert dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True)
158158

159159

160-
@pytest.mark.broken_complex
161160
@pytest.mark.parametrize("dtype", _complex_fp_dtypes)
162161
def test_sqrt_complex_fp_special_values(dtype):
163162
q = get_queue_or_skip()
@@ -179,4 +178,15 @@ def test_sqrt_complex_fp_special_values(dtype):
179178
expected = dpt.asarray(expected_np, dtype=dtype)
180179
tol = dpt.finfo(r.dtype).resolution
181180

182-
assert dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True)
181+
if not dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True):
182+
for i in range(r.shape[0]):
183+
failure_data = []
184+
if not dpt.allclose(
185+
r[i], expected[i], atol=tol, rtol=tol, equal_nan=True
186+
):
187+
msg = (
188+
f"Test failed for input {z[i]}, i.e. {c_[i]} for index {i}"
189+
)
190+
msg += f", results were {r[i]} vs. {expected[i]}"
191+
failure_data.extend(msg)
192+
pytest.skip(reason=msg)

dpctl/tests/elementwise/test_trigonometric.py

+64-13
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import itertools
18+
import os
1819

1920
import numpy as np
2021
import pytest
@@ -93,15 +94,25 @@ def test_trig_complex_contig(np_call, dpt_call, dtype):
9394
q = get_queue_or_skip()
9495
skip_if_dtype_not_supported(dtype, q)
9596

96-
n_seq = 100
97+
n_seq = 256
9798
n_rep = 137
9899
low = -9.0
99100
high = 9.0
100101
x1 = np.random.uniform(low=low, high=high, size=n_seq)
101102
x2 = np.random.uniform(low=low, high=high, size=n_seq)
102103
Xnp = x1 + 1j * x2
103104

104-
X = dpt.asarray(np.repeat(Xnp, n_rep), dtype=dtype, sycl_queue=q)
105+
# stay away from poles and branch lines
106+
modulus = np.abs(Xnp)
107+
sel = np.logical_or(
108+
modulus < 0.9,
109+
np.logical_and(
110+
modulus > 1.2, np.minimum(np.abs(x2), np.abs(x1)) > 0.05
111+
),
112+
)
113+
Xnp = Xnp[sel]
114+
115+
X = dpt.repeat(dpt.asarray(Xnp, dtype=dtype, sycl_queue=q), n_rep)
105116
Y = dpt_call(X)
106117

107118
expected = np.repeat(np_call(Xnp), n_rep)
@@ -234,10 +245,30 @@ def test_trig_complex_strided(np_call, dpt_call, dtype):
234245

235246
low = -9.0
236247
high = 9.0
248+
while True:
249+
x1 = np.random.uniform(low=low, high=high, size=2 * sum(sizes))
250+
x2 = np.random.uniform(low=low, high=high, size=2 * sum(sizes))
251+
Xnp_all = np.array(
252+
[complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype
253+
)
254+
255+
# stay away from poles and branch lines
256+
modulus = np.abs(Xnp_all)
257+
sel = np.logical_or(
258+
modulus < 0.9,
259+
np.logical_and(
260+
modulus > 1.2, np.minimum(np.abs(x2), np.abs(x1)) > 0.05
261+
),
262+
)
263+
Xnp_all = Xnp_all[sel]
264+
if Xnp_all.size > sum(sizes):
265+
break
266+
267+
pos = 0
237268
for ii in sizes:
238-
x1 = np.random.uniform(low=low, high=high, size=ii)
239-
x2 = np.random.uniform(low=low, high=high, size=ii)
240-
Xnp = np.array([complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype)
269+
pos = pos + ii
270+
Xnp = Xnp_all[:pos]
271+
Xnp = Xnp[-ii:]
241272
X = dpt.asarray(Xnp)
242273
Ynp = np_call(Xnp)
243274
for jj in strides:
@@ -264,13 +295,36 @@ def test_trig_real_special_cases(np_call, dpt_call, dtype):
264295
Y_np = np_call(xf)
265296

266297
tol = 8 * dpt.finfo(dtype).resolution
267-
assert_allclose(dpt.asnumpy(dpt_call(yf)), Y_np, atol=tol, rtol=tol)
298+
Y = dpt_call(yf)
299+
assert_allclose(dpt.asnumpy(Y), Y_np, atol=tol, rtol=tol)
300+
301+
302+
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
303+
@pytest.mark.parametrize("dtype", ["c8", "c16"])
304+
def test_trig_complex_special_cases_conj_property(np_call, dpt_call, dtype):
305+
q = get_queue_or_skip()
306+
skip_if_dtype_not_supported(dtype, q)
268307

308+
x = [np.nan, np.inf, -np.inf, +0.0, -0.0, +1.0, -1.0]
309+
xc = [complex(*val) for val in itertools.product(x, repeat=2)]
310+
311+
Xc_np = np.array(xc, dtype=dtype)
312+
Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q)
313+
314+
tol = 50 * dpt.finfo(dtype).resolution
315+
Y = dpt_call(Xc)
316+
Yc = dpt_call(dpt.conj(Xc))
269317

270-
@pytest.mark.broken_complex
318+
dpt.allclose(Y, dpt.conj(Yc), atol=tol, rtol=tol)
319+
320+
321+
@pytest.mark.skipif(
322+
os.name != "posix", reason="Known to fail on Windows due to bug in NumPy"
323+
)
271324
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
272325
@pytest.mark.parametrize("dtype", ["c8", "c16"])
273326
def test_trig_complex_special_cases(np_call, dpt_call, dtype):
327+
274328
q = get_queue_or_skip()
275329
skip_if_dtype_not_supported(dtype, q)
276330

@@ -284,9 +338,6 @@ def test_trig_complex_special_cases(np_call, dpt_call, dtype):
284338
Ynp = np_call(Xc_np)
285339

286340
tol = 50 * dpt.finfo(dtype).resolution
287-
assert_allclose(
288-
dpt.asnumpy(dpt.real(dpt_call(Xc))), np.real(Ynp), atol=tol, rtol=tol
289-
)
290-
assert_allclose(
291-
dpt.asnumpy(dpt.imag(dpt_call(Xc))), np.imag(Ynp), atol=tol, rtol=tol
292-
)
341+
Y = dpt_call(Xc)
342+
assert_allclose(dpt.asnumpy(dpt.real(Y)), np.real(Ynp), atol=tol, rtol=tol)
343+
assert_allclose(dpt.asnumpy(dpt.imag(Y)), np.imag(Ynp), atol=tol, rtol=tol)

0 commit comments

Comments
 (0)