Skip to content

Commit a9064ee

Browse files
Merge pull request #1392 from IntelPython/fix-gh-1391
Fix for incorrect result in reduction over axis=0
2 parents a91fc55 + 8d8ef0b commit a9064ee

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

dpctl/tensor/libtensor/source/sum_reductions.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,9 @@ std::pair<sycl::event, sycl::event> py_sum_over_axis(
218218
return std::make_pair(keep_args_event, sum_over_axis_contig_ev);
219219
}
220220
}
221-
else if (is_src_f_contig & is_dst_c_contig) {
221+
else if (is_src_f_contig &&
222+
((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous()))
223+
{
222224
auto fn = sum_over_axis0_contig_atomic_dispatch_table[src_typeid]
223225
[dst_typeid];
224226
if (fn != nullptr) {

dpctl/tests/test_tensor_sum.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,18 @@ def test_largish_reduction(arg_dtype, n):
172172

173173
assert dpt.all(dpt.equal(y1, y2))
174174
assert dpt.all(dpt.equal(y1, n * m))
175+
176+
177+
def test_axis0_bug():
178+
"gh-1391"
179+
get_queue_or_skip()
180+
181+
sh = (1, 2, 3)
182+
a = dpt.arange(sh[0] * sh[1] * sh[2], dtype="i4")
183+
a = dpt.reshape(a, sh)
184+
aT = dpt.permute_dims(a, (2, 1, 0))
185+
186+
s = dpt.sum(aT, axis=2)
187+
expected = dpt.asarray([[0, 3], [1, 4], [2, 5]])
188+
189+
assert dpt.all(s == expected)

0 commit comments

Comments
 (0)