We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 60a04ff commit 0ca9f68Copy full SHA for 0ca9f68
pymc_experimental/model/transforms/autoreparam.py
@@ -176,12 +176,8 @@ def vip_reparam_node(
176
) -> Tuple[ModelDeterministic, ModelNamed]:
177
if not isinstance(node.op, RandomVariable | SymbolicRandomVariable):
178
raise TypeError("Op should be RandomVariable type")
179
- _, size, *_ = node.inputs
180
- eval_size = size.eval(mode="FAST_COMPILE")
181
- if eval_size is not None:
182
- rv_shape = tuple(eval_size)
183
- else:
184
- rv_shape = ()
+ rv = node.default_output()
+ rv_shape = rv.shape.eval(mode="FAST_COMPILE")
185
lam_name = f"{name}::lam_logit__"
186
_log.debug(f"Creating {lam_name} with shape of {rv_shape}")
187
logit_lam_ = pytensor.shared(
0 commit comments