Skip to content

Commit 7367e8d

Browse files
committed
Fix bug in Dimshuffles created by Elemwise
1 parent a86efe5 commit 7367e8d

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

pytensor/tensor/elemwise.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ def __init__(self, input_broadcastable, new_order):
130130
super().__init__([self.c_func_file], self.c_func_name)
131131

132132
self.input_broadcastable = tuple(input_broadcastable)
133+
if not all(isinstance(bs, (bool, np.bool_)) for bs in self.input_broadcastable):
134+
raise ValueError(
135+
f"input_broadcastable must be boolean, {self.input_broadcastable}"
136+
)
133137
self.new_order = tuple(new_order)
134138

135139
self.inplace = True
@@ -411,10 +415,9 @@ def get_output_info(self, dim_shuffle, *inputs):
411415
if not difference:
412416
args.append(input)
413417
else:
414-
# TODO: use LComplete instead
415418
args.append(
416419
dim_shuffle(
417-
tuple(1 if s == 1 else None for s in input.type.shape),
420+
input.type.broadcastable,
418421
["x"] * difference + list(range(length)),
419422
)(input)
420423
)

tests/tensor/test_elemwise.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,12 @@ def test_static_shape(self):
188188
y = x.dimshuffle([0, 1, "x"])
189189
assert y.type.shape == (1, 2, 1)
190190

191+
def test_valid_input_broadcastable(self):
192+
assert DimShuffle([True, False], (1, 0)).input_broadcastable == (True, False)
193+
194+
with pytest.raises(ValueError, match="input_broadcastable must be boolean"):
195+
DimShuffle([None, None], (1, 0))
196+
191197

192198
class TestBroadcast:
193199
# this is to allow other types to reuse this class to test their ops

0 commit comments

Comments
 (0)