Skip to content

Commit 90b6bec

Browse files
ricardoV94twiecki
authored andcommitted
Do not treat broadcasted variables as independent in logprob inference
This included two cases: ## Direct valuation of broadcasted RVs The `naive_bcast_lift` rewrite was included by default and allowed broadcasted RVs to be valued. This is invalid because it implies that `logp(broadcast_to(normal(0, 1), (3, 2), value) == logp(normal(0, 1, size=(3, 2)), value)` which is not true. Broadcast replicates the same RV draws, so these values can't be considered independent when evaluating the logp. The rewrite is kept but not used anywhere ## Valuation of Mixtures with potential repeated components This can happen when AdvancedIndexing is used. As a precaution, Mixture replace now fails when Advanced integer indexing is detected, even though some cases may be valid at runtime (e.g., no repated indexes)
1 parent 762de98 commit 90b6bec

File tree

4 files changed

+80
-22
lines changed

4 files changed

+80
-22
lines changed

pymc/logprob/mixture.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,15 @@
5858
)
5959
from pytensor.tensor.shape import shape_tuple
6060
from pytensor.tensor.subtensor import (
61+
AdvancedSubtensor,
62+
AdvancedSubtensor1,
6163
as_index_literal,
6264
as_nontensor_scalar,
6365
get_canonical_form_slice,
6466
is_basic_idx,
6567
)
6668
from pytensor.tensor.type import TensorType
67-
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType
69+
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType
6870
from pytensor.tensor.var import TensorVariable
6971

7072
from pymc.logprob.abstract import (
@@ -309,6 +311,17 @@ def mixture_replace(fgraph, node):
309311

310312
mixing_indices = node.inputs[1:]
311313

314+
# TODO: Add check / test case for Advanced Boolean indexing
315+
if isinstance(node.op, (AdvancedSubtensor, AdvancedSubtensor1)):
316+
# We don't support (non-scalar) integer array indexing as it can pick repeated values,
317+
# but the Mixture logprob assumes all mixture values are independent
318+
if any(
319+
indices.dtype.startswith("int") and sum(1 - b for b in indices.type.broadcastable) > 0
320+
for indices in mixing_indices
321+
if not isinstance(indices, SliceConstant)
322+
):
323+
return None
324+
312325
# We loop through mixture components and collect all the array elements
313326
# that belong to each one (by way of their indices).
314327
new_mixture_rvs = []

pymc/logprob/tensor.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,10 +357,6 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[List[MeasurableDimShuf
357357
"find_measurable_dimshuffles", find_measurable_dimshuffles, "basic", "tensor"
358358
)
359359

360-
361-
measurable_ir_rewrites_db.register("broadcast_to_lift", naive_bcast_rv_lift, "basic", "tensor")
362-
363-
364360
measurable_ir_rewrites_db.register(
365361
"find_measurable_stacks",
366362
find_measurable_stacks,

pymc/tests/logprob/test_mixture.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def create_mix_model(size, axis):
9191
with pytest.raises(RuntimeError, match="could not be derived: {m}"):
9292
factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv, X_rv: x_vv})
9393

94-
with pytest.raises(NotImplementedError):
94+
with pytest.raises(RuntimeError, match="could not be derived: {m}"):
9595
axis_at = at.lscalar("axis")
9696
axis_at.tag.test_value = 0
9797
env = create_mix_model((2,), axis_at)
@@ -139,17 +139,19 @@ def test_compute_test_value(op_constructor):
139139

140140

141141
@pytest.mark.parametrize(
142-
"p_val, size",
142+
"p_val, size, supported",
143143
[
144-
(np.array(0.0, dtype=pytensor.config.floatX), ()),
145-
(np.array(1.0, dtype=pytensor.config.floatX), ()),
146-
(np.array(0.0, dtype=pytensor.config.floatX), (2,)),
147-
(np.array(1.0, dtype=pytensor.config.floatX), (2, 1)),
148-
(np.array(1.0, dtype=pytensor.config.floatX), (2, 3)),
149-
(np.array([0.1, 0.9], dtype=pytensor.config.floatX), (2, 3)),
144+
(np.array(0.0, dtype=pytensor.config.floatX), (), True),
145+
(np.array(1.0, dtype=pytensor.config.floatX), (), True),
146+
(np.array([0.1, 0.9], dtype=pytensor.config.floatX), (), True),
147+
# The cases belowe are not supported because they may pick repeated values via AdvancedIndexing
148+
(np.array(0.0, dtype=pytensor.config.floatX), (2,), False),
149+
(np.array(1.0, dtype=pytensor.config.floatX), (2, 1), False),
150+
(np.array(1.0, dtype=pytensor.config.floatX), (2, 3), False),
151+
(np.array([0.1, 0.9], dtype=pytensor.config.floatX), (2, 3), False),
150152
],
151153
)
152-
def test_hetero_mixture_binomial(p_val, size):
154+
def test_hetero_mixture_binomial(p_val, size, supported):
153155
srng = at.random.RandomStream(29833)
154156

155157
X_rv = srng.normal(0, 1, size=size, name="X")
@@ -175,7 +177,12 @@ def test_hetero_mixture_binomial(p_val, size):
175177
m_vv = M_rv.clone()
176178
m_vv.name = "m"
177179

178-
M_logp = joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False)
180+
if supported:
181+
M_logp = joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False)
182+
else:
183+
with pytest.raises(RuntimeError, match="could not be derived: {m}"):
184+
joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False)
185+
return
179186

180187
M_logp_fn = pytensor.function([p_at, m_vv, i_vv], M_logp)
181188

@@ -204,9 +211,9 @@ def test_hetero_mixture_binomial(p_val, size):
204211

205212

206213
@pytest.mark.parametrize(
207-
"X_args, Y_args, Z_args, p_val, comp_size, idx_size, extra_indices, join_axis",
214+
"X_args, Y_args, Z_args, p_val, comp_size, idx_size, extra_indices, join_axis, supported",
208215
[
209-
# Scalar mixture components, scalar index
216+
# Scalar components, scalar index
210217
(
211218
(
212219
np.array(0, dtype=pytensor.config.floatX),
@@ -225,6 +232,7 @@ def test_hetero_mixture_binomial(p_val, size):
225232
(),
226233
(),
227234
0,
235+
True,
228236
),
229237
# Degenerate vector mixture components, scalar index along join axis
230238
(
@@ -245,6 +253,7 @@ def test_hetero_mixture_binomial(p_val, size):
245253
(),
246254
(),
247255
0,
256+
True,
248257
),
249258
# Degenerate vector mixture components, scalar index along join axis (axis=1)
250259
(
@@ -265,6 +274,7 @@ def test_hetero_mixture_binomial(p_val, size):
265274
(),
266275
(slice(None),),
267276
1,
277+
True,
268278
),
269279
# Vector mixture components, scalar index along the join axis
270280
(
@@ -285,6 +295,7 @@ def test_hetero_mixture_binomial(p_val, size):
285295
(),
286296
(),
287297
0,
298+
True,
288299
),
289300
# Vector mixture components, scalar index along the join axis (axis=1)
290301
(
@@ -305,6 +316,7 @@ def test_hetero_mixture_binomial(p_val, size):
305316
(),
306317
(slice(None),),
307318
1,
319+
True,
308320
),
309321
# Vector mixture components, scalar index that mixes across components
310322
pytest.param(
@@ -325,6 +337,7 @@ def test_hetero_mixture_binomial(p_val, size):
325337
(),
326338
(),
327339
1,
340+
True,
328341
marks=pytest.mark.xfail(
329342
AssertionError,
330343
match="Arrays are not almost equal to 6 decimals", # This is ignored, but that's where it should fail!
@@ -350,7 +363,10 @@ def test_hetero_mixture_binomial(p_val, size):
350363
(),
351364
(),
352365
0,
366+
True,
353367
),
368+
# All the tests below rely on AdvancedIndexing, which is not supported at the moment
369+
# See https://github.com/pymc-devs/pymc/issues/6398
354370
# Scalar mixture components, vector index along first axis
355371
(
356372
(
@@ -370,6 +386,7 @@ def test_hetero_mixture_binomial(p_val, size):
370386
(6,),
371387
(),
372388
0,
389+
False,
373390
),
374391
# Vector mixture components, vector index along first axis
375392
(
@@ -390,9 +407,10 @@ def test_hetero_mixture_binomial(p_val, size):
390407
(2,),
391408
(slice(None),),
392409
0,
410+
False,
393411
),
394412
# Vector mixture components, vector index along last axis
395-
pytest.param(
413+
(
396414
(
397415
np.array(0, dtype=pytensor.config.floatX),
398416
np.array(1, dtype=pytensor.config.floatX),
@@ -410,7 +428,7 @@ def test_hetero_mixture_binomial(p_val, size):
410428
(4,),
411429
(slice(None),),
412430
1,
413-
marks=pytest.mark.xfail(IndexError, reason="Bug in AdvancedIndex Mixture logprob"),
431+
False,
414432
),
415433
# Vector mixture components (with degenerate vector parameters), vector index along first axis
416434
(
@@ -431,6 +449,7 @@ def test_hetero_mixture_binomial(p_val, size):
431449
(2,),
432450
(),
433451
0,
452+
False,
434453
),
435454
# Vector mixture components (with vector parameters), vector index along first axis
436455
(
@@ -451,6 +470,7 @@ def test_hetero_mixture_binomial(p_val, size):
451470
(2,),
452471
(),
453472
0,
473+
False,
454474
),
455475
# Vector mixture components (with vector parameters), vector index along first axis, implicit sizes
456476
(
@@ -471,6 +491,7 @@ def test_hetero_mixture_binomial(p_val, size):
471491
None,
472492
(),
473493
0,
494+
False,
474495
),
475496
# Matrix mixture components, matrix index
476497
(
@@ -491,6 +512,7 @@ def test_hetero_mixture_binomial(p_val, size):
491512
(2, 3),
492513
(),
493514
0,
515+
False,
494516
),
495517
# Vector components, matrix indexing (constant along first dimension, then random)
496518
(
@@ -511,6 +533,7 @@ def test_hetero_mixture_binomial(p_val, size):
511533
(5,),
512534
(np.arange(5),),
513535
0,
536+
False,
514537
),
515538
# Vector mixture components, tensor3 indexing (constant along first dimension, then degenerate, then random)
516539
(
@@ -531,11 +554,12 @@ def test_hetero_mixture_binomial(p_val, size):
531554
(5,),
532555
(np.arange(5), None),
533556
0,
557+
False,
534558
),
535559
],
536560
)
537561
def test_hetero_mixture_categorical(
538-
X_args, Y_args, Z_args, p_val, comp_size, idx_size, extra_indices, join_axis
562+
X_args, Y_args, Z_args, p_val, comp_size, idx_size, extra_indices, join_axis, supported
539563
):
540564
srng = at.random.RandomStream(29833)
541565

@@ -561,7 +585,12 @@ def test_hetero_mixture_categorical(
561585
m_vv = M_rv.clone()
562586
m_vv.name = "m"
563587

564-
logp_parts = factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False)
588+
if supported:
589+
logp_parts = factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False)
590+
else:
591+
with pytest.raises(RuntimeError, match="could not be derived: {m}"):
592+
factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False)
593+
return
565594

566595
I_logp_fn = pytensor.function([p_at, i_vv], logp_parts[i_vv])
567596
M_logp_fn = pytensor.function([m_vv, i_vv], logp_parts[m_vv])
@@ -854,7 +883,7 @@ def test_mixture_with_DiracDelta():
854883
Y_rv = dirac_delta(0.0)
855884
Y_rv.name = "Y"
856885

857-
I_rv = srng.categorical([0.5, 0.5], size=4)
886+
I_rv = srng.categorical([0.5, 0.5], size=1)
858887

859888
i_vv = I_rv.clone()
860889
i_vv.name = "i"

pymc/tests/logprob/test_tensor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,26 @@ def test_naive_bcast_rv_lift_valued_var():
8181
assert np.allclose(logp_map[y_vv].eval({x_vv: 0, y_vv: [0, 0]}), st.norm(0).logpdf([0, 0]))
8282

8383

84+
@pytest.mark.xfail(RuntimeError, reason="logprob for broadcasted RVs not implemented")
85+
def test_bcast_rv_logp():
86+
"""Test that derived logp for broadcasted RV is correct"""
87+
88+
x_rv = at.random.normal(name="x")
89+
broadcasted_x_rv = at.broadcast_to(x_rv, (2,))
90+
broadcasted_x_rv.name = "broadcasted_x"
91+
broadcasted_x_vv = broadcasted_x_rv.clone()
92+
93+
logp = joint_logprob({broadcasted_x_rv: broadcasted_x_vv}, sum=False)
94+
valid_logp = logp.eval({broadcasted_x_vv: [0, 0]})
95+
assert valid_logp.shape == ()
96+
assert np.isclose(valid_logp, st.norm.logpdf(0))
97+
98+
# It's not possible for broadcasted dimensions to have different values
99+
# This shoud either raise or return -inf
100+
invalid_logp = logp.eval({broadcasted_x_vv: [0, 1]})
101+
assert invalid_logp == -np.inf
102+
103+
84104
def test_measurable_make_vector():
85105
base1_rv = at.random.normal(name="base1")
86106
base2_rv = at.random.halfnormal(name="base2")

0 commit comments

Comments
 (0)