Skip to content

Commit 62ef8b6

Browse files
ricardoV94twiecki
authored andcommitted
Do not raise from internal mixture rewrite
1 parent cedd595 commit 62ef8b6

File tree

1 file changed

+30
-30
lines changed

1 file changed

+30
-30
lines changed

pymc/logprob/mixture.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
import pytensor
4040
import pytensor.tensor as at
4141

42-
from pytensor.graph.basic import Apply, Variable
42+
from pytensor.graph.basic import Apply, Constant, Variable
4343
from pytensor.graph.fg import FunctionGraph
4444
from pytensor.graph.op import Op, compute_test_value
4545
from pytensor.graph.rewriting.basic import (
@@ -246,41 +246,29 @@ def get_stack_mixture_vars(
246246
node: Apply,
247247
) -> Tuple[Optional[List[TensorVariable]], Optional[int]]:
248248
r"""Extract the mixture terms from a `*Subtensor*` applied to stacked `MeasurableVariable`\s."""
249-
if not isinstance(node.op, subtensor_ops):
250-
return None, None # pragma: no cover
251249

252-
join_axis = NoneConst
250+
assert isinstance(node.op, subtensor_ops)
251+
253252
joined_rvs = node.inputs[0]
254253

255254
# First, make sure that it's some sort of concatenation
256255
if not (joined_rvs.owner and isinstance(joined_rvs.owner.op, (MakeVector, Join))):
257-
# Node is not a compatible join `Op`
258-
return None, join_axis # pragma: no cover
256+
return None, None
259257

260258
if isinstance(joined_rvs.owner.op, MakeVector):
259+
join_axis = NoneConst
261260
mixture_rvs = joined_rvs.owner.inputs
262261

263262
elif isinstance(joined_rvs.owner.op, Join):
264-
mixture_rvs = joined_rvs.owner.inputs[1:]
265-
join_axis = joined_rvs.owner.inputs[0]
266-
try:
267-
# TODO: Find better solution to avoid this circular dependency
268-
from pymc.pytensorf import constant_fold
269-
270-
join_axis = int(constant_fold((join_axis,))[0])
271-
except ValueError:
272-
# TODO: Support symbolic join axes
273-
raise NotImplementedError("Symbolic `Join` axes are not supported in mixtures")
263+
# TODO: Find better solution to avoid this circular dependency
264+
from pymc.pytensorf import constant_fold
274265

275-
join_axis = at.as_tensor(join_axis)
266+
join_axis = joined_rvs.owner.inputs[0]
267+
# TODO: Support symbolic join axes. This will raise ValueError if it's not a constant
268+
(join_axis,) = constant_fold((join_axis,), raise_not_constant=False)
269+
join_axis = at.as_tensor(join_axis, dtype="int64")
276270

277-
if not all(rv.owner and isinstance(rv.owner.op, MeasurableVariable) for rv in mixture_rvs):
278-
# Currently, all mixture components must be `MeasurableVariable` outputs
279-
# TODO: Allow constants and make them Dirac-deltas
280-
# raise NotImplementedError(
281-
# "All mixture components must be `MeasurableVariable` outputs"
282-
# )
283-
return None, join_axis
271+
mixture_rvs = joined_rvs.owner.inputs[1:]
284272

285273
return mixture_rvs, join_axis
286274

@@ -302,33 +290,45 @@ def mixture_replace(fgraph, node):
302290

303291
old_mixture_rv = node.default_output()
304292

305-
mixture_res, join_axis = get_stack_mixture_vars(node)
293+
mixture_rvs, join_axis = get_stack_mixture_vars(node)
294+
295+
# We don't support symbolic join axis
296+
if mixture_rvs is None or not isinstance(join_axis, (NoneTypeT, Constant)):
297+
return None
306298

307-
if mixture_res is None or any(rv in rv_map_feature.rv_values for rv in mixture_res):
299+
# Check that all components are MeasurableVariables and none is already conditioned on
300+
if not all(
301+
(
302+
rv.owner is not None
303+
and isinstance(rv.owner.op, MeasurableVariable)
304+
and rv not in rv_map_feature.rv_values
305+
)
306+
for rv in mixture_rvs
307+
):
308308
return None # pragma: no cover
309309

310310
mixing_indices = node.inputs[1:]
311311

312312
# We loop through mixture components and collect all the array elements
313313
# that belong to each one (by way of their indices).
314-
mixture_rvs = []
315-
for i, component_rv in enumerate(mixture_res):
314+
new_mixture_rvs = []
315+
for i, component_rv in enumerate(mixture_rvs):
316316

317317
# We create custom types for the mixture components and assign them
318318
# null `get_measurable_outputs` dispatches so that they aren't
319319
# erroneously encountered in places like `factorized_joint_logprob`.
320320
new_node = assign_custom_measurable_outputs(component_rv.owner)
321321
out_idx = component_rv.owner.outputs.index(component_rv)
322322
new_comp_rv = new_node.outputs[out_idx]
323-
mixture_rvs.append(new_comp_rv)
323+
new_mixture_rvs.append(new_comp_rv)
324324

325325
# Replace this sub-graph with a `MixtureRV`
326326
mix_op = MixtureRV(
327327
1 + len(mixing_indices),
328328
old_mixture_rv.dtype,
329329
old_mixture_rv.broadcastable,
330330
)
331-
new_node = mix_op.make_node(*([join_axis] + mixing_indices + mixture_rvs))
331+
new_node = mix_op.make_node(*([join_axis] + mixing_indices + new_mixture_rvs))
332332

333333
new_mixture_rv = new_node.default_output()
334334

0 commit comments

Comments
 (0)