|
12 | 12 | from pytensor.compile.mode import get_default_mode, get_mode
|
13 | 13 | from pytensor.compile.ops import DeepCopyOp, deep_copy_op
|
14 | 14 | from pytensor.configdefaults import config
|
15 |
| -from pytensor.graph.basic import equal_computations |
| 15 | +from pytensor.graph.basic import equal_computations, vars_between |
16 | 16 | from pytensor.graph.fg import FunctionGraph
|
17 | 17 | from pytensor.graph.rewriting.basic import check_stack_trace, out2in
|
18 | 18 | from pytensor.graph.rewriting.db import RewriteDatabaseQuery
|
19 | 19 | from pytensor.graph.rewriting.utils import rewrite_graph
|
20 | 20 | from pytensor.printing import debugprint, pprint
|
21 | 21 | from pytensor.raise_op import Assert, CheckAndRaise
|
22 |
| -from pytensor.scalar.basic import Add |
23 | 22 | from pytensor.tensor.basic import (
|
24 | 23 | Alloc,
|
25 | 24 | Join,
|
|
32 | 31 | )
|
33 | 32 | from pytensor.tensor.elemwise import DimShuffle, Elemwise
|
34 | 33 | from pytensor.tensor.math import (
|
| 34 | + Sum, |
35 | 35 | add,
|
36 | 36 | bitwise_and,
|
37 | 37 | bitwise_or,
|
|
103 | 103 | values_eq_approx_remove_nan,
|
104 | 104 | vector,
|
105 | 105 | )
|
106 |
| -from pytensor.tensor.var import TensorVariable |
107 | 106 | from tests import unittest_tools as utt
|
108 | 107 |
|
109 | 108 |
|
@@ -1307,13 +1306,20 @@ def test_local_sum_make_vector():
|
1307 | 1306 | mv = MakeVector(config.floatX)
|
1308 | 1307 | output = mv(a, b, c).sum()
|
1309 | 1308 |
|
1310 |
| - func = function([a, b, c], output) |
| 1309 | + output = rewrite_graph(output) |
| 1310 | + between = vars_between([a, b, c], [output]) |
| 1311 | + for var in between: |
| 1312 | + assert (var.owner is None) or (not isinstance(var.owner.op, MakeVector)) |
1311 | 1313 |
|
1312 |
| - elemwise = func.maker.fgraph.outputs[0].owner |
1313 |
| - # The MakeVector op should be optimized away, so we just |
1314 |
| - # take the sum of the scalars. |
1315 |
| - assert elemwise.inputs[0].name == "a" |
1316 |
| - assert isinstance(elemwise.inputs[0], TensorVariable) |
| 1314 | + # Check for empty sum |
| 1315 | + a, b, c = scalars("abc") |
| 1316 | + mv = MakeVector(config.floatX) |
| 1317 | + output = mv(a, b, c).sum(axis=[]) |
| 1318 | + |
| 1319 | + output = rewrite_graph(output) |
| 1320 | + between = vars_between([a, b, c], [output]) |
| 1321 | + for var in between: |
| 1322 | + assert (var.owner is None) or (not isinstance(var.owner.op, Sum)) |
1317 | 1323 |
|
1318 | 1324 |
|
1319 | 1325 | @pytest.mark.parametrize(
|
|
0 commit comments