Skip to content

Commit 0ca9f68

Browse files
committed
Evaluate the rv.shape directly
1 parent 60a04ff commit 0ca9f68

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

pymc_experimental/model/transforms/autoreparam.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,8 @@ def vip_reparam_node(
176176
) -> Tuple[ModelDeterministic, ModelNamed]:
177177
if not isinstance(node.op, RandomVariable | SymbolicRandomVariable):
178178
raise TypeError("Op should be RandomVariable type")
179-
_, size, *_ = node.inputs
180-
eval_size = size.eval(mode="FAST_COMPILE")
181-
if eval_size is not None:
182-
rv_shape = tuple(eval_size)
183-
else:
184-
rv_shape = ()
179+
rv = node.default_output()
180+
rv_shape = rv.shape.eval(mode="FAST_COMPILE")
185181
lam_name = f"{name}::lam_logit__"
186182
_log.debug(f"Creating {lam_name} with shape of {rv_shape}")
187183
logit_lam_ = pytensor.shared(

0 commit comments

Comments
 (0)