Skip to content

Commit 02614c8

Browse files
committed
fix autoreparam because dims are no longer static
1 parent 87d4aea commit 02614c8

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

pymc_experimental/model/transforms/autoreparam.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from dataclasses import dataclass
23
from functools import singledispatch
34
from typing import Dict, List, Optional, Sequence, Tuple, Union
@@ -8,7 +9,6 @@
89
import pytensor.tensor as pt
910
import scipy.special
1011
from pymc.distributions import SymbolicRandomVariable
11-
from pymc.exceptions import NotConstantValueError
1212
from pymc.logprob.transforms import Transform
1313
from pymc.model.fgraph import (
1414
ModelDeterministic,
@@ -19,10 +19,12 @@
1919
model_from_fgraph,
2020
model_named,
2121
)
22-
from pymc.pytensorf import constant_fold, toposort_replace
22+
from pymc.pytensorf import toposort_replace
2323
from pytensor.graph.basic import Apply, Variable
2424
from pytensor.tensor.random.op import RandomVariable
2525

26+
_log = logging.getLogger("pmx")
27+
2628

2729
@dataclass
2830
class VIP:
@@ -174,15 +176,14 @@ def vip_reparam_node(
174176
) -> Tuple[ModelDeterministic, ModelNamed]:
175177
if not isinstance(node.op, RandomVariable | SymbolicRandomVariable):
176178
raise TypeError("Op should be RandomVariable type")
177-
rv = node.default_output()
178-
try:
179-
[rv_shape] = constant_fold([rv.shape])
180-
except NotConstantValueError:
181-
raise ValueError("Size should be static for autoreparametrization.")
179+
_, size, *_ = node.inputs
180+
rv_shape = tuple(size.eval())
181+
lam_name = f"{name}::lam_logit__"
182+
_log.debug(f"Creating {lam_name} with shape of {rv_shape}")
182183
logit_lam_ = pytensor.shared(
183184
np.zeros(rv_shape),
184185
shape=rv_shape,
185-
name=f"{name}::lam_logit__",
186+
name=lam_name,
186187
)
187188
logit_lam = model_named(logit_lam_, *dims)
188189
lam = pt.sigmoid(logit_lam)

0 commit comments

Comments
 (0)