15
15
# limitations under the License.
16
16
17
17
18
- from itertools import chain , product , repeat
18
+ import operator
19
+ from itertools import chain , repeat
19
20
20
21
import numpy as np
21
22
from numpy .core .numeric import normalize_axis_index , normalize_axis_tuple
@@ -426,10 +427,11 @@ def roll(X, shift, axis=None):
426
427
if not isinstance (X , dpt .usm_ndarray ):
427
428
raise TypeError (f"Expected usm_ndarray type, got { type (X )} ." )
428
429
if axis is None :
430
+ shift = operator .index (shift )
429
431
res = dpt .empty (
430
432
X .shape , dtype = X .dtype , usm_type = X .usm_type , sycl_queue = X .sycl_queue
431
433
)
432
- hev , _ = ti ._copy_usm_ndarray_for_reshape (
434
+ hev , _ = ti ._copy_usm_ndarray_for_roll_1d (
433
435
src = X , dst = res , shift = shift , sycl_queue = X .sycl_queue
434
436
)
435
437
hev .wait ()
@@ -438,31 +440,20 @@ def roll(X, shift, axis=None):
438
440
broadcasted = np .broadcast (shift , axis )
439
441
if broadcasted .ndim > 1 :
440
442
raise ValueError ("'shift' and 'axis' should be scalars or 1D sequences" )
441
- shifts = {ax : 0 for ax in range (X .ndim )}
443
+ shifts = [
444
+ 0 ,
445
+ ] * X .ndim
442
446
for sh , ax in broadcasted :
443
447
shifts [ax ] += sh
444
- rolls = [((np .s_ [:], np .s_ [:]),)] * X .ndim
445
- for ax , offset in shifts .items ():
446
- offset %= X .shape [ax ] or 1
447
- if offset :
448
- # (original, result), (original, result)
449
- rolls [ax ] = (
450
- (np .s_ [:- offset ], np .s_ [offset :]),
451
- (np .s_ [- offset :], np .s_ [:offset ]),
452
- )
453
448
449
+ exec_q = X .sycl_queue
454
450
res = dpt .empty (
455
- X .shape , dtype = X .dtype , usm_type = X .usm_type , sycl_queue = X . sycl_queue
451
+ X .shape , dtype = X .dtype , usm_type = X .usm_type , sycl_queue = exec_q
456
452
)
457
- hev_list = []
458
- for indices in product (* rolls ):
459
- arr_index , res_index = zip (* indices )
460
- hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
461
- src = X [arr_index ], dst = res [res_index ], sycl_queue = X .sycl_queue
462
- )
463
- hev_list .append (hev )
464
-
465
- dpctl .SyclEvent .wait_for (hev_list )
453
+ ht_e , _ = ti ._copy_usm_ndarray_for_roll_nd (
454
+ src = X , dst = res , shifts = shifts , sycl_queue = exec_q
455
+ )
456
+ ht_e .wait ()
466
457
return res
467
458
468
459
@@ -550,7 +541,6 @@ def _concat_axis_None(arrays):
550
541
hev , _ = ti ._copy_usm_ndarray_for_reshape (
551
542
src = src_ ,
552
543
dst = res [fill_start :fill_end ],
553
- shift = 0 ,
554
544
sycl_queue = exec_q ,
555
545
)
556
546
fill_start = fill_end
0 commit comments