Skip to content

Commit a6cb5db

Browse files
Merge pull request #1380 from IntelPython/dedicated-roll-kernel
Add roll kernels
2 parents a9064ee + ac331bb commit a6cb5db

12 files changed

+981
-81
lines changed

dpctl/tensor/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ pybind11_add_module(${python_module_name} MODULE
3737
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_and_cast_usm_to_usm.cpp
3838
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp
3939
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_reshape.cpp
40+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_roll.cpp
4041
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
4142
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/integer_advanced_indexing.cpp
4243
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_advanced_indexing.cpp

dpctl/tensor/_manipulation_functions.py

+13-23
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
# limitations under the License.
1616

1717

18-
from itertools import chain, product, repeat
18+
import operator
19+
from itertools import chain, repeat
1920

2021
import numpy as np
2122
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
@@ -426,10 +427,11 @@ def roll(X, shift, axis=None):
426427
if not isinstance(X, dpt.usm_ndarray):
427428
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
428429
if axis is None:
430+
shift = operator.index(shift)
429431
res = dpt.empty(
430432
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue
431433
)
432-
hev, _ = ti._copy_usm_ndarray_for_reshape(
434+
hev, _ = ti._copy_usm_ndarray_for_roll_1d(
433435
src=X, dst=res, shift=shift, sycl_queue=X.sycl_queue
434436
)
435437
hev.wait()
@@ -438,31 +440,20 @@ def roll(X, shift, axis=None):
438440
broadcasted = np.broadcast(shift, axis)
439441
if broadcasted.ndim > 1:
440442
raise ValueError("'shift' and 'axis' should be scalars or 1D sequences")
441-
shifts = {ax: 0 for ax in range(X.ndim)}
443+
shifts = [
444+
0,
445+
] * X.ndim
442446
for sh, ax in broadcasted:
443447
shifts[ax] += sh
444-
rolls = [((np.s_[:], np.s_[:]),)] * X.ndim
445-
for ax, offset in shifts.items():
446-
offset %= X.shape[ax] or 1
447-
if offset:
448-
# (original, result), (original, result)
449-
rolls[ax] = (
450-
(np.s_[:-offset], np.s_[offset:]),
451-
(np.s_[-offset:], np.s_[:offset]),
452-
)
453448

449+
exec_q = X.sycl_queue
454450
res = dpt.empty(
455-
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue
451+
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q
456452
)
457-
hev_list = []
458-
for indices in product(*rolls):
459-
arr_index, res_index = zip(*indices)
460-
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
461-
src=X[arr_index], dst=res[res_index], sycl_queue=X.sycl_queue
462-
)
463-
hev_list.append(hev)
464-
465-
dpctl.SyclEvent.wait_for(hev_list)
453+
ht_e, _ = ti._copy_usm_ndarray_for_roll_nd(
454+
src=X, dst=res, shifts=shifts, sycl_queue=exec_q
455+
)
456+
ht_e.wait()
466457
return res
467458

468459

@@ -550,7 +541,6 @@ def _concat_axis_None(arrays):
550541
hev, _ = ti._copy_usm_ndarray_for_reshape(
551542
src=src_,
552543
dst=res[fill_start:fill_end],
553-
shift=0,
554544
sycl_queue=exec_q,
555545
)
556546
fill_start = fill_end

dpctl/tensor/_reshape.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def reshape(X, shape, order="C", copy=None):
165165
)
166166
if order == "C":
167167
hev, _ = _copy_usm_ndarray_for_reshape(
168-
src=X, dst=flat_res, shift=0, sycl_queue=X.sycl_queue
168+
src=X, dst=flat_res, sycl_queue=X.sycl_queue
169169
)
170170
hev.wait()
171171
else:

0 commit comments

Comments
 (0)