diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index ca83b8350b..49f25aef6a 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -37,6 +37,7 @@ pybind11_add_module(${python_module_name} MODULE ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_and_cast_usm_to_usm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_reshape.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_roll.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/integer_advanced_indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_advanced_indexing.cpp diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index d76f33af94..cb54556ed2 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -15,7 +15,8 @@ # limitations under the License. -from itertools import chain, product, repeat +import operator +from itertools import chain, repeat import numpy as np from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple @@ -426,10 +427,11 @@ def roll(X, shift, axis=None): if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") if axis is None: + shift = operator.index(shift) res = dpt.empty( X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue ) - hev, _ = ti._copy_usm_ndarray_for_reshape( + hev, _ = ti._copy_usm_ndarray_for_roll_1d( src=X, dst=res, shift=shift, sycl_queue=X.sycl_queue ) hev.wait() @@ -438,31 +440,20 @@ def roll(X, shift, axis=None): broadcasted = np.broadcast(shift, axis) if broadcasted.ndim > 1: raise ValueError("'shift' and 'axis' should be scalars or 1D sequences") - shifts = {ax: 0 for ax in range(X.ndim)} + shifts = [ + 0, + ] * X.ndim for sh, ax in broadcasted: shifts[ax] += sh - rolls = [((np.s_[:], np.s_[:]),)] * X.ndim - for ax, offset in shifts.items(): - offset %= X.shape[ax] or 1 - if offset: - # (original, result), (original, result) - rolls[ax] = ( - (np.s_[:-offset], np.s_[offset:]), - (np.s_[-offset:], np.s_[:offset]), - ) + exec_q = X.sycl_queue res = dpt.empty( - X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue + X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q ) - hev_list = [] - for indices in product(*rolls): - arr_index, res_index = zip(*indices) - hev, _ = ti._copy_usm_ndarray_into_usm_ndarray( - src=X[arr_index], dst=res[res_index], sycl_queue=X.sycl_queue - ) - hev_list.append(hev) - - dpctl.SyclEvent.wait_for(hev_list) + ht_e, _ = ti._copy_usm_ndarray_for_roll_nd( + src=X, dst=res, shifts=shifts, sycl_queue=exec_q + ) + ht_e.wait() return res @@ -550,7 +541,6 @@ def _concat_axis_None(arrays): hev, _ = ti._copy_usm_ndarray_for_reshape( src=src_, dst=res[fill_start:fill_end], - shift=0, sycl_queue=exec_q, ) fill_start = fill_end diff --git a/dpctl/tensor/_reshape.py b/dpctl/tensor/_reshape.py index ac4a04cac4..b363c063de 100644 --- a/dpctl/tensor/_reshape.py +++ b/dpctl/tensor/_reshape.py @@ -165,7 +165,7 @@ def reshape(X, shape, order="C", copy=None): ) if order == "C": hev, _ = _copy_usm_ndarray_for_reshape( - src=X, dst=flat_res, shift=0, sycl_queue=X.sycl_queue + src=X, dst=flat_res, sycl_queue=X.sycl_queue ) hev.wait() else: diff --git a/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp b/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp index 99c356aeb9..e5aaa34903 100644 --- a/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp +++ b/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp @@ -56,9 +56,6 @@ class copy_cast_contig_kernel; template class copy_cast_from_host_kernel; -template -class copy_for_reshape_generic_kernel; - template class Caster { public: @@ -630,27 +627,24 @@ struct CopyAndCastFromHostFactory // =============== Copying for reshape ================== // +template +class copy_for_reshape_generic_kernel; + template class GenericCopyForReshapeFunctor { private: - py::ssize_t offset = 0; - py::ssize_t size = 1; - // USM array of size 2*(src_nd + dst_nd) - // [ src_shape; src_strides; dst_shape; dst_strides ] - Ty *src_p = nullptr; + const Ty *src_p = nullptr; Ty *dst_p = nullptr; SrcIndexerT src_indexer_; DstIndexerT dst_indexer_; public: - GenericCopyForReshapeFunctor(py::ssize_t shift, - py::ssize_t nelems, - char *src_ptr, + GenericCopyForReshapeFunctor(const char *src_ptr, char *dst_ptr, SrcIndexerT src_indexer, DstIndexerT dst_indexer) - : offset(shift), size(nelems), src_p(reinterpret_cast(src_ptr)), + : src_p(reinterpret_cast(src_ptr)), dst_p(reinterpret_cast(dst_ptr)), src_indexer_(src_indexer), dst_indexer_(dst_indexer) { @@ -658,40 +652,31 @@ class GenericCopyForReshapeFunctor void operator()(sycl::id<1> wiid) const { - py::ssize_t this_src_offset = src_indexer_(wiid.get(0)); - const Ty *in = src_p + this_src_offset; - - py::ssize_t shifted_wiid = - (static_cast(wiid.get(0)) + offset) % size; - shifted_wiid = (shifted_wiid >= 0) ? shifted_wiid : shifted_wiid + size; + const py::ssize_t src_offset = src_indexer_(wiid.get(0)); + const py::ssize_t dst_offset = dst_indexer_(wiid.get(0)); - py::ssize_t this_dst_offset = dst_indexer_(shifted_wiid); - - Ty *out = dst_p + this_dst_offset; - *out = *in; + dst_p[dst_offset] = src_p[src_offset]; } }; // define function type typedef sycl::event (*copy_for_reshape_fn_ptr_t)( sycl::queue, - py::ssize_t, // shift - size_t, // num_elements - int, - int, // src_nd, dst_nd + size_t, // num_elements + int, // src_nd + int, // dst_nd py::ssize_t *, // packed shapes and strides - char *, // src_data_ptr + const char *, // src_data_ptr char *, // dst_data_ptr const std::vector &); /*! * @brief Function to copy content of array while reshaping. * - * Submits a kernel to perform a copy `dst[unravel_index((i + shift) % nelems , + * Submits a kernel to perform a copy `dst[unravel_index(i, * dst.shape)] = src[unravel_undex(i, src.shape)]`. * * @param q The execution queue where kernel is submitted. - * @param shift The shift in flat indexing. * @param nelems The number of elements to copy * @param src_nd Array dimension of the source array * @param dst_nd Array dimension of the destination array @@ -709,31 +694,40 @@ typedef sycl::event (*copy_for_reshape_fn_ptr_t)( template sycl::event copy_for_reshape_generic_impl(sycl::queue q, - py::ssize_t shift, size_t nelems, int src_nd, int dst_nd, py::ssize_t *packed_shapes_and_strides, - char *src_p, + const char *src_p, char *dst_p, const std::vector &depends) { dpctl::tensor::type_utils::validate_type_for_device(q); sycl::event copy_for_reshape_ev = q.submit([&](sycl::handler &cgh) { - StridedIndexer src_indexer{ - src_nd, 0, - const_cast(packed_shapes_and_strides)}; - StridedIndexer dst_indexer{ - dst_nd, 0, - const_cast(packed_shapes_and_strides + - (2 * src_nd))}; cgh.depends_on(depends); - cgh.parallel_for>( + + // packed_shapes_and_strides: + // USM array of size 2*(src_nd + dst_nd) + // [ src_shape; src_strides; dst_shape; dst_strides ] + + const py::ssize_t *src_shape_and_strides = + const_cast(packed_shapes_and_strides); + + const py::ssize_t *dst_shape_and_strides = + const_cast(packed_shapes_and_strides + + (2 * src_nd)); + + StridedIndexer src_indexer{src_nd, 0, src_shape_and_strides}; + StridedIndexer dst_indexer{dst_nd, 0, dst_shape_and_strides}; + + using KernelName = + copy_for_reshape_generic_kernel; + + cgh.parallel_for( sycl::range<1>(nelems), GenericCopyForReshapeFunctor( - shift, nelems, src_p, dst_p, src_indexer, dst_indexer)); + src_p, dst_p, src_indexer, dst_indexer)); }); return copy_for_reshape_ev; @@ -753,6 +747,387 @@ template struct CopyForReshapeGenericFactory } }; +// ================== Copying for roll ================== // + +/*! @brief Functor to cyclically roll global_id to the left */ +struct LeftRolled1DTransformer +{ + LeftRolled1DTransformer(size_t offset, size_t size) + : offset_(offset), size_(size) + { + } + + size_t operator()(size_t gid) const + { + const size_t shifted_gid = + ((gid < offset_) ? gid + size_ - offset_ : gid - offset_); + return shifted_gid; + } + +private: + size_t offset_ = 0; + size_t size_ = 1; +}; + +/*! @brief Indexer functor to compose indexer and transformer */ +template struct CompositionIndexer +{ + CompositionIndexer(IndexerT f, TransformerT t) : f_(f), t_(t) {} + + auto operator()(size_t gid) const + { + return f_(t_(gid)); + } + +private: + IndexerT f_; + TransformerT t_; +}; + +/*! @brief Indexer functor to find offset for nd-shifted indices lifted from + * iteration id */ +struct RolledNDIndexer +{ + RolledNDIndexer(int nd, + const py::ssize_t *shape, + const py::ssize_t *strides, + const py::ssize_t *ndshifts, + py::ssize_t starting_offset) + : nd_(nd), shape_(shape), strides_(strides), ndshifts_(ndshifts), + starting_offset_(starting_offset) + { + } + + py::ssize_t operator()(size_t gid) const + { + return compute_offset(gid); + } + +private: + int nd_ = -1; + const py::ssize_t *shape_ = nullptr; + const py::ssize_t *strides_ = nullptr; + const py::ssize_t *ndshifts_ = nullptr; + py::ssize_t starting_offset_ = 0; + + py::ssize_t compute_offset(py::ssize_t gid) const + { + using dpctl::tensor::strides::CIndexer_vector; + + CIndexer_vector _ind(nd_); + py::ssize_t relative_offset_(0); + _ind.get_left_rolled_displacement( + gid, + shape_, // shape ptr + strides_, // strides ptr + ndshifts_, // shifts ptr + relative_offset_); + return starting_offset_ + relative_offset_; + } +}; + +template +class copy_for_roll_strided_kernel; + +template +class StridedCopyForRollFunctor +{ +private: + const Ty *src_p = nullptr; + Ty *dst_p = nullptr; + SrcIndexerT src_indexer_; + DstIndexerT dst_indexer_; + +public: + StridedCopyForRollFunctor(const Ty *src_ptr, + Ty *dst_ptr, + SrcIndexerT src_indexer, + DstIndexerT dst_indexer) + : src_p(src_ptr), dst_p(dst_ptr), src_indexer_(src_indexer), + dst_indexer_(dst_indexer) + { + } + + void operator()(sycl::id<1> wiid) const + { + const size_t gid = wiid.get(0); + + const py::ssize_t src_offset = src_indexer_(gid); + const py::ssize_t dst_offset = dst_indexer_(gid); + + dst_p[dst_offset] = src_p[src_offset]; + } +}; + +// define function type +typedef sycl::event (*copy_for_roll_strided_fn_ptr_t)( + sycl::queue, + size_t, // shift + size_t, // num_elements + int, // common_nd + const py::ssize_t *, // packed shapes and strides + const char *, // src_data_ptr + py::ssize_t, // src_offset + char *, // dst_data_ptr + py::ssize_t, // dst_offset + const std::vector &); + +/*! + * @brief Function to copy content of array with a shift. + * + * Submits a kernel to perform a copy `dst[unravel_index((i + shift) % nelems , + * dst.shape)] = src[unravel_undex(i, src.shape)]`. + * + * @param q The execution queue where kernel is submitted. + * @param shift The shift in flat indexing, must be non-negative. + * @param nelems The number of elements to copy + * @param nd Array dimensionality of the destination and source arrays + * @param packed_shapes_and_strides Kernel accessible USM array + * of size `3*nd` with content `[common_shape, src_strides, dst_strides]`. + * @param src_p Typeless USM pointer to the buffer of the source array + * @param src_offset Displacement of first element of src relative src_p in + * elements + * @param dst_p Typeless USM pointer to the buffer of the destination array + * @param dst_offset Displacement of first element of dst relative dst_p in + * elements + * @param depends List of events to wait for before starting computations, if + * any. + * + * @return Event to wait on to ensure that computation completes. + * @ingroup CopyAndCastKernels + */ +template +sycl::event +copy_for_roll_strided_impl(sycl::queue q, + size_t shift, + size_t nelems, + int nd, + const py::ssize_t *packed_shapes_and_strides, + const char *src_p, + py::ssize_t src_offset, + char *dst_p, + py::ssize_t dst_offset, + const std::vector &depends) +{ + dpctl::tensor::type_utils::validate_type_for_device(q); + + sycl::event copy_for_roll_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + // packed_shapes_and_strides: + // USM array of size 3 * nd + // [ common_shape; src_strides; dst_strides ] + + StridedIndexer src_indexer{nd, src_offset, packed_shapes_and_strides}; + LeftRolled1DTransformer left_roll_transformer{shift, nelems}; + + using CompositeIndexerT = + CompositionIndexer; + + CompositeIndexerT rolled_src_indexer(src_indexer, + left_roll_transformer); + + UnpackedStridedIndexer dst_indexer{nd, dst_offset, + packed_shapes_and_strides, + packed_shapes_and_strides + 2 * nd}; + + using KernelName = copy_for_roll_strided_kernel; + + const Ty *src_tp = reinterpret_cast(src_p); + Ty *dst_tp = reinterpret_cast(dst_p); + + cgh.parallel_for( + sycl::range<1>(nelems), + StridedCopyForRollFunctor( + src_tp, dst_tp, rolled_src_indexer, dst_indexer)); + }); + + return copy_for_roll_ev; +} + +// define function type +typedef sycl::event (*copy_for_roll_contig_fn_ptr_t)( + sycl::queue, + size_t, // shift + size_t, // num_elements + const char *, // src_data_ptr + py::ssize_t, // src_offset + char *, // dst_data_ptr + py::ssize_t, // dst_offset + const std::vector &); + +template class copy_for_roll_contig_kernel; + +/*! + * @brief Function to copy content of array with a shift. + * + * Submits a kernel to perform a copy `dst[unravel_index((i + shift) % nelems , + * dst.shape)] = src[unravel_undex(i, src.shape)]`. + * + * @param q The execution queue where kernel is submitted. + * @param shift The shift in flat indexing, must be non-negative. + * @param nelems The number of elements to copy + * @param src_p Typeless USM pointer to the buffer of the source array + * @param src_offset Displacement of the start of array src relative src_p in + * elements + * @param dst_p Typeless USM pointer to the buffer of the destination array + * @param dst_offset Displacement of the start of array dst relative dst_p in + * elements + * @param depends List of events to wait for before starting computations, if + * any. + * + * @return Event to wait on to ensure that computation completes. + * @ingroup CopyAndCastKernels + */ +template +sycl::event copy_for_roll_contig_impl(sycl::queue q, + size_t shift, + size_t nelems, + const char *src_p, + py::ssize_t src_offset, + char *dst_p, + py::ssize_t dst_offset, + const std::vector &depends) +{ + dpctl::tensor::type_utils::validate_type_for_device(q); + + sycl::event copy_for_roll_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + NoOpIndexer src_indexer{}; + LeftRolled1DTransformer roller{shift, nelems}; + + CompositionIndexer + left_rolled_src_indexer{src_indexer, roller}; + NoOpIndexer dst_indexer{}; + + using KernelName = copy_for_roll_contig_kernel; + + const Ty *src_tp = reinterpret_cast(src_p) + src_offset; + Ty *dst_tp = reinterpret_cast(dst_p) + dst_offset; + + cgh.parallel_for( + sycl::range<1>(nelems), + StridedCopyForRollFunctor< + Ty, CompositionIndexer, + NoOpIndexer>(src_tp, dst_tp, left_rolled_src_indexer, + dst_indexer)); + }); + + return copy_for_roll_ev; +} + +/*! + * @brief Factory to get function pointer of type `fnT` for given array data + * type `Ty`. + * @ingroup CopyAndCastKernels + */ +template struct CopyForRollStridedFactory +{ + fnT get() + { + fnT f = copy_for_roll_strided_impl; + return f; + } +}; + +/*! + * @brief Factory to get function pointer of type `fnT` for given array data + * type `Ty`. + * @ingroup CopyAndCastKernels + */ +template struct CopyForRollContigFactory +{ + fnT get() + { + fnT f = copy_for_roll_contig_impl; + return f; + } +}; + +template +class copy_for_roll_ndshift_strided_kernel; + +// define function type +typedef sycl::event (*copy_for_roll_ndshift_strided_fn_ptr_t)( + sycl::queue, + size_t, // num_elements + int, // common_nd + const py::ssize_t *, // packed shape, strides, shifts + const char *, // src_data_ptr + py::ssize_t, // src_offset + char *, // dst_data_ptr + py::ssize_t, // dst_offset + const std::vector &); + +template +sycl::event copy_for_roll_ndshift_strided_impl( + sycl::queue q, + size_t nelems, + int nd, + const py::ssize_t *packed_shapes_and_strides_and_shifts, + const char *src_p, + py::ssize_t src_offset, + char *dst_p, + py::ssize_t dst_offset, + const std::vector &depends) +{ + dpctl::tensor::type_utils::validate_type_for_device(q); + + sycl::event copy_for_roll_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + // packed_shapes_and_strides_and_shifts: + // USM array of size 4 * nd + // [ common_shape; src_strides; dst_strides; shifts ] + + const py::ssize_t *shape_ptr = packed_shapes_and_strides_and_shifts; + const py::ssize_t *src_strides_ptr = + packed_shapes_and_strides_and_shifts + nd; + const py::ssize_t *dst_strides_ptr = + packed_shapes_and_strides_and_shifts + 2 * nd; + const py::ssize_t *shifts_ptr = + packed_shapes_and_strides_and_shifts + 3 * nd; + + RolledNDIndexer src_indexer{nd, shape_ptr, src_strides_ptr, shifts_ptr, + src_offset}; + + UnpackedStridedIndexer dst_indexer{nd, dst_offset, shape_ptr, + dst_strides_ptr}; + + using KernelName = copy_for_roll_strided_kernel; + + const Ty *src_tp = reinterpret_cast(src_p); + Ty *dst_tp = reinterpret_cast(dst_p); + + cgh.parallel_for( + sycl::range<1>(nelems), + StridedCopyForRollFunctor( + src_tp, dst_tp, src_indexer, dst_indexer)); + }); + + return copy_for_roll_ev; +} + +/*! + * @brief Factory to get function pointer of type `fnT` for given array data + * type `Ty`. + * @ingroup CopyAndCastKernels + */ +template struct CopyForRollNDShiftFactory +{ + fnT get() + { + fnT f = copy_for_roll_ndshift_strided_impl; + return f; + } +}; + } // namespace copy_and_cast } // namespace kernels } // namespace tensor diff --git a/dpctl/tensor/libtensor/include/utils/offset_utils.hpp b/dpctl/tensor/libtensor/include/utils/offset_utils.hpp index 99e56b850d..19bcf9d0a8 100644 --- a/dpctl/tensor/libtensor/include/utils/offset_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/offset_utils.hpp @@ -144,12 +144,12 @@ struct StridedIndexer { } - size_t operator()(py::ssize_t gid) const + py::ssize_t operator()(py::ssize_t gid) const { return compute_offset(gid); } - size_t operator()(size_t gid) const + py::ssize_t operator()(size_t gid) const { return compute_offset(static_cast(gid)); } @@ -159,7 +159,7 @@ struct StridedIndexer py::ssize_t starting_offset; py::ssize_t const *shape_strides; - size_t compute_offset(py::ssize_t gid) const + py::ssize_t compute_offset(py::ssize_t gid) const { using dpctl::tensor::strides::CIndexer_vector; @@ -185,12 +185,12 @@ struct UnpackedStridedIndexer { } - size_t operator()(py::ssize_t gid) const + py::ssize_t operator()(py::ssize_t gid) const { return compute_offset(gid); } - size_t operator()(size_t gid) const + py::ssize_t operator()(size_t gid) const { return compute_offset(static_cast(gid)); } @@ -201,7 +201,7 @@ struct UnpackedStridedIndexer py::ssize_t const *shape; py::ssize_t const *strides; - size_t compute_offset(py::ssize_t gid) const + py::ssize_t compute_offset(py::ssize_t gid) const { using dpctl::tensor::strides::CIndexer_vector; @@ -223,11 +223,10 @@ struct Strided1DIndexer { } - size_t operator()(size_t gid) const + py::ssize_t operator()(size_t gid) const { // ensure 0 <= gid < size - return static_cast(offset + - std::min(gid, size - 1) * step); + return offset + std::min(gid, size - 1) * step; } private: @@ -245,9 +244,9 @@ struct Strided1DCyclicIndexer { } - size_t operator()(size_t gid) const + py::ssize_t operator()(size_t gid) const { - return static_cast(offset + (gid % size) * step); + return offset + (gid % size) * step; } private: diff --git a/dpctl/tensor/libtensor/include/utils/strided_iters.hpp b/dpctl/tensor/libtensor/include/utils/strided_iters.hpp index bd174e3f90..7cca7c7b5d 100644 --- a/dpctl/tensor/libtensor/include/utils/strided_iters.hpp +++ b/dpctl/tensor/libtensor/include/utils/strided_iters.hpp @@ -238,6 +238,30 @@ template class CIndexer_vector } return; } + + template + void get_left_rolled_displacement(indT i, + ShapeTy shape, + StridesTy stride, + StridesTy shifts, + indT &disp) const + { + indT i_ = i; + indT d = 0; + for (int dim = nd; --dim > 0;) { + const indT si = shape[dim]; + const indT q = i_ / si; + const indT r = (i_ - q * si); + // assumes si > shifts[dim] >= 0 + const indT shifted_r = + (r < shifts[dim] ? r + si - shifts[dim] : r - shifts[dim]); + d += shifted_r * stride[dim]; + i_ = q; + } + const indT shifted_r = + (i_ < shifts[0] ? i_ + shape[0] - shifts[0] : i_ - shifts[0]); + disp = d + shifted_r * stride[0]; + } }; /* diff --git a/dpctl/tensor/libtensor/source/copy_for_reshape.cpp b/dpctl/tensor/libtensor/source/copy_for_reshape.cpp index 8edf982b16..7114d87c47 100644 --- a/dpctl/tensor/libtensor/source/copy_for_reshape.cpp +++ b/dpctl/tensor/libtensor/source/copy_for_reshape.cpp @@ -60,7 +60,6 @@ static copy_for_reshape_fn_ptr_t std::pair copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src, dpctl::tensor::usm_ndarray dst, - py::ssize_t shift, sycl::queue exec_q, const std::vector &depends) { @@ -109,7 +108,7 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src, if (src_nelems == 1) { // handle special case of 1-element array int src_elemsize = src.get_elemsize(); - char *src_data = src.get_data(); + const char *src_data = src.get_data(); char *dst_data = dst.get_data(); sycl::event copy_ev = exec_q.copy(src_data, dst_data, src_elemsize); @@ -146,7 +145,7 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src, } sycl::event copy_shape_ev = std::get<2>(ptr_size_event_tuple); - char *src_data = src.get_data(); + const char *src_data = src.get_data(); char *dst_data = dst.get_data(); std::vector all_deps(depends.size() + 1); @@ -154,7 +153,7 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src, all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends)); sycl::event copy_for_reshape_event = - fn(exec_q, shift, src_nelems, src_nd, dst_nd, shape_strides, src_data, + fn(exec_q, src_nelems, src_nd, dst_nd, shape_strides, src_data, dst_data, all_deps); auto temporaries_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { diff --git a/dpctl/tensor/libtensor/source/copy_for_reshape.hpp b/dpctl/tensor/libtensor/source/copy_for_reshape.hpp index 09caddf824..32d41fc159 100644 --- a/dpctl/tensor/libtensor/source/copy_for_reshape.hpp +++ b/dpctl/tensor/libtensor/source/copy_for_reshape.hpp @@ -40,7 +40,6 @@ namespace py_internal extern std::pair copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src, dpctl::tensor::usm_ndarray dst, - py::ssize_t shift, sycl::queue exec_q, const std::vector &depends = {}); diff --git a/dpctl/tensor/libtensor/source/copy_for_roll.cpp b/dpctl/tensor/libtensor/source/copy_for_roll.cpp new file mode 100644 index 0000000000..eee129932f --- /dev/null +++ b/dpctl/tensor/libtensor/source/copy_for_roll.cpp @@ -0,0 +1,419 @@ +//===----------- Implementation of _tensor_impl module ---------*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===----------------------------------------------------------------------===// + +#include +#include +#include + +#include "copy_for_roll.hpp" +#include "dpctl4pybind11.hpp" +#include "kernels/copy_and_cast.hpp" +#include "utils/type_dispatch.hpp" +#include + +#include "simplify_iteration_space.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +using dpctl::tensor::kernels::copy_and_cast::copy_for_roll_contig_fn_ptr_t; +using dpctl::tensor::kernels::copy_and_cast:: + copy_for_roll_ndshift_strided_fn_ptr_t; +using dpctl::tensor::kernels::copy_and_cast::copy_for_roll_strided_fn_ptr_t; +using dpctl::utils::keep_args_alive; + +// define static vector +static copy_for_roll_strided_fn_ptr_t + copy_for_roll_strided_dispatch_vector[td_ns::num_types]; + +static copy_for_roll_contig_fn_ptr_t + copy_for_roll_contig_dispatch_vector[td_ns::num_types]; + +static copy_for_roll_ndshift_strided_fn_ptr_t + copy_for_roll_ndshift_dispatch_vector[td_ns::num_types]; + +/* + * Copies src into dst (same data type) of different shapes by using flat + * iterations. + * + * Equivalent to the following loop: + * + * for i for range(src.size): + * dst[np.multi_index(i, dst.shape)] = src[np.multi_index(i, src.shape)] + */ +std::pair +copy_usm_ndarray_for_roll_1d(dpctl::tensor::usm_ndarray src, + dpctl::tensor::usm_ndarray dst, + py::ssize_t shift, + sycl::queue exec_q, + const std::vector &depends) +{ + int src_nd = src.get_ndim(); + int dst_nd = dst.get_ndim(); + + // Must have the same number of dimensions + if (src_nd != dst_nd) { + throw py::value_error( + "copy_usm_ndarray_for_roll_1d requires src and dst to " + "have the same number of dimensions."); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + if (!std::equal(src_shape_ptr, src_shape_ptr + src_nd, dst_shape_ptr)) { + throw py::value_error( + "copy_usm_ndarray_for_roll_1d requires src and dst to " + "have the same shape."); + } + + py::ssize_t src_nelems = src.get_size(); + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + // typenames must be the same + if (src_typenum != dst_typenum) { + throw py::value_error( + "copy_usm_ndarray_for_roll_1d requires src and dst to " + "have the same type."); + } + + if (src_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + // destination must be ample enough to accommodate all elements + { + auto dst_offsets = dst.get_minmax_offsets(); + py::ssize_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < src_nelems) { + throw py::value_error( + "Destination array can not accommodate all the " + "elements of source array."); + } + } + + // check same contexts + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + if (src_nelems == 1) { + // handle special case of 1-element array + int src_elemsize = src.get_elemsize(); + const char *src_data = src.get_data(); + char *dst_data = dst.get_data(); + sycl::event copy_ev = + exec_q.copy(src_data, dst_data, src_elemsize); + return std::make_pair(keep_args_alive(exec_q, {src, dst}, {copy_ev}), + copy_ev); + } + + auto array_types = td_ns::usm_ndarray_types(); + int type_id = array_types.typenum_to_lookup_id(src_typenum); + + const bool is_src_c_contig = src.is_c_contiguous(); + const bool is_src_f_contig = src.is_f_contiguous(); + + const bool is_dst_c_contig = dst.is_c_contiguous(); + const bool is_dst_f_contig = dst.is_f_contiguous(); + + const bool both_c_contig = is_src_c_contig && is_dst_c_contig; + const bool both_f_contig = is_src_f_contig && is_dst_f_contig; + + // normalize shift parameter to be 0 <= offset < src_nelems + size_t offset = + (shift > 0) ? (shift % src_nelems) : src_nelems + (shift % src_nelems); + + const char *src_data = src.get_data(); + char *dst_data = dst.get_data(); + + if (both_c_contig || both_f_contig) { + auto fn = copy_for_roll_contig_dispatch_vector[type_id]; + + if (fn != nullptr) { + constexpr py::ssize_t zero_offset = 0; + + sycl::event copy_for_roll_ev = + fn(exec_q, offset, src_nelems, src_data, zero_offset, dst_data, + zero_offset, depends); + + sycl::event ht_ev = + keep_args_alive(exec_q, {src, dst}, {copy_for_roll_ev}); + + return std::make_pair(ht_ev, copy_for_roll_ev); + } + } + + auto const &src_strides = src.get_strides_vector(); + auto const &dst_strides = dst.get_strides_vector(); + + using shT = std::vector; + shT simplified_shape; + shT simplified_src_strides; + shT simplified_dst_strides; + py::ssize_t src_offset(0); + py::ssize_t dst_offset(0); + + int nd = src_nd; + const py::ssize_t *shape = src_shape_ptr; + + // nd, simplified_* and *_offset are modified by reference + dpctl::tensor::py_internal::simplify_iteration_space( + nd, shape, src_strides, dst_strides, + // output + simplified_shape, simplified_src_strides, simplified_dst_strides, + src_offset, dst_offset); + + if (nd == 1 && simplified_src_strides[0] == 1 && + simplified_dst_strides[0] == 1) { + auto fn = copy_for_roll_contig_dispatch_vector[type_id]; + + if (fn != nullptr) { + + sycl::event copy_for_roll_ev = + fn(exec_q, offset, src_nelems, src_data, src_offset, dst_data, + dst_offset, depends); + + sycl::event ht_ev = + keep_args_alive(exec_q, {src, dst}, {copy_for_roll_ev}); + + return std::make_pair(ht_ev, copy_for_roll_ev); + } + } + + auto fn = copy_for_roll_strided_dispatch_vector[type_id]; + + std::vector host_task_events; + host_task_events.reserve(2); + + // shape_strides = [src_shape, src_strides, dst_strides] + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &ptr_size_event_tuple = device_allocate_and_pack( + exec_q, host_task_events, simplified_shape, simplified_src_strides, + simplified_dst_strides); + + py::ssize_t *shape_strides = std::get<0>(ptr_size_event_tuple); + if (shape_strides == nullptr) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event copy_shape_ev = std::get<2>(ptr_size_event_tuple); + + std::vector all_deps(depends.size() + 1); + all_deps.push_back(copy_shape_ev); + all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends)); + + sycl::event copy_for_roll_event = + fn(exec_q, offset, src_nelems, src_nd, shape_strides, src_data, + src_offset, dst_data, dst_offset, all_deps); + + auto temporaries_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(copy_for_roll_event); + auto ctx = exec_q.get_context(); + cgh.host_task( + [shape_strides, ctx]() { sycl::free(shape_strides, ctx); }); + }); + + host_task_events.push_back(temporaries_cleanup_ev); + + return std::make_pair(keep_args_alive(exec_q, {src, dst}, host_task_events), + copy_for_roll_event); +} + +std::pair +copy_usm_ndarray_for_roll_nd(dpctl::tensor::usm_ndarray src, + dpctl::tensor::usm_ndarray dst, + const std::vector &shifts, + sycl::queue exec_q, + const std::vector &depends) +{ + int src_nd = src.get_ndim(); + int dst_nd = dst.get_ndim(); + + // Must have the same number of dimensions + if (src_nd != dst_nd) { + throw py::value_error( + "copy_usm_ndarray_for_roll_nd requires src and dst to " + "have the same number of dimensions."); + } + + if (static_cast(src_nd) != shifts.size()) { + throw py::value_error( + "copy_usm_ndarray_for_roll_nd requires shifts to " + "contain an integral shift for each array dimension."); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + if (!std::equal(src_shape_ptr, src_shape_ptr + src_nd, dst_shape_ptr)) { + throw py::value_error( + "copy_usm_ndarray_for_roll_nd requires src and dst to " + "have the same shape."); + } + + py::ssize_t src_nelems = src.get_size(); + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + // typenames must be the same + if (src_typenum != dst_typenum) { + throw py::value_error( + "copy_usm_ndarray_for_reshape requires src and dst to " + "have the same type."); + } + + if (src_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + // destination must be ample enough to accommodate all elements + { + auto dst_offsets = dst.get_minmax_offsets(); + py::ssize_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < src_nelems) { + throw py::value_error( + "Destination array can not accommodate all the " + "elements of source array."); + } + } + + // check for compatible queues + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + if (src_nelems == 1) { + // handle special case of 1-element array + int src_elemsize = src.get_elemsize(); + const char *src_data = src.get_data(); + char *dst_data = dst.get_data(); + sycl::event copy_ev = + exec_q.copy(src_data, dst_data, src_elemsize); + return std::make_pair(keep_args_alive(exec_q, {src, dst}, {copy_ev}), + copy_ev); + } + + auto array_types = td_ns::usm_ndarray_types(); + int type_id = array_types.typenum_to_lookup_id(src_typenum); + + std::vector normalized_shifts{}; + normalized_shifts.reserve(src_nd); + + for (int i = 0; i < src_nd; ++i) { + // normalize shift parameter to be 0 <= offset < dim + py::ssize_t dim = src_shape_ptr[i]; + size_t offset = + (shifts[i] > 0) ? (shifts[i] % dim) : dim + (shifts[i] % dim); + + normalized_shifts.push_back(offset); + } + + const char *src_data = src.get_data(); + char *dst_data = dst.get_data(); + + auto const &src_strides = src.get_strides_vector(); + auto const &dst_strides = dst.get_strides_vector(); + auto const &common_shape = src.get_shape_vector(); + + constexpr py::ssize_t src_offset = 0; + constexpr py::ssize_t dst_offset = 0; + + auto fn = copy_for_roll_ndshift_dispatch_vector[type_id]; + + std::vector host_task_events; + host_task_events.reserve(2); + + // shape_strides = [src_shape, src_strides, dst_strides] + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &ptr_size_event_tuple = device_allocate_and_pack( + exec_q, host_task_events, common_shape, src_strides, dst_strides, + normalized_shifts); + + py::ssize_t *shape_strides_shifts = std::get<0>(ptr_size_event_tuple); + if (shape_strides_shifts == nullptr) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event copy_shape_ev = std::get<2>(ptr_size_event_tuple); + + std::vector all_deps(depends.size() + 1); + all_deps.push_back(copy_shape_ev); + all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends)); + + sycl::event copy_for_roll_event = + fn(exec_q, src_nelems, src_nd, shape_strides_shifts, src_data, + src_offset, dst_data, dst_offset, all_deps); + + auto temporaries_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(copy_for_roll_event); + auto ctx = exec_q.get_context(); + cgh.host_task([shape_strides_shifts, ctx]() { + sycl::free(shape_strides_shifts, ctx); + }); + }); + + host_task_events.push_back(temporaries_cleanup_ev); + + return std::make_pair(keep_args_alive(exec_q, {src, dst}, host_task_events), + copy_for_roll_event); +} + +void init_copy_for_roll_dispatch_vectors(void) +{ + using namespace td_ns; + using dpctl::tensor::kernels::copy_and_cast::CopyForRollStridedFactory; + + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(copy_for_roll_strided_dispatch_vector); + + using dpctl::tensor::kernels::copy_and_cast::CopyForRollContigFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(copy_for_roll_contig_dispatch_vector); + + using dpctl::tensor::kernels::copy_and_cast::CopyForRollNDShiftFactory; + DispatchVectorBuilder + dvb3; + dvb3.populate_dispatch_vector(copy_for_roll_ndshift_dispatch_vector); +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/copy_for_roll.hpp b/dpctl/tensor/libtensor/source/copy_for_roll.hpp new file mode 100644 index 0000000000..0c00710e11 --- /dev/null +++ b/dpctl/tensor/libtensor/source/copy_for_roll.hpp @@ -0,0 +1,58 @@ +//===----------- Implementation of _tensor_impl module ---------*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===----------------------------------------------------------------------===// + +#pragma once +#include +#include +#include + +#include "dpctl4pybind11.hpp" +#include + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern std::pair +copy_usm_ndarray_for_roll_1d(dpctl::tensor::usm_ndarray src, + dpctl::tensor::usm_ndarray dst, + py::ssize_t shift, + sycl::queue exec_q, + const std::vector &depends = {}); + +extern std::pair +copy_usm_ndarray_for_roll_nd(dpctl::tensor::usm_ndarray src, + dpctl::tensor::usm_ndarray dst, + const std::vector &shifts, + sycl::queue exec_q, + const std::vector &depends = {}); + +extern void init_copy_for_roll_dispatch_vectors(); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/tensor_py.cpp b/dpctl/tensor/libtensor/source/tensor_py.cpp index 662ca464c4..691a56b11f 100644 --- a/dpctl/tensor/libtensor/source/tensor_py.cpp +++ b/dpctl/tensor/libtensor/source/tensor_py.cpp @@ -37,6 +37,7 @@ #include "boolean_reductions.hpp" #include "copy_and_cast_usm_to_usm.hpp" #include "copy_for_reshape.hpp" +#include "copy_for_roll.hpp" #include "copy_numpy_ndarray_into_usm_ndarray.hpp" #include "device_support_queries.hpp" #include "elementwise_functions.hpp" @@ -68,6 +69,11 @@ using dpctl::tensor::py_internal::copy_usm_ndarray_into_usm_ndarray; using dpctl::tensor::py_internal::copy_usm_ndarray_for_reshape; +/* =========================== Copy for roll ============================= */ + +using dpctl::tensor::py_internal::copy_usm_ndarray_for_roll_1d; +using dpctl::tensor::py_internal::copy_usm_ndarray_for_roll_nd; + /* ============= Copy from numpy.ndarray to usm_ndarray ==================== */ using dpctl::tensor::py_internal::copy_numpy_ndarray_into_usm_ndarray; @@ -120,6 +126,7 @@ void init_dispatch_vectors(void) using namespace dpctl::tensor::py_internal; init_copy_for_reshape_dispatch_vectors(); + init_copy_for_roll_dispatch_vectors(); init_linear_sequences_dispatch_vectors(); init_full_ctor_dispatch_vectors(); init_eye_ctor_dispatch_vectors(); @@ -221,11 +228,27 @@ PYBIND11_MODULE(_tensor_impl, m) m.def("_copy_usm_ndarray_for_reshape", ©_usm_ndarray_for_reshape, "Copies from usm_ndarray `src` into usm_ndarray `dst` with the same " "number of elements using underlying 'C'-contiguous order for flat " + "traversal. " + "Returns a tuple of events: (ht_event, comp_event)", + py::arg("src"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + + m.def("_copy_usm_ndarray_for_roll_1d", ©_usm_ndarray_for_roll_1d, + "Copies from usm_ndarray `src` into usm_ndarray `dst` with the same " + "shapes using underlying 'C'-contiguous order for flat " "traversal with shift. " "Returns a tuple of events: (ht_event, comp_event)", py::arg("src"), py::arg("dst"), py::arg("shift"), py::arg("sycl_queue"), py::arg("depends") = py::list()); + m.def("_copy_usm_ndarray_for_roll_nd", ©_usm_ndarray_for_roll_nd, + "Copies from usm_ndarray `src` into usm_ndarray `dst` with the same " + "shapes using underlying 'C'-contiguous order for " + "traversal with shifts along each axis. " + "Returns a tuple of events: (ht_event, comp_event)", + py::arg("src"), py::arg("dst"), py::arg("shifts"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + m.def("_linspace_step", &usm_ndarray_linear_sequence_step, "Fills input 1D contiguous usm_ndarray `dst` with linear sequence " "specified by " diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index 1cee5e6c8f..6152a15aae 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -648,6 +648,19 @@ def test_roll_2d(data): assert_array_equal(Ynp, dpt.asnumpy(Y)) +def test_roll_validation(): + get_queue_or_skip() + + X = dict() + with pytest.raises(TypeError): + dpt.roll(X) + + X = dpt.empty((1, 2, 3)) + shift = ((2, 3, 1), (1, 2, 3)) + with pytest.raises(ValueError): + dpt.roll(X, shift=shift, axis=(0, 1, 2)) + + def test_concat_incorrect_type(): Xnp = np.ones((2, 2)) pytest.raises(TypeError, dpt.concat)