|
| 1 | +import logging |
1 | 2 | from dataclasses import dataclass
|
2 | 3 | from functools import singledispatch
|
3 | 4 | from typing import Dict, List, Optional, Sequence, Tuple, Union
|
|
8 | 9 | import pytensor.tensor as pt
|
9 | 10 | import scipy.special
|
10 | 11 | from pymc.distributions import SymbolicRandomVariable
|
11 |
| -from pymc.exceptions import NotConstantValueError |
12 | 12 | from pymc.logprob.transforms import Transform
|
13 | 13 | from pymc.model.fgraph import (
|
14 | 14 | ModelDeterministic,
|
|
19 | 19 | model_from_fgraph,
|
20 | 20 | model_named,
|
21 | 21 | )
|
22 |
| -from pymc.pytensorf import constant_fold, toposort_replace |
| 22 | +from pymc.pytensorf import toposort_replace |
23 | 23 | from pytensor.graph.basic import Apply, Variable
|
24 | 24 | from pytensor.tensor.random.op import RandomVariable
|
25 | 25 |
|
| 26 | +_log = logging.getLogger("pmx") |
| 27 | + |
26 | 28 |
|
27 | 29 | @dataclass
|
28 | 30 | class VIP:
|
@@ -174,15 +176,14 @@ def vip_reparam_node(
|
174 | 176 | ) -> Tuple[ModelDeterministic, ModelNamed]:
|
175 | 177 | if not isinstance(node.op, RandomVariable | SymbolicRandomVariable):
|
176 | 178 | raise TypeError("Op should be RandomVariable type")
|
177 |
| - rv = node.default_output() |
178 |
| - try: |
179 |
| - [rv_shape] = constant_fold([rv.shape]) |
180 |
| - except NotConstantValueError: |
181 |
| - raise ValueError("Size should be static for autoreparametrization.") |
| 179 | + _, size, *_ = node.inputs |
| 180 | + rv_shape = tuple(size.eval()) |
| 181 | + lam_name = f"{name}::lam_logit__" |
| 182 | + _log.debug(f"Creating {lam_name} with shape of {rv_shape}") |
182 | 183 | logit_lam_ = pytensor.shared(
|
183 | 184 | np.zeros(rv_shape),
|
184 | 185 | shape=rv_shape,
|
185 |
| - name=f"{name}::lam_logit__", |
| 186 | + name=lam_name, |
186 | 187 | )
|
187 | 188 | logit_lam = model_named(logit_lam_, *dims)
|
188 | 189 | lam = pt.sigmoid(logit_lam)
|
|
0 commit comments