Skip to content

Commit ddc5977

Browse files
committed
Corrected boolean indexing cumsum
- The cumulative sum was being calculated incorrectly -- the offset from stride simplification was unused and the result was incorrect for some cases with non-C-contiguous strides - To fix this, new functions ``compact_iteration_space`` and complementary function ``compact_iteration`` have been implemented
1 parent cf4660d commit ddc5977

File tree

5 files changed

+106
-13
lines changed

5 files changed

+106
-13
lines changed

dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,6 @@ typedef size_t (*mask_positions_strided_impl_fn_ptr_t)(
424424
size_t,
425425
const char *,
426426
int,
427-
py::ssize_t,
428427
const py::ssize_t *,
429428
char *,
430429
std::vector<sycl::event> const &);
@@ -434,7 +433,6 @@ size_t mask_positions_strided_impl(sycl::queue q,
434433
size_t n_elems,
435434
const char *mask,
436435
int nd,
437-
py::ssize_t input_offset,
438436
const py::ssize_t *shape_strides,
439437
char *cumsum,
440438
std::vector<sycl::event> const &depends = {})
@@ -444,7 +442,7 @@ size_t mask_positions_strided_impl(sycl::queue q,
444442
cumsumT *cumsum_data_ptr = reinterpret_cast<cumsumT *>(cumsum);
445443
size_t wg_size = 128;
446444

447-
StridedIndexer strided_indexer{nd, input_offset, shape_strides};
445+
StridedIndexer strided_indexer{nd, 0, shape_strides};
448446
NonZeroIndicator<maskT, cumsumT> non_zero_indicator{};
449447

450448
sycl::event comp_ev =

dpctl/tensor/libtensor/include/utils/strided_iters.hpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,58 @@ contract_iter4(vecT shape,
909909
out_strides3, disp3, out_strides4, disp4);
910910
}
911911

912+
/*
913+
For purposes of iterating over pairs of elements of two arrays
914+
with `shape` and strides `strides1`, `strides2` given as pointers
915+
`simplify_iteration_two_strides(nd, shape_ptr, strides1_ptr,
916+
strides2_ptr, disp1, disp2)`
917+
may modify memory and returns new length of these arrays.
918+
919+
The new shape and new strides, as well as the offset
920+
`(new_shape, new_strides1, disp1, new_stride2, disp2)` are such that
921+
iterating over them will traverse the same set of pairs of elements,
922+
possibly in a different order.
923+
*/
924+
template <class ShapeTy, class StridesTy>
925+
int compact_iteration(const int nd, ShapeTy *shape, StridesTy *strides)
926+
{
927+
if (nd < 2)
928+
return nd;
929+
930+
bool contractable = true;
931+
for (int i = 0; i < nd; ++i) {
932+
if (strides[i] < 0) {
933+
contractable = false;
934+
}
935+
}
936+
937+
int nd_ = nd;
938+
while (contractable) {
939+
bool changed = false;
940+
for (int i = 0; i + 1 < nd_; ++i) {
941+
StridesTy str = strides[i + 1];
942+
StridesTy jump = strides[i] - (shape[i + 1] - 1) * str;
943+
944+
if (jump == str) {
945+
changed = true;
946+
shape[i] *= shape[i + 1];
947+
for (int j = i; j < nd_; ++j) {
948+
strides[j] = strides[j + 1];
949+
}
950+
for (int j = i + 1; j + 1 < nd_; ++j) {
951+
shape[j] = shape[j + 1];
952+
}
953+
--nd_;
954+
break;
955+
}
956+
}
957+
if (!changed)
958+
break;
959+
}
960+
961+
return nd_;
962+
}
963+
912964
} // namespace strides
913965
} // namespace tensor
914966
} // namespace dpctl

dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,22 +205,19 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
205205
auto const &strides_vector = mask.get_strides_vector();
206206

207207
using shT = std::vector<py::ssize_t>;
208-
shT simplified_shape;
209-
shT simplified_strides;
210-
py::ssize_t offset(0);
208+
shT compact_shape;
209+
shT compact_strides;
211210

212211
int mask_nd = mask.get_ndim();
213212
int nd = mask_nd;
214213

215-
dpctl::tensor::py_internal::simplify_iteration_space_1(
216-
nd, shape, strides_vector, simplified_shape, simplified_strides,
217-
offset);
214+
dpctl::tensor::py_internal::compact_iteration_space(
215+
nd, shape, strides_vector, compact_shape, compact_strides);
218216

219-
if (nd == 1 && simplified_strides[0] == 1) {
217+
if (nd == 1 && compact_strides[0] == 1) {
220218
auto fn = (use_i32)
221219
? mask_positions_contig_i32_dispatch_vector[mask_typeid]
222220
: mask_positions_contig_i64_dispatch_vector[mask_typeid];
223-
224221
return fn(exec_q, mask_size, mask_data, cumsum_data, depends);
225222
}
226223

@@ -232,7 +229,7 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
232229

233230
using dpctl::tensor::offset_utils::device_allocate_and_pack;
234231
const auto &ptr_size_event_tuple = device_allocate_and_pack<py::ssize_t>(
235-
exec_q, host_task_events, simplified_shape, simplified_strides);
232+
exec_q, host_task_events, compact_shape, compact_strides);
236233
py::ssize_t *shape_strides = std::get<0>(ptr_size_event_tuple);
237234
if (shape_strides == nullptr) {
238235
sycl::event::wait(host_task_events);
@@ -253,7 +250,7 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
253250
dependent_events.insert(dependent_events.end(), depends.begin(),
254251
depends.end());
255252

256-
size_t total_set = strided_fn(exec_q, mask_size, mask_data, nd, offset,
253+
size_t total_set = strided_fn(exec_q, mask_size, mask_data, nd,
257254
shape_strides, cumsum_data, dependent_events);
258255

259256
sycl::event::wait(host_task_events);

dpctl/tensor/libtensor/source/simplify_iteration_space.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,45 @@ void simplify_iteration_space_4(
369369
}
370370
}
371371

372+
void compact_iteration_space(int &nd,
373+
const py::ssize_t *const &shape,
374+
std::vector<py::ssize_t> const &strides,
375+
// output
376+
std::vector<py::ssize_t> &compact_shape,
377+
std::vector<py::ssize_t> &compact_strides)
378+
{
379+
using dpctl::tensor::strides::compact_iteration;
380+
if (nd > 1) {
381+
// Compact iteration space to reduce dimensionality
382+
// and improve access pattern
383+
compact_shape.reserve(nd);
384+
compact_shape.insert(std::begin(compact_shape), shape, shape + nd);
385+
assert(compact_shape.size() == static_cast<size_t>(nd));
386+
387+
compact_strides.reserve(nd);
388+
compact_strides.insert(std::end(compact_strides), std::begin(strides),
389+
std::end(strides));
390+
assert(compact_strides.size() == static_cast<size_t>(nd));
391+
392+
int contracted_nd =
393+
compact_iteration(nd, compact_shape.data(), compact_strides.data());
394+
compact_shape.resize(contracted_nd);
395+
compact_strides.resize(contracted_nd);
396+
397+
nd = contracted_nd;
398+
}
399+
else if (nd == 1) {
400+
// Populate vectors
401+
compact_shape.reserve(nd);
402+
compact_shape.push_back(shape[0]);
403+
assert(compact_shape.size() == static_cast<size_t>(nd));
404+
405+
compact_strides.reserve(nd);
406+
compact_strides.push_back(strides[0]);
407+
assert(compact_strides.size() == static_cast<size_t>(nd));
408+
}
409+
}
410+
372411
py::ssize_t _ravel_multi_index_c(std::vector<py::ssize_t> const &mi,
373412
std::vector<py::ssize_t> const &shape)
374413
{

dpctl/tensor/libtensor/source/simplify_iteration_space.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,13 @@ void simplify_iteration_space_4(int &,
9090
py::ssize_t &,
9191
py::ssize_t &);
9292

93+
void compact_iteration_space(int &,
94+
const py::ssize_t *const &,
95+
std::vector<py::ssize_t> const &,
96+
// output
97+
std::vector<py::ssize_t> &,
98+
std::vector<py::ssize_t> &);
99+
93100
py::ssize_t _ravel_multi_index_c(std::vector<py::ssize_t> const &,
94101
std::vector<py::ssize_t> const &);
95102
py::ssize_t _ravel_multi_index_f(std::vector<py::ssize_t> const &,

0 commit comments

Comments
 (0)