39
39
import pytensor
40
40
import pytensor .tensor as at
41
41
42
- from pytensor .graph .basic import Apply , Variable
42
+ from pytensor .graph .basic import Apply , Constant , Variable
43
43
from pytensor .graph .fg import FunctionGraph
44
44
from pytensor .graph .op import Op , compute_test_value
45
45
from pytensor .graph .rewriting .basic import (
@@ -246,41 +246,29 @@ def get_stack_mixture_vars(
246
246
node : Apply ,
247
247
) -> Tuple [Optional [List [TensorVariable ]], Optional [int ]]:
248
248
r"""Extract the mixture terms from a `*Subtensor*` applied to stacked `MeasurableVariable`\s."""
249
- if not isinstance (node .op , subtensor_ops ):
250
- return None , None # pragma: no cover
251
249
252
- join_axis = NoneConst
250
+ assert isinstance (node .op , subtensor_ops )
251
+
253
252
joined_rvs = node .inputs [0 ]
254
253
255
254
# First, make sure that it's some sort of concatenation
256
255
if not (joined_rvs .owner and isinstance (joined_rvs .owner .op , (MakeVector , Join ))):
257
- # Node is not a compatible join `Op`
258
- return None , join_axis # pragma: no cover
256
+ return None , None
259
257
260
258
if isinstance (joined_rvs .owner .op , MakeVector ):
259
+ join_axis = NoneConst
261
260
mixture_rvs = joined_rvs .owner .inputs
262
261
263
262
elif isinstance (joined_rvs .owner .op , Join ):
264
- mixture_rvs = joined_rvs .owner .inputs [1 :]
265
- join_axis = joined_rvs .owner .inputs [0 ]
266
- try :
267
- # TODO: Find better solution to avoid this circular dependency
268
- from pymc .pytensorf import constant_fold
269
-
270
- join_axis = int (constant_fold ((join_axis ,))[0 ])
271
- except ValueError :
272
- # TODO: Support symbolic join axes
273
- raise NotImplementedError ("Symbolic `Join` axes are not supported in mixtures" )
263
+ # TODO: Find better solution to avoid this circular dependency
264
+ from pymc .pytensorf import constant_fold
274
265
275
- join_axis = at .as_tensor (join_axis )
266
+ join_axis = joined_rvs .owner .inputs [0 ]
267
+ # TODO: Support symbolic join axes. This will raise ValueError if it's not a constant
268
+ (join_axis ,) = constant_fold ((join_axis ,), raise_not_constant = False )
269
+ join_axis = at .as_tensor (join_axis , dtype = "int64" )
276
270
277
- if not all (rv .owner and isinstance (rv .owner .op , MeasurableVariable ) for rv in mixture_rvs ):
278
- # Currently, all mixture components must be `MeasurableVariable` outputs
279
- # TODO: Allow constants and make them Dirac-deltas
280
- # raise NotImplementedError(
281
- # "All mixture components must be `MeasurableVariable` outputs"
282
- # )
283
- return None , join_axis
271
+ mixture_rvs = joined_rvs .owner .inputs [1 :]
284
272
285
273
return mixture_rvs , join_axis
286
274
@@ -302,33 +290,45 @@ def mixture_replace(fgraph, node):
302
290
303
291
old_mixture_rv = node .default_output ()
304
292
305
- mixture_res , join_axis = get_stack_mixture_vars (node )
293
+ mixture_rvs , join_axis = get_stack_mixture_vars (node )
294
+
295
+ # We don't support symbolic join axis
296
+ if mixture_rvs is None or not isinstance (join_axis , (NoneTypeT , Constant )):
297
+ return None
306
298
307
- if mixture_res is None or any (rv in rv_map_feature .rv_values for rv in mixture_res ):
299
+ # Check that all components are MeasurableVariables and none is already conditioned on
300
+ if not all (
301
+ (
302
+ rv .owner is not None
303
+ and isinstance (rv .owner .op , MeasurableVariable )
304
+ and rv not in rv_map_feature .rv_values
305
+ )
306
+ for rv in mixture_rvs
307
+ ):
308
308
return None # pragma: no cover
309
309
310
310
mixing_indices = node .inputs [1 :]
311
311
312
312
# We loop through mixture components and collect all the array elements
313
313
# that belong to each one (by way of their indices).
314
- mixture_rvs = []
315
- for i , component_rv in enumerate (mixture_res ):
314
+ new_mixture_rvs = []
315
+ for i , component_rv in enumerate (mixture_rvs ):
316
316
317
317
# We create custom types for the mixture components and assign them
318
318
# null `get_measurable_outputs` dispatches so that they aren't
319
319
# erroneously encountered in places like `factorized_joint_logprob`.
320
320
new_node = assign_custom_measurable_outputs (component_rv .owner )
321
321
out_idx = component_rv .owner .outputs .index (component_rv )
322
322
new_comp_rv = new_node .outputs [out_idx ]
323
- mixture_rvs .append (new_comp_rv )
323
+ new_mixture_rvs .append (new_comp_rv )
324
324
325
325
# Replace this sub-graph with a `MixtureRV`
326
326
mix_op = MixtureRV (
327
327
1 + len (mixing_indices ),
328
328
old_mixture_rv .dtype ,
329
329
old_mixture_rv .broadcastable ,
330
330
)
331
- new_node = mix_op .make_node (* ([join_axis ] + mixing_indices + mixture_rvs ))
331
+ new_node = mix_op .make_node (* ([join_axis ] + mixing_indices + new_mixture_rvs ))
332
332
333
333
new_mixture_rv = new_node .default_output ()
334
334
0 commit comments