Skip to content

Commit bfd7257

Browse files
authored
Fix non-contiguous reshapes in numba backend (#255)
1 parent 0c9eb9c commit bfd7257

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,7 @@ def dimshuffle_inner(x, shuffle):
841841

842842
@numba_basic.numba_njit
843843
def dimshuffle_inner(x, shuffle):
844-
return np.reshape(x, ())
844+
return np.reshape(np.ascontiguousarray(x), ())
845845

846846
# Without the following wrapper function we would see this error:
847847
# E No implementation of function Function(<built-in function getitem>) found for signature:

tests/link/numba/test_elemwise.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,17 @@ def test_Dimshuffle_returns_array():
218218
assert out.ndim == 0
219219

220220

221+
def test_Dimshuffle_non_contiguous():
222+
"""The numba impl of reshape doesn't work with
223+
non-contiguous arrays, make sure we work around that."""
224+
x = at.dvector()
225+
idx = at.vector(dtype="int64")
226+
op = pytensor.tensor.elemwise.DimShuffle([True], [])
227+
out = op(at.specify_shape(x[idx][::2], (1,)))
228+
func = pytensor.function([x, idx], out, mode="NUMBA")
229+
assert func(np.zeros(3), np.array([1])).ndim == 0
230+
231+
221232
@pytest.mark.parametrize(
222233
"careduce_fn, axis, v",
223234
[

0 commit comments

Comments
 (0)