Skip to content

Add roll kernels #1380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 8, 2023
1 change: 1 addition & 0 deletions dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pybind11_add_module(${python_module_name} MODULE
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_and_cast_usm_to_usm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_reshape.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_roll.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/integer_advanced_indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_advanced_indexing.cpp
Expand Down
36 changes: 13 additions & 23 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
# limitations under the License.


from itertools import chain, product, repeat
import operator
from itertools import chain, repeat

import numpy as np
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
Expand Down Expand Up @@ -426,10 +427,11 @@ def roll(X, shift, axis=None):
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
if axis is None:
shift = operator.index(shift)
res = dpt.empty(
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue
)
hev, _ = ti._copy_usm_ndarray_for_reshape(
hev, _ = ti._copy_usm_ndarray_for_roll_1d(
src=X, dst=res, shift=shift, sycl_queue=X.sycl_queue
)
hev.wait()
Expand All @@ -438,31 +440,20 @@ def roll(X, shift, axis=None):
broadcasted = np.broadcast(shift, axis)
if broadcasted.ndim > 1:
raise ValueError("'shift' and 'axis' should be scalars or 1D sequences")
shifts = {ax: 0 for ax in range(X.ndim)}
shifts = [
0,
] * X.ndim
for sh, ax in broadcasted:
shifts[ax] += sh
rolls = [((np.s_[:], np.s_[:]),)] * X.ndim
for ax, offset in shifts.items():
offset %= X.shape[ax] or 1
if offset:
# (original, result), (original, result)
rolls[ax] = (
(np.s_[:-offset], np.s_[offset:]),
(np.s_[-offset:], np.s_[:offset]),
)

exec_q = X.sycl_queue
res = dpt.empty(
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q
)
hev_list = []
for indices in product(*rolls):
arr_index, res_index = zip(*indices)
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=X[arr_index], dst=res[res_index], sycl_queue=X.sycl_queue
)
hev_list.append(hev)

dpctl.SyclEvent.wait_for(hev_list)
ht_e, _ = ti._copy_usm_ndarray_for_roll_nd(
src=X, dst=res, shifts=shifts, sycl_queue=exec_q
)
ht_e.wait()
return res


Expand Down Expand Up @@ -550,7 +541,6 @@ def _concat_axis_None(arrays):
hev, _ = ti._copy_usm_ndarray_for_reshape(
src=src_,
dst=res[fill_start:fill_end],
shift=0,
sycl_queue=exec_q,
)
fill_start = fill_end
Expand Down
2 changes: 1 addition & 1 deletion dpctl/tensor/_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def reshape(X, shape, order="C", copy=None):
)
if order == "C":
hev, _ = _copy_usm_ndarray_for_reshape(
src=X, dst=flat_res, shift=0, sycl_queue=X.sycl_queue
src=X, dst=flat_res, sycl_queue=X.sycl_queue
)
hev.wait()
else:
Expand Down
Loading