Skip to content

Commit 3ab8992

Browse files
committed
Replace str "output" by a dummy Op in the clients of the FunctionGraph
1 parent 2143d85 commit 3ab8992

File tree

18 files changed

+160
-155
lines changed

18 files changed

+160
-155
lines changed

pytensor/compile/debugmode.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pytensor.graph.basic import Variable, io_toposort
3131
from pytensor.graph.destroyhandler import DestroyHandler
3232
from pytensor.graph.features import AlreadyThere, BadOptimization
33+
from pytensor.graph.fg import Output
3334
from pytensor.graph.op import HasInnerGraph, Op
3435
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
3536
from pytensor.link.basic import Container, LocalLinker
@@ -628,7 +629,9 @@ def _is_used_in_graph(fgraph, var):
628629
True if `var` is used by another node in the graph.
629630
630631
"""
631-
return not (fgraph.clients[var] == [("output", 1)] or fgraph.clients[var] == [])
632+
return any(
633+
client for client, _ in fgraph.clients[var] if not isinstance(client.op, Output)
634+
)
632635

633636

634637
def _check_strides_match(a, b, warn_err, op):
@@ -978,7 +981,7 @@ def _check_preallocated_output(
978981
# disable memory checks in that mode, since they were already run.
979982
try:
980983
changed_inner_mode = False
981-
if isinstance(getattr(node, "op", None), HasInnerGraph):
984+
if isinstance(node.op, HasInnerGraph):
982985
fn = node.op.fn
983986
if not fn or not hasattr(fn, "maker") or not hasattr(fn.maker, "mode"):
984987
_logger.warning(f"Expected pytensor function not found in {node.op}.fn")
@@ -1133,18 +1136,14 @@ class _FunctionGraphEvent:
11331136

11341137
def __init__(self, kind, node, idx=None, reason=None):
11351138
self.kind = kind
1136-
if node == "output":
1137-
self.node = "output"
1138-
self.op = "output"
1139-
else:
1140-
self.node = node
1141-
self.op = node.op
1139+
self.node = node
1140+
self.op = node.op
11421141
self.idx = idx
11431142
self.reason = str(reason)
11441143

11451144
def __str__(self):
11461145
if self.kind == "change":
1147-
if self.op != "output":
1146+
if not isinstance(self.op, Output):
11481147
msg = str(len(self.node.inputs))
11491148
else:
11501149
msg = ""

pytensor/compile/function/types.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525
from pytensor.graph.destroyhandler import DestroyHandler
2626
from pytensor.graph.features import AlreadyThere, Feature, PreserveVariableAttributes
27-
from pytensor.graph.fg import FunctionGraph
27+
from pytensor.graph.fg import FunctionGraph, Output
2828
from pytensor.graph.op import HasInnerGraph
2929
from pytensor.graph.utils import InconsistencyError, get_variable_trace_string
3030
from pytensor.link.basic import Container
@@ -77,8 +77,6 @@ def view_tree_set(fgraph, v, treeset):
7777
"""
7878
treeset.add(v)
7979
for cl, v_input_pos_to_cl in fgraph.clients[v]:
80-
if cl == "output":
81-
continue
8280
vmap = cl.op.view_map
8381
dmap = cl.op.destroy_map
8482
for opos, iposlist in chain(vmap.items(), dmap.items()):
@@ -1199,8 +1197,15 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
11991197
has_destroyers_attr = hasattr(fgraph, "has_destroyers")
12001198

12011199
for i in range(len(fgraph.outputs)):
1200+
original_out = fgraph.outputs[i]
1201+
[output_client] = [
1202+
cl
1203+
for cl, _ in fgraph.clients[original_out]
1204+
if isinstance(cl.op, Output) and cl.op.idx == i
1205+
]
1206+
12021207
views_of_output_i = set()
1203-
view_tree_set(fgraph, alias_root(fgraph.outputs[i]), views_of_output_i)
1208+
view_tree_set(fgraph, alias_root(original_out), views_of_output_i)
12041209
copied = False
12051210
# do not allow outputs to be aliased
12061211
for j in range(i + 1, len(fgraph.outputs)):
@@ -1209,16 +1214,16 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
12091214
if fgraph.outputs[j] in views_of_output_i:
12101215
if wrapped_outputs[i].borrow and wrapped_outputs[j].borrow:
12111216
fgraph.change_node_input(
1212-
"output", i, view_op(fgraph.outputs[i]), reason=reason
1217+
output_client, 0, view_op(original_out), reason=reason
12131218
)
12141219
else:
12151220
fgraph.change_node_input(
1216-
"output", i, deep_copy_op(fgraph.outputs[i]), reason=reason
1221+
output_client, 0, deep_copy_op(original_out), reason=reason
12171222
)
12181223
copied = True
12191224
break
12201225

1221-
if not copied:
1226+
if not copied: # no-break
12221227
for input_j in all_graph_inputs:
12231228
# do not allow outputs to be aliased to an inputs (j), unless
12241229
# a) that j'th input has been 'destroyed' by
@@ -1236,33 +1241,33 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
12361241
j = fgraph.inputs.index(input_j)
12371242
if wrapped_outputs[i].borrow and wrapped_inputs[j].borrow:
12381243
fgraph.change_node_input(
1239-
"output",
1240-
i,
1241-
view_op(fgraph.outputs[i]),
1244+
output_client,
1245+
0,
1246+
view_op(original_out),
12421247
reason=reason,
12431248
)
12441249
break
12451250
else:
12461251
fgraph.change_node_input(
1247-
"output",
1248-
i,
1249-
deep_copy_op(fgraph.outputs[i]),
1252+
output_client,
1253+
0,
1254+
deep_copy_op(original_out),
12501255
reason=reason,
12511256
)
12521257
break
12531258
elif wrapped_outputs[i].borrow:
12541259
fgraph.change_node_input(
1255-
"output",
1256-
i,
1257-
view_op(fgraph.outputs[i]),
1260+
output_client,
1261+
0,
1262+
view_op(original_out),
12581263
reason=reason,
12591264
)
12601265
break
12611266
else:
12621267
fgraph.change_node_input(
1263-
"output",
1264-
i,
1265-
deep_copy_op(fgraph.outputs[i]),
1268+
output_client,
1269+
0,
1270+
deep_copy_op(original_out),
12661271
reason=reason,
12671272
)
12681273
break

pytensor/compile/profiling.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,17 @@
1616
import time
1717
from collections import defaultdict
1818
from contextlib import contextmanager
19-
from typing import TYPE_CHECKING, Any, Union
19+
from typing import Any, Union
2020

2121
import numpy as np
2222

2323
import pytensor
2424
from pytensor.configdefaults import config
2525
from pytensor.graph.basic import Apply, Constant, Variable
26+
from pytensor.graph.fg import FunctionGraph, Output
2627
from pytensor.link.utils import get_destroy_dependencies
2728

2829

29-
if TYPE_CHECKING:
30-
from pytensor.graph.fg import FunctionGraph
31-
32-
3330
@contextmanager
3431
def extended_open(filename, mode="r"):
3532
if filename == "<stdout>":
@@ -1055,7 +1052,7 @@ def count_minimum_peak(node_list, fgraph, nodes_mem):
10551052
executable_nodes = set()
10561053
for var in fgraph.inputs:
10571054
for c, _ in fgraph.clients[var]:
1058-
if c != "output":
1055+
if not isinstance(c.op, Output):
10591056
deps = c.inputs + destroy_dependencies[c]
10601057
if all(compute_map[v][0] for v in deps):
10611058
executable_nodes.add(c)
@@ -1183,7 +1180,7 @@ def min_memory_generator(executable_nodes, viewed_by, view_of):
11831180

11841181
for var in node.outputs:
11851182
for c, _ in fgraph.clients[var]:
1186-
if c != "output":
1183+
if not isinstance(c.op, Output):
11871184
deps = c.inputs + destroy_dependencies[c]
11881185
if all(compute_map[v][0] for v in deps):
11891186
new_exec_nodes.add(c)

pytensor/graph/destroyhandler.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytensor.configdefaults import config
1212
from pytensor.graph.basic import Constant
1313
from pytensor.graph.features import AlreadyThere, Bookkeeper
14+
from pytensor.graph.fg import Output
1415
from pytensor.graph.utils import InconsistencyError
1516
from pytensor.misc.ordered_set import OrderedSet
1617

@@ -401,8 +402,6 @@ def has_destroyers(protected_list):
401402
def recursive_destroys_finder(protected_var):
402403
# protected_var is the idx'th input of app.
403404
for app, idx in fgraph.clients[protected_var]:
404-
if app == "output":
405-
continue
406405
destroy_maps = app.op.destroy_map.values()
407406
# If True means that the apply node, destroys the protected_var.
408407
if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
@@ -575,10 +574,10 @@ def on_prune(self, fgraph, app, reason):
575574

576575
def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
577576
"""
578-
app.inputs[i] changed from old_r to new_r.
577+
node.inputs[i] changed from old_r to new_r.
579578
580579
"""
581-
if app == "output":
580+
if isinstance(app.op, Output):
582581
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
583582
# considered 'outputs' of the graph.
584583
pass

0 commit comments

Comments
 (0)