15
15
# limitations under the License.
16
16
17
17
import itertools
18
+ import os
18
19
19
20
import numpy as np
20
21
import pytest
@@ -93,15 +94,25 @@ def test_trig_complex_contig(np_call, dpt_call, dtype):
93
94
q = get_queue_or_skip ()
94
95
skip_if_dtype_not_supported (dtype , q )
95
96
96
- n_seq = 100
97
+ n_seq = 256
97
98
n_rep = 137
98
99
low = - 9.0
99
100
high = 9.0
100
101
x1 = np .random .uniform (low = low , high = high , size = n_seq )
101
102
x2 = np .random .uniform (low = low , high = high , size = n_seq )
102
103
Xnp = x1 + 1j * x2
103
104
104
- X = dpt .asarray (np .repeat (Xnp , n_rep ), dtype = dtype , sycl_queue = q )
105
+ # stay away from poles and branch lines
106
+ modulus = np .abs (Xnp )
107
+ sel = np .logical_or (
108
+ modulus < 0.9 ,
109
+ np .logical_and (
110
+ modulus > 1.2 , np .minimum (np .abs (x2 ), np .abs (x1 )) > 0.05
111
+ ),
112
+ )
113
+ Xnp = Xnp [sel ]
114
+
115
+ X = dpt .repeat (dpt .asarray (Xnp , dtype = dtype , sycl_queue = q ), n_rep )
105
116
Y = dpt_call (X )
106
117
107
118
expected = np .repeat (np_call (Xnp ), n_rep )
@@ -234,10 +245,30 @@ def test_trig_complex_strided(np_call, dpt_call, dtype):
234
245
235
246
low = - 9.0
236
247
high = 9.0
248
+ while True :
249
+ x1 = np .random .uniform (low = low , high = high , size = 2 * sum (sizes ))
250
+ x2 = np .random .uniform (low = low , high = high , size = 2 * sum (sizes ))
251
+ Xnp_all = np .array (
252
+ [complex (v1 , v2 ) for v1 , v2 in zip (x1 , x2 )], dtype = dtype
253
+ )
254
+
255
+ # stay away from poles and branch lines
256
+ modulus = np .abs (Xnp_all )
257
+ sel = np .logical_or (
258
+ modulus < 0.9 ,
259
+ np .logical_and (
260
+ modulus > 1.2 , np .minimum (np .abs (x2 ), np .abs (x1 )) > 0.05
261
+ ),
262
+ )
263
+ Xnp_all = Xnp_all [sel ]
264
+ if Xnp_all .size > sum (sizes ):
265
+ break
266
+
267
+ pos = 0
237
268
for ii in sizes :
238
- x1 = np . random . uniform ( low = low , high = high , size = ii )
239
- x2 = np . random . uniform ( low = low , high = high , size = ii )
240
- Xnp = np . array ([ complex ( v1 , v2 ) for v1 , v2 in zip ( x1 , x2 )], dtype = dtype )
269
+ pos = pos + ii
270
+ Xnp = Xnp_all [: pos ]
271
+ Xnp = Xnp [ - ii :]
241
272
X = dpt .asarray (Xnp )
242
273
Ynp = np_call (Xnp )
243
274
for jj in strides :
@@ -264,13 +295,36 @@ def test_trig_real_special_cases(np_call, dpt_call, dtype):
264
295
Y_np = np_call (xf )
265
296
266
297
tol = 8 * dpt .finfo (dtype ).resolution
267
- assert_allclose (dpt .asnumpy (dpt_call (yf )), Y_np , atol = tol , rtol = tol )
298
+ Y = dpt_call (yf )
299
+ assert_allclose (dpt .asnumpy (Y ), Y_np , atol = tol , rtol = tol )
300
+
301
+
302
+ @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
303
+ @pytest .mark .parametrize ("dtype" , ["c8" , "c16" ])
304
+ def test_trig_complex_special_cases_conj_property (np_call , dpt_call , dtype ):
305
+ q = get_queue_or_skip ()
306
+ skip_if_dtype_not_supported (dtype , q )
268
307
308
+ x = [np .nan , np .inf , - np .inf , + 0.0 , - 0.0 , + 1.0 , - 1.0 ]
309
+ xc = [complex (* val ) for val in itertools .product (x , repeat = 2 )]
310
+
311
+ Xc_np = np .array (xc , dtype = dtype )
312
+ Xc = dpt .asarray (Xc_np , dtype = dtype , sycl_queue = q )
313
+
314
+ tol = 50 * dpt .finfo (dtype ).resolution
315
+ Y = dpt_call (Xc )
316
+ Yc = dpt_call (dpt .conj (Xc ))
269
317
270
- @pytest .mark .broken_complex
318
+ dpt .allclose (Y , dpt .conj (Yc ), atol = tol , rtol = tol )
319
+
320
+
321
+ @pytest .mark .skipif (
322
+ os .name != "posix" , reason = "Known to fail on Windows due to bug in NumPy"
323
+ )
271
324
@pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
272
325
@pytest .mark .parametrize ("dtype" , ["c8" , "c16" ])
273
326
def test_trig_complex_special_cases (np_call , dpt_call , dtype ):
327
+
274
328
q = get_queue_or_skip ()
275
329
skip_if_dtype_not_supported (dtype , q )
276
330
@@ -284,9 +338,6 @@ def test_trig_complex_special_cases(np_call, dpt_call, dtype):
284
338
Ynp = np_call (Xc_np )
285
339
286
340
tol = 50 * dpt .finfo (dtype ).resolution
287
- assert_allclose (
288
- dpt .asnumpy (dpt .real (dpt_call (Xc ))), np .real (Ynp ), atol = tol , rtol = tol
289
- )
290
- assert_allclose (
291
- dpt .asnumpy (dpt .imag (dpt_call (Xc ))), np .imag (Ynp ), atol = tol , rtol = tol
292
- )
341
+ Y = dpt_call (Xc )
342
+ assert_allclose (dpt .asnumpy (dpt .real (Y )), np .real (Ynp ), atol = tol , rtol = tol )
343
+ assert_allclose (dpt .asnumpy (dpt .imag (Y )), np .imag (Ynp ), atol = tol , rtol = tol )
0 commit comments