Skip to content

Commit a9204fb

Browse files
committed
Replace str "output" by a dummy Op in the clients of the FunctionGraph
1 parent 2143d85 commit a9204fb

File tree

18 files changed

+162
-151
lines changed

18 files changed

+162
-151
lines changed

pytensor/compile/debugmode.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pytensor.graph.basic import Variable, io_toposort
3131
from pytensor.graph.destroyhandler import DestroyHandler
3232
from pytensor.graph.features import AlreadyThere, BadOptimization
33+
from pytensor.graph.fg import Output
3334
from pytensor.graph.op import HasInnerGraph, Op
3435
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
3536
from pytensor.link.basic import Container, LocalLinker
@@ -628,7 +629,11 @@ def _is_used_in_graph(fgraph, var):
628629
True if `var` is used by another node in the graph.
629630
630631
"""
631-
return not (fgraph.clients[var] == [("output", 1)] or fgraph.clients[var] == [])
632+
return any(
633+
client
634+
for client, _ in fgraph.clients[var]
635+
if not isinstance(client.owner.op, Output)
636+
)
632637

633638

634639
def _check_strides_match(a, b, warn_err, op):
@@ -978,7 +983,7 @@ def _check_preallocated_output(
978983
# disable memory checks in that mode, since they were already run.
979984
try:
980985
changed_inner_mode = False
981-
if isinstance(getattr(node, "op", None), HasInnerGraph):
986+
if isinstance(node.op, HasInnerGraph):
982987
fn = node.op.fn
983988
if not fn or not hasattr(fn, "maker") or not hasattr(fn.maker, "mode"):
984989
_logger.warning(f"Expected pytensor function not found in {node.op}.fn")
@@ -1133,18 +1138,14 @@ class _FunctionGraphEvent:
11331138

11341139
def __init__(self, kind, node, idx=None, reason=None):
11351140
self.kind = kind
1136-
if node == "output":
1137-
self.node = "output"
1138-
self.op = "output"
1139-
else:
1140-
self.node = node
1141-
self.op = node.op
1141+
self.node = node
1142+
self.op = node.op
11421143
self.idx = idx
11431144
self.reason = str(reason)
11441145

11451146
def __str__(self):
11461147
if self.kind == "change":
1147-
if self.op != "output":
1148+
if not isinstance(self.op, Output):
11481149
msg = str(len(self.node.inputs))
11491150
else:
11501151
msg = ""

pytensor/compile/function/types.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525
from pytensor.graph.destroyhandler import DestroyHandler
2626
from pytensor.graph.features import AlreadyThere, Feature, PreserveVariableAttributes
27-
from pytensor.graph.fg import FunctionGraph
27+
from pytensor.graph.fg import FunctionGraph, Output
2828
from pytensor.graph.op import HasInnerGraph
2929
from pytensor.graph.utils import InconsistencyError, get_variable_trace_string
3030
from pytensor.link.basic import Container
@@ -77,8 +77,6 @@ def view_tree_set(fgraph, v, treeset):
7777
"""
7878
treeset.add(v)
7979
for cl, v_input_pos_to_cl in fgraph.clients[v]:
80-
if cl == "output":
81-
continue
8280
vmap = cl.op.view_map
8381
dmap = cl.op.destroy_map
8482
for opos, iposlist in chain(vmap.items(), dmap.items()):
@@ -1199,8 +1197,15 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
11991197
has_destroyers_attr = hasattr(fgraph, "has_destroyers")
12001198

12011199
for i in range(len(fgraph.outputs)):
1200+
original_out = fgraph.outputs[i]
1201+
[output_client] = [
1202+
cl
1203+
for cl, _ in fgraph.clients[original_out]
1204+
if isinstance(cl.op, Output) and cl.op.idx == i
1205+
]
1206+
12021207
views_of_output_i = set()
1203-
view_tree_set(fgraph, alias_root(fgraph.outputs[i]), views_of_output_i)
1208+
view_tree_set(fgraph, alias_root(original_out), views_of_output_i)
12041209
copied = False
12051210
# do not allow outputs to be aliased
12061211
for j in range(i + 1, len(fgraph.outputs)):
@@ -1209,16 +1214,16 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
12091214
if fgraph.outputs[j] in views_of_output_i:
12101215
if wrapped_outputs[i].borrow and wrapped_outputs[j].borrow:
12111216
fgraph.change_node_input(
1212-
"output", i, view_op(fgraph.outputs[i]), reason=reason
1217+
output_client, 0, view_op(original_out), reason=reason
12131218
)
12141219
else:
12151220
fgraph.change_node_input(
1216-
"output", i, deep_copy_op(fgraph.outputs[i]), reason=reason
1221+
output_client, 0, deep_copy_op(original_out), reason=reason
12171222
)
12181223
copied = True
12191224
break
12201225

1221-
if not copied:
1226+
if not copied: # no-break
12221227
for input_j in all_graph_inputs:
12231228
# do not allow outputs to be aliased to an inputs (j), unless
12241229
# a) that j'th input has been 'destroyed' by
@@ -1236,33 +1241,33 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
12361241
j = fgraph.inputs.index(input_j)
12371242
if wrapped_outputs[i].borrow and wrapped_inputs[j].borrow:
12381243
fgraph.change_node_input(
1239-
"output",
1240-
i,
1241-
view_op(fgraph.outputs[i]),
1244+
output_client,
1245+
0,
1246+
view_op(original_out),
12421247
reason=reason,
12431248
)
12441249
break
12451250
else:
12461251
fgraph.change_node_input(
1247-
"output",
1248-
i,
1249-
deep_copy_op(fgraph.outputs[i]),
1252+
output_client,
1253+
0,
1254+
deep_copy_op(original_out),
12501255
reason=reason,
12511256
)
12521257
break
12531258
elif wrapped_outputs[i].borrow:
12541259
fgraph.change_node_input(
1255-
"output",
1256-
i,
1257-
view_op(fgraph.outputs[i]),
1260+
output_client,
1261+
0,
1262+
view_op(original_out),
12581263
reason=reason,
12591264
)
12601265
break
12611266
else:
12621267
fgraph.change_node_input(
1263-
"output",
1264-
i,
1265-
deep_copy_op(fgraph.outputs[i]),
1268+
output_client,
1269+
0,
1270+
deep_copy_op(original_out),
12661271
reason=reason,
12671272
)
12681273
break

pytensor/compile/profiling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
if TYPE_CHECKING:
30-
from pytensor.graph.fg import FunctionGraph
30+
from pytensor.graph.fg import FunctionGraph, Output
3131

3232

3333
@contextmanager
@@ -1055,7 +1055,7 @@ def count_minimum_peak(node_list, fgraph, nodes_mem):
10551055
executable_nodes = set()
10561056
for var in fgraph.inputs:
10571057
for c, _ in fgraph.clients[var]:
1058-
if c != "output":
1058+
if not isinstance(c.op, Output):
10591059
deps = c.inputs + destroy_dependencies[c]
10601060
if all(compute_map[v][0] for v in deps):
10611061
executable_nodes.add(c)
@@ -1183,7 +1183,7 @@ def min_memory_generator(executable_nodes, viewed_by, view_of):
11831183

11841184
for var in node.outputs:
11851185
for c, _ in fgraph.clients[var]:
1186-
if c != "output":
1186+
if not isinstance(c.op, Output):
11871187
deps = c.inputs + destroy_dependencies[c]
11881188
if all(compute_map[v][0] for v in deps):
11891189
new_exec_nodes.add(c)

pytensor/graph/destroyhandler.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytensor.configdefaults import config
1212
from pytensor.graph.basic import Constant
1313
from pytensor.graph.features import AlreadyThere, Bookkeeper
14+
from pytensor.graph.fg import Output
1415
from pytensor.graph.utils import InconsistencyError
1516
from pytensor.misc.ordered_set import OrderedSet
1617

@@ -401,8 +402,6 @@ def has_destroyers(protected_list):
401402
def recursive_destroys_finder(protected_var):
402403
# protected_var is the idx'th input of app.
403404
for app, idx in fgraph.clients[protected_var]:
404-
if app == "output":
405-
continue
406405
destroy_maps = app.op.destroy_map.values()
407406
# If True means that the apply node, destroys the protected_var.
408407
if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
@@ -575,10 +574,10 @@ def on_prune(self, fgraph, app, reason):
575574

576575
def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
577576
"""
578-
app.inputs[i] changed from old_r to new_r.
577+
node.inputs[i] changed from old_r to new_r.
579578
580579
"""
581-
if app == "output":
580+
if isinstance(app.op, Output):
582581
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
583582
# considered 'outputs' of the graph.
584583
pass

0 commit comments

Comments
 (0)