@@ -91,7 +91,7 @@ def create_mix_model(size, axis):
91
91
with pytest .raises (RuntimeError , match = "could not be derived: {m}" ):
92
92
factorized_joint_logprob ({M_rv : m_vv , I_rv : i_vv , X_rv : x_vv })
93
93
94
- with pytest .raises (NotImplementedError ):
94
+ with pytest .raises (RuntimeError , match = "could not be derived: {m}" ):
95
95
axis_at = at .lscalar ("axis" )
96
96
axis_at .tag .test_value = 0
97
97
env = create_mix_model ((2 ,), axis_at )
@@ -139,17 +139,19 @@ def test_compute_test_value(op_constructor):
139
139
140
140
141
141
@pytest .mark .parametrize (
142
- "p_val, size" ,
142
+ "p_val, size, supported " ,
143
143
[
144
- (np .array (0.0 , dtype = pytensor .config .floatX ), ()),
145
- (np .array (1.0 , dtype = pytensor .config .floatX ), ()),
146
- (np .array (0.0 , dtype = pytensor .config .floatX ), (2 ,)),
147
- (np .array (1.0 , dtype = pytensor .config .floatX ), (2 , 1 )),
148
- (np .array (1.0 , dtype = pytensor .config .floatX ), (2 , 3 )),
149
- (np .array ([0.1 , 0.9 ], dtype = pytensor .config .floatX ), (2 , 3 )),
144
+ (np .array (0.0 , dtype = pytensor .config .floatX ), (), True ),
145
+ (np .array (1.0 , dtype = pytensor .config .floatX ), (), True ),
146
+ (np .array ([0.1 , 0.9 ], dtype = pytensor .config .floatX ), (), True ),
147
+ # The cases belowe are not supported because they may pick repeated values via AdvancedIndexing
148
+ (np .array (0.0 , dtype = pytensor .config .floatX ), (2 ,), False ),
149
+ (np .array (1.0 , dtype = pytensor .config .floatX ), (2 , 1 ), False ),
150
+ (np .array (1.0 , dtype = pytensor .config .floatX ), (2 , 3 ), False ),
151
+ (np .array ([0.1 , 0.9 ], dtype = pytensor .config .floatX ), (2 , 3 ), False ),
150
152
],
151
153
)
152
- def test_hetero_mixture_binomial (p_val , size ):
154
+ def test_hetero_mixture_binomial (p_val , size , supported ):
153
155
srng = at .random .RandomStream (29833 )
154
156
155
157
X_rv = srng .normal (0 , 1 , size = size , name = "X" )
@@ -175,7 +177,12 @@ def test_hetero_mixture_binomial(p_val, size):
175
177
m_vv = M_rv .clone ()
176
178
m_vv .name = "m"
177
179
178
- M_logp = joint_logprob ({M_rv : m_vv , I_rv : i_vv }, sum = False )
180
+ if supported :
181
+ M_logp = joint_logprob ({M_rv : m_vv , I_rv : i_vv }, sum = False )
182
+ else :
183
+ with pytest .raises (RuntimeError , match = "could not be derived: {m}" ):
184
+ joint_logprob ({M_rv : m_vv , I_rv : i_vv }, sum = False )
185
+ return
179
186
180
187
M_logp_fn = pytensor .function ([p_at , m_vv , i_vv ], M_logp )
181
188
@@ -204,9 +211,9 @@ def test_hetero_mixture_binomial(p_val, size):
204
211
205
212
206
213
@pytest .mark .parametrize (
207
- "X_args, Y_args, Z_args, p_val, comp_size, idx_size, extra_indices, join_axis" ,
214
+ "X_args, Y_args, Z_args, p_val, comp_size, idx_size, extra_indices, join_axis, supported " ,
208
215
[
209
- # Scalar mixture components, scalar index
216
+ # Scalar components, scalar index
210
217
(
211
218
(
212
219
np .array (0 , dtype = pytensor .config .floatX ),
@@ -225,6 +232,7 @@ def test_hetero_mixture_binomial(p_val, size):
225
232
(),
226
233
(),
227
234
0 ,
235
+ True ,
228
236
),
229
237
# Degenerate vector mixture components, scalar index along join axis
230
238
(
@@ -245,6 +253,7 @@ def test_hetero_mixture_binomial(p_val, size):
245
253
(),
246
254
(),
247
255
0 ,
256
+ True ,
248
257
),
249
258
# Degenerate vector mixture components, scalar index along join axis (axis=1)
250
259
(
@@ -265,6 +274,7 @@ def test_hetero_mixture_binomial(p_val, size):
265
274
(),
266
275
(slice (None ),),
267
276
1 ,
277
+ True ,
268
278
),
269
279
# Vector mixture components, scalar index along the join axis
270
280
(
@@ -285,6 +295,7 @@ def test_hetero_mixture_binomial(p_val, size):
285
295
(),
286
296
(),
287
297
0 ,
298
+ True ,
288
299
),
289
300
# Vector mixture components, scalar index along the join axis (axis=1)
290
301
(
@@ -305,6 +316,7 @@ def test_hetero_mixture_binomial(p_val, size):
305
316
(),
306
317
(slice (None ),),
307
318
1 ,
319
+ True ,
308
320
),
309
321
# Vector mixture components, scalar index that mixes across components
310
322
pytest .param (
@@ -325,6 +337,7 @@ def test_hetero_mixture_binomial(p_val, size):
325
337
(),
326
338
(),
327
339
1 ,
340
+ True ,
328
341
marks = pytest .mark .xfail (
329
342
AssertionError ,
330
343
match = "Arrays are not almost equal to 6 decimals" , # This is ignored, but that's where it should fail!
@@ -350,7 +363,10 @@ def test_hetero_mixture_binomial(p_val, size):
350
363
(),
351
364
(),
352
365
0 ,
366
+ True ,
353
367
),
368
+ # All the tests below rely on AdvancedIndexing, which is not supported at the moment
369
+ # See https://github.com/pymc-devs/pymc/issues/6398
354
370
# Scalar mixture components, vector index along first axis
355
371
(
356
372
(
@@ -370,6 +386,7 @@ def test_hetero_mixture_binomial(p_val, size):
370
386
(6 ,),
371
387
(),
372
388
0 ,
389
+ False ,
373
390
),
374
391
# Vector mixture components, vector index along first axis
375
392
(
@@ -390,9 +407,10 @@ def test_hetero_mixture_binomial(p_val, size):
390
407
(2 ,),
391
408
(slice (None ),),
392
409
0 ,
410
+ False ,
393
411
),
394
412
# Vector mixture components, vector index along last axis
395
- pytest . param (
413
+ (
396
414
(
397
415
np .array (0 , dtype = pytensor .config .floatX ),
398
416
np .array (1 , dtype = pytensor .config .floatX ),
@@ -410,7 +428,7 @@ def test_hetero_mixture_binomial(p_val, size):
410
428
(4 ,),
411
429
(slice (None ),),
412
430
1 ,
413
- marks = pytest . mark . xfail ( IndexError , reason = "Bug in AdvancedIndex Mixture logprob" ) ,
431
+ False ,
414
432
),
415
433
# Vector mixture components (with degenerate vector parameters), vector index along first axis
416
434
(
@@ -431,6 +449,7 @@ def test_hetero_mixture_binomial(p_val, size):
431
449
(2 ,),
432
450
(),
433
451
0 ,
452
+ False ,
434
453
),
435
454
# Vector mixture components (with vector parameters), vector index along first axis
436
455
(
@@ -451,6 +470,7 @@ def test_hetero_mixture_binomial(p_val, size):
451
470
(2 ,),
452
471
(),
453
472
0 ,
473
+ False ,
454
474
),
455
475
# Vector mixture components (with vector parameters), vector index along first axis, implicit sizes
456
476
(
@@ -471,6 +491,7 @@ def test_hetero_mixture_binomial(p_val, size):
471
491
None ,
472
492
(),
473
493
0 ,
494
+ False ,
474
495
),
475
496
# Matrix mixture components, matrix index
476
497
(
@@ -491,6 +512,7 @@ def test_hetero_mixture_binomial(p_val, size):
491
512
(2 , 3 ),
492
513
(),
493
514
0 ,
515
+ False ,
494
516
),
495
517
# Vector components, matrix indexing (constant along first dimension, then random)
496
518
(
@@ -511,6 +533,7 @@ def test_hetero_mixture_binomial(p_val, size):
511
533
(5 ,),
512
534
(np .arange (5 ),),
513
535
0 ,
536
+ False ,
514
537
),
515
538
# Vector mixture components, tensor3 indexing (constant along first dimension, then degenerate, then random)
516
539
(
@@ -531,11 +554,12 @@ def test_hetero_mixture_binomial(p_val, size):
531
554
(5 ,),
532
555
(np .arange (5 ), None ),
533
556
0 ,
557
+ False ,
534
558
),
535
559
],
536
560
)
537
561
def test_hetero_mixture_categorical (
538
- X_args , Y_args , Z_args , p_val , comp_size , idx_size , extra_indices , join_axis
562
+ X_args , Y_args , Z_args , p_val , comp_size , idx_size , extra_indices , join_axis , supported
539
563
):
540
564
srng = at .random .RandomStream (29833 )
541
565
@@ -561,7 +585,12 @@ def test_hetero_mixture_categorical(
561
585
m_vv = M_rv .clone ()
562
586
m_vv .name = "m"
563
587
564
- logp_parts = factorized_joint_logprob ({M_rv : m_vv , I_rv : i_vv }, sum = False )
588
+ if supported :
589
+ logp_parts = factorized_joint_logprob ({M_rv : m_vv , I_rv : i_vv }, sum = False )
590
+ else :
591
+ with pytest .raises (RuntimeError , match = "could not be derived: {m}" ):
592
+ factorized_joint_logprob ({M_rv : m_vv , I_rv : i_vv }, sum = False )
593
+ return
565
594
566
595
I_logp_fn = pytensor .function ([p_at , i_vv ], logp_parts [i_vv ])
567
596
M_logp_fn = pytensor .function ([m_vv , i_vv ], logp_parts [m_vv ])
@@ -854,7 +883,7 @@ def test_mixture_with_DiracDelta():
854
883
Y_rv = dirac_delta (0.0 )
855
884
Y_rv .name = "Y"
856
885
857
- I_rv = srng .categorical ([0.5 , 0.5 ], size = 4 )
886
+ I_rv = srng .categorical ([0.5 , 0.5 ], size = 1 )
858
887
859
888
i_vv = I_rv .clone ()
860
889
i_vv .name = "i"
0 commit comments