Skip to content

Commit e86ba87

Browse files
Merge pull request #1294 from IntelPython/fix-gh-1293-sum-keepdims-true
Fix gh-1293 for sum over zero-size array with keepdims set
2 parents af302a5 + a4e357f commit e86ba87

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

dpctl/tensor/_reduction.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ def sum(arr, axis=None, dtype=None, keepdims=False):
123123

124124
res_usm_type = arr.usm_type
125125
if arr.size == 0:
126+
if keepdims:
127+
res_shape = res_shape + (1,) * red_nd
128+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
129+
res_shape = tuple(res_shape[i] for i in inv_perm)
126130
return dpt.zeros(
127131
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
128132
)

dpctl/tests/test_tensor_sum.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,26 @@ def test_sum_arg_out_dtype_scalar(arg_dtype, out_dtype):
133133
assert isinstance(r, dpt.usm_ndarray)
134134
assert r.dtype == dpt.dtype(out_dtype)
135135
assert dpt.asnumpy(r) == 1
136+
137+
138+
def test_sum_keepdims_zero_size():
139+
"""See gh-1293"""
140+
get_queue_or_skip()
141+
n = 10
142+
a = dpt.ones((n, 0, n))
143+
144+
s1 = dpt.sum(a, keepdims=True)
145+
assert s1.shape == (1, 1, 1)
146+
147+
s2 = dpt.sum(a, axis=(0, 1), keepdims=True)
148+
assert s2.shape == (1, 1, n)
149+
150+
s3 = dpt.sum(a, axis=(1, 2), keepdims=True)
151+
assert s3.shape == (n, 1, 1)
152+
153+
s4 = dpt.sum(a, axis=(0, 2), keepdims=True)
154+
assert s4.shape == (1, 0, 1)
155+
156+
a0 = a[0]
157+
s5 = dpt.sum(a0, keepdims=True)
158+
assert s5.shape == (1, 1)

0 commit comments

Comments
 (0)