Skip to content

Commit f49889c

Browse files
Fixes boolean indexing for strided masks (#1370)
* Corrected boolean indexing cumsum - The cumulative sum was being calculated incorrectly -- the offset from stride simplification was unused and the result was incorrect for some cases with non-C-contiguous strides - To fix this, new functions ``compact_iteration_space`` and complementary function ``compact_iteration`` have been implemented * Add a test for nonzero where dimension compacting occurs * Added tests for the corrected behavior of boolean indexing and nonzero * Removed dead branch in py_mask_positions Compacting strides can reduce dimensionality of the array, but it can not turn an input that is not already C-contiguous into a C-contiguous one. Hence the branch checking if the input became C-contiguous after compacting is effectively dead. * Added docstring for compact_iteration --------- Co-authored-by: Oleksandr Pavlyk <[email protected]>
1 parent 526c46c commit f49889c

File tree

6 files changed

+147
-19
lines changed

6 files changed

+147
-19
lines changed

dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,6 @@ typedef size_t (*mask_positions_strided_impl_fn_ptr_t)(
424424
size_t,
425425
const char *,
426426
int,
427-
py::ssize_t,
428427
const py::ssize_t *,
429428
char *,
430429
std::vector<sycl::event> const &);
@@ -434,7 +433,6 @@ size_t mask_positions_strided_impl(sycl::queue q,
434433
size_t n_elems,
435434
const char *mask,
436435
int nd,
437-
py::ssize_t input_offset,
438436
const py::ssize_t *shape_strides,
439437
char *cumsum,
440438
std::vector<sycl::event> const &depends = {})
@@ -444,7 +442,7 @@ size_t mask_positions_strided_impl(sycl::queue q,
444442
cumsumT *cumsum_data_ptr = reinterpret_cast<cumsumT *>(cumsum);
445443
size_t wg_size = 128;
446444

447-
StridedIndexer strided_indexer{nd, input_offset, shape_strides};
445+
StridedIndexer strided_indexer{nd, 0, shape_strides};
448446
NonZeroIndicator<maskT, cumsumT> non_zero_indicator{};
449447

450448
sycl::event comp_ev =

dpctl/tensor/libtensor/include/utils/strided_iters.hpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,55 @@ contract_iter4(vecT shape,
909909
out_strides3, disp3, out_strides4, disp4);
910910
}
911911

912+
/*
913+
For purposes of iterating over elements of an array with `shape` and
914+
strides `strides` given as pointers `compact_iteration(nd, shape, strides)`
915+
may modify memory and returns the new length of the array.
916+
917+
The new shape and new strides `(new_shape, new_strides)` are such that
918+
iterating over them will traverse the same elements in the same order,
919+
possibly with reduced dimensionality.
920+
*/
921+
template <class ShapeTy, class StridesTy>
922+
int compact_iteration(const int nd, ShapeTy *shape, StridesTy *strides)
923+
{
924+
if (nd < 2)
925+
return nd;
926+
927+
bool contractable = true;
928+
for (int i = 0; i < nd; ++i) {
929+
if (strides[i] < 0) {
930+
contractable = false;
931+
}
932+
}
933+
934+
int nd_ = nd;
935+
while (contractable) {
936+
bool changed = false;
937+
for (int i = 0; i + 1 < nd_; ++i) {
938+
StridesTy str = strides[i + 1];
939+
StridesTy jump = strides[i] - (shape[i + 1] - 1) * str;
940+
941+
if (jump == str) {
942+
changed = true;
943+
shape[i] *= shape[i + 1];
944+
for (int j = i; j < nd_; ++j) {
945+
strides[j] = strides[j + 1];
946+
}
947+
for (int j = i + 1; j + 1 < nd_; ++j) {
948+
shape[j] = shape[j + 1];
949+
}
950+
--nd_;
951+
break;
952+
}
953+
}
954+
if (!changed)
955+
break;
956+
}
957+
958+
return nd_;
959+
}
960+
912961
} // namespace strides
913962
} // namespace tensor
914963
} // namespace dpctl

dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -205,24 +205,14 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
205205
auto const &strides_vector = mask.get_strides_vector();
206206

207207
using shT = std::vector<py::ssize_t>;
208-
shT simplified_shape;
209-
shT simplified_strides;
210-
py::ssize_t offset(0);
208+
shT compact_shape;
209+
shT compact_strides;
211210

212211
int mask_nd = mask.get_ndim();
213212
int nd = mask_nd;
214213

215-
dpctl::tensor::py_internal::simplify_iteration_space_1(
216-
nd, shape, strides_vector, simplified_shape, simplified_strides,
217-
offset);
218-
219-
if (nd == 1 && simplified_strides[0] == 1) {
220-
auto fn = (use_i32)
221-
? mask_positions_contig_i32_dispatch_vector[mask_typeid]
222-
: mask_positions_contig_i64_dispatch_vector[mask_typeid];
223-
224-
return fn(exec_q, mask_size, mask_data, cumsum_data, depends);
225-
}
214+
dpctl::tensor::py_internal::compact_iteration_space(
215+
nd, shape, strides_vector, compact_shape, compact_strides);
226216

227217
// Strided implementation
228218
auto strided_fn =
@@ -232,7 +222,7 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
232222

233223
using dpctl::tensor::offset_utils::device_allocate_and_pack;
234224
const auto &ptr_size_event_tuple = device_allocate_and_pack<py::ssize_t>(
235-
exec_q, host_task_events, simplified_shape, simplified_strides);
225+
exec_q, host_task_events, compact_shape, compact_strides);
236226
py::ssize_t *shape_strides = std::get<0>(ptr_size_event_tuple);
237227
if (shape_strides == nullptr) {
238228
sycl::event::wait(host_task_events);
@@ -253,7 +243,7 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
253243
dependent_events.insert(dependent_events.end(), depends.begin(),
254244
depends.end());
255245

256-
size_t total_set = strided_fn(exec_q, mask_size, mask_data, nd, offset,
246+
size_t total_set = strided_fn(exec_q, mask_size, mask_data, nd,
257247
shape_strides, cumsum_data, dependent_events);
258248

259249
sycl::event::wait(host_task_events);

dpctl/tensor/libtensor/source/simplify_iteration_space.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,45 @@ void simplify_iteration_space_4(
369369
}
370370
}
371371

372+
void compact_iteration_space(int &nd,
373+
const py::ssize_t *const &shape,
374+
std::vector<py::ssize_t> const &strides,
375+
// output
376+
std::vector<py::ssize_t> &compact_shape,
377+
std::vector<py::ssize_t> &compact_strides)
378+
{
379+
using dpctl::tensor::strides::compact_iteration;
380+
if (nd > 1) {
381+
// Compact iteration space to reduce dimensionality
382+
// and improve access pattern
383+
compact_shape.reserve(nd);
384+
compact_shape.insert(std::begin(compact_shape), shape, shape + nd);
385+
assert(compact_shape.size() == static_cast<size_t>(nd));
386+
387+
compact_strides.reserve(nd);
388+
compact_strides.insert(std::end(compact_strides), std::begin(strides),
389+
std::end(strides));
390+
assert(compact_strides.size() == static_cast<size_t>(nd));
391+
392+
int contracted_nd =
393+
compact_iteration(nd, compact_shape.data(), compact_strides.data());
394+
compact_shape.resize(contracted_nd);
395+
compact_strides.resize(contracted_nd);
396+
397+
nd = contracted_nd;
398+
}
399+
else if (nd == 1) {
400+
// Populate vectors
401+
compact_shape.reserve(nd);
402+
compact_shape.push_back(shape[0]);
403+
assert(compact_shape.size() == static_cast<size_t>(nd));
404+
405+
compact_strides.reserve(nd);
406+
compact_strides.push_back(strides[0]);
407+
assert(compact_strides.size() == static_cast<size_t>(nd));
408+
}
409+
}
410+
372411
py::ssize_t _ravel_multi_index_c(std::vector<py::ssize_t> const &mi,
373412
std::vector<py::ssize_t> const &shape)
374413
{

dpctl/tensor/libtensor/source/simplify_iteration_space.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,13 @@ void simplify_iteration_space_4(int &,
9090
py::ssize_t &,
9191
py::ssize_t &);
9292

93+
void compact_iteration_space(int &,
94+
const py::ssize_t *const &,
95+
std::vector<py::ssize_t> const &,
96+
// output
97+
std::vector<py::ssize_t> &,
98+
std::vector<py::ssize_t> &);
99+
93100
py::ssize_t _ravel_multi_index_c(std::vector<py::ssize_t> const &,
94101
std::vector<py::ssize_t> const &);
95102
py::ssize_t _ravel_multi_index_f(std::vector<py::ssize_t> const &,

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,19 @@ def test_extract_all_1d():
10441044
res2 = dpt.extract(sel, x)
10451045
assert (dpt.asnumpy(res2) == expected_res).all()
10461046

1047+
# test strided case
1048+
x = dpt.arange(15, dtype="i4")
1049+
sel_np = np.zeros(15, dtype="?")
1050+
np.put(sel_np, np.random.choice(sel_np.size, size=7), True)
1051+
sel = dpt.asarray(sel_np)
1052+
1053+
res = x[sel[::-1]]
1054+
expected_res = dpt.asnumpy(x)[sel_np[::-1]]
1055+
assert (dpt.asnumpy(res) == expected_res).all()
1056+
1057+
res2 = dpt.extract(sel[::-1], x)
1058+
assert (dpt.asnumpy(res2) == expected_res).all()
1059+
10471060

10481061
def test_extract_all_2d():
10491062
get_queue_or_skip()
@@ -1287,6 +1300,38 @@ def test_nonzero():
12871300
assert (dpt.asnumpy(i) == np.array([3, 4, 5, 6])).all()
12881301

12891302

1303+
def test_nonzero_f_contig():
1304+
"See gh-1370"
1305+
get_queue_or_skip
1306+
1307+
mask = dpt.zeros((5, 5), dtype="?", order="F")
1308+
mask[2, 3] = True
1309+
1310+
expected_res = (2, 3)
1311+
res = dpt.nonzero(mask)
1312+
1313+
assert expected_res == res
1314+
assert mask[res]
1315+
1316+
1317+
def test_nonzero_compacting():
1318+
"""See gh-1370.
1319+
Test with input where dimensionality
1320+
of iteration space is compacted from 3d to 2d
1321+
"""
1322+
get_queue_or_skip
1323+
1324+
mask = dpt.zeros((5, 5, 5), dtype="?", order="F")
1325+
mask[3, 2, 1] = True
1326+
mask_view = mask[..., :3]
1327+
1328+
expected_res = (3, 2, 1)
1329+
res = dpt.nonzero(mask_view)
1330+
1331+
assert expected_res == res
1332+
assert mask_view[res]
1333+
1334+
12901335
def test_assign_scalar():
12911336
get_queue_or_skip()
12921337
x = dpt.arange(-5, 5, dtype="i8")

0 commit comments

Comments
 (0)