Skip to content

Commit 4b47472

Browse files
committed
Improve test_local_sum_make_vector rewrite
1 parent 780f431 commit 4b47472

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -990,14 +990,18 @@ def local_sum_make_vector(fgraph, node):
990990
if not isinstance(array.owner.op, MakeVector):
991991
return
992992

993-
if node.op.axis not in [None, 0, -1]:
994-
return
993+
if node.op.axis == ():
994+
return [array]
995+
996+
# If this is not the case the sum is invalid
997+
assert node.op.axis is None or node.op.axis == (0,)
995998

996999
elements = array.owner.inputs
997-
dtype = node.op.acc_dtype
998-
element_sum = add(*[cast(value, dtype) for value in elements])
1000+
acc_dtype = node.op.acc_dtype
1001+
out_dtype = node.op.dtype
1002+
element_sum = cast(add(*[cast(value, acc_dtype) for value in elements]), out_dtype)
9991003

1000-
return [as_tensor_variable(element_sum)]
1004+
return [element_sum]
10011005

10021006

10031007
@register_useless("local_remove_switch_const_cond")

tests/tensor/rewriting/test_basic.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@
1212
from pytensor.compile.mode import get_default_mode, get_mode
1313
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
1414
from pytensor.configdefaults import config
15-
from pytensor.graph.basic import equal_computations
15+
from pytensor.graph.basic import equal_computations, vars_between
1616
from pytensor.graph.fg import FunctionGraph
1717
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
1818
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1919
from pytensor.graph.rewriting.utils import rewrite_graph
2020
from pytensor.printing import debugprint, pprint
2121
from pytensor.raise_op import Assert, CheckAndRaise
22-
from pytensor.scalar.basic import Add
2322
from pytensor.tensor.basic import (
2423
Alloc,
2524
Join,
@@ -32,6 +31,7 @@
3231
)
3332
from pytensor.tensor.elemwise import DimShuffle, Elemwise
3433
from pytensor.tensor.math import (
34+
Sum,
3535
add,
3636
bitwise_and,
3737
bitwise_or,
@@ -103,7 +103,6 @@
103103
values_eq_approx_remove_nan,
104104
vector,
105105
)
106-
from pytensor.tensor.var import TensorVariable
107106
from tests import unittest_tools as utt
108107

109108

@@ -1307,13 +1306,20 @@ def test_local_sum_make_vector():
13071306
mv = MakeVector(config.floatX)
13081307
output = mv(a, b, c).sum()
13091308

1310-
func = function([a, b, c], output)
1309+
output = rewrite_graph(output)
1310+
between = vars_between([a, b, c], [output])
1311+
for var in between:
1312+
assert (var.owner is None) or (not isinstance(var.owner.op, MakeVector))
13111313

1312-
elemwise = func.maker.fgraph.outputs[0].owner
1313-
# The MakeVector op should be optimized away, so we just
1314-
# take the sum of the scalars.
1315-
assert elemwise.inputs[0].name == "a"
1316-
assert isinstance(elemwise.inputs[0], TensorVariable)
1314+
# Check for empty sum
1315+
a, b, c = scalars("abc")
1316+
mv = MakeVector(config.floatX)
1317+
output = mv(a, b, c).sum(axis=[])
1318+
1319+
output = rewrite_graph(output)
1320+
between = vars_between([a, b, c], [output])
1321+
for var in between:
1322+
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
13171323

13181324

13191325
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)