Skip to content

Commit 762de98

Browse files
ricardoV94twiecki
authored andcommitted
Add failing test when IfElse Mixture logprob is used with indexing that mixes across components
1 parent 17dca13 commit 762de98

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

pymc/logprob/mixture.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,10 @@ def logprob_MixtureRV(
445445
logp_val = at.set_subtensor(logp_val[idx_m_on_axis], logp_m)
446446

447447
else:
448+
# FIXME: This logprob implementation does not support mixing across distinct components,
449+
# but we sometimes use it, because MixtureRV does not keep information about at which
450+
# dimension scalar indexing actually starts
451+
448452
# If the stacking operation expands the component RVs, we have
449453
# to expand the value and later squeeze the logprob for everything
450454
# to work correctly

pymc/tests/logprob/test_mixture.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,31 @@ def test_hetero_mixture_binomial(p_val, size):
306306
(slice(None),),
307307
1,
308308
),
309+
# Vector mixture components, scalar index that mixes across components
310+
pytest.param(
311+
(
312+
np.array(0, dtype=pytensor.config.floatX),
313+
np.array(1, dtype=pytensor.config.floatX),
314+
),
315+
(
316+
np.array(0.5, dtype=pytensor.config.floatX),
317+
np.array(0.5, dtype=pytensor.config.floatX),
318+
),
319+
(
320+
np.array(100, dtype=pytensor.config.floatX),
321+
np.array(1, dtype=pytensor.config.floatX),
322+
),
323+
np.array([0.1, 0.5, 0.1, 0.3], dtype=pytensor.config.floatX),
324+
(4,),
325+
(),
326+
(),
327+
1,
328+
marks=pytest.mark.xfail(
329+
AssertionError,
330+
match="Arrays are not almost equal to 6 decimals", # This is ignored, but that's where it should fail!
331+
reason="IfElse Mixture logprob fails when indexing mixes across components",
332+
),
333+
),
309334
# Matrix components, scalar index along first axis
310335
(
311336
(

0 commit comments

Comments
 (0)