Skip to content

Commit 707ce87

Browse files
committed
Replace str "output" by a dummy Op in the clients of the FunctionGraph
1 parent 2143d85 commit 707ce87

File tree

18 files changed

+173
-180
lines changed

18 files changed

+173
-180
lines changed

pytensor/compile/debugmode.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pytensor.graph.basic import Variable, io_toposort
3131
from pytensor.graph.destroyhandler import DestroyHandler
3232
from pytensor.graph.features import AlreadyThere, BadOptimization
33+
from pytensor.graph.fg import Output
3334
from pytensor.graph.op import HasInnerGraph, Op
3435
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
3536
from pytensor.link.basic import Container, LocalLinker
@@ -628,7 +629,9 @@ def _is_used_in_graph(fgraph, var):
628629
True if `var` is used by another node in the graph.
629630
630631
"""
631-
return not (fgraph.clients[var] == [("output", 1)] or fgraph.clients[var] == [])
632+
return any(
633+
client for client, _ in fgraph.clients[var] if not isinstance(client.op, Output)
634+
)
632635

633636

634637
def _check_strides_match(a, b, warn_err, op):
@@ -978,7 +981,7 @@ def _check_preallocated_output(
978981
# disable memory checks in that mode, since they were already run.
979982
try:
980983
changed_inner_mode = False
981-
if isinstance(getattr(node, "op", None), HasInnerGraph):
984+
if isinstance(node.op, HasInnerGraph):
982985
fn = node.op.fn
983986
if not fn or not hasattr(fn, "maker") or not hasattr(fn.maker, "mode"):
984987
_logger.warning(f"Expected pytensor function not found in {node.op}.fn")
@@ -1133,18 +1136,14 @@ class _FunctionGraphEvent:
11331136

11341137
def __init__(self, kind, node, idx=None, reason=None):
11351138
self.kind = kind
1136-
if node == "output":
1137-
self.node = "output"
1138-
self.op = "output"
1139-
else:
1140-
self.node = node
1141-
self.op = node.op
1139+
self.node = node
1140+
self.op = node.op
11421141
self.idx = idx
11431142
self.reason = str(reason)
11441143

11451144
def __str__(self):
11461145
if self.kind == "change":
1147-
if self.op != "output":
1146+
if not isinstance(self.op, Output):
11481147
msg = str(len(self.node.inputs))
11491148
else:
11501149
msg = ""

pytensor/compile/function/types.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,6 @@ def view_tree_set(fgraph, v, treeset):
7777
"""
7878
treeset.add(v)
7979
for cl, v_input_pos_to_cl in fgraph.clients[v]:
80-
if cl == "output":
81-
continue
8280
vmap = cl.op.view_map
8381
dmap = cl.op.destroy_map
8482
for opos, iposlist in chain(vmap.items(), dmap.items()):
@@ -1199,8 +1197,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
11991197
has_destroyers_attr = hasattr(fgraph, "has_destroyers")
12001198

12011199
for i in range(len(fgraph.outputs)):
1200+
original_out = fgraph.outputs[i]
1201+
output_client = fgraph.get_output_client(i)
1202+
12021203
views_of_output_i = set()
1203-
view_tree_set(fgraph, alias_root(fgraph.outputs[i]), views_of_output_i)
1204+
view_tree_set(fgraph, alias_root(original_out), views_of_output_i)
12041205
copied = False
12051206
# do not allow outputs to be aliased
12061207
for j in range(i + 1, len(fgraph.outputs)):
@@ -1209,16 +1210,16 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
12091210
if fgraph.outputs[j] in views_of_output_i:
12101211
if wrapped_outputs[i].borrow and wrapped_outputs[j].borrow:
12111212
fgraph.change_node_input(
1212-
"output", i, view_op(fgraph.outputs[i]), reason=reason
1213+
output_client, 0, view_op(original_out), reason=reason
12131214
)
12141215
else:
12151216
fgraph.change_node_input(
1216-
"output", i, deep_copy_op(fgraph.outputs[i]), reason=reason
1217+
output_client, 0, deep_copy_op(original_out), reason=reason
12171218
)
12181219
copied = True
12191220
break
12201221

1221-
if not copied:
1222+
if not copied: # no-break
12221223
for input_j in all_graph_inputs:
12231224
# do not allow outputs to be aliased to an inputs (j), unless
12241225
# a) that j'th input has been 'destroyed' by
@@ -1236,33 +1237,33 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
12361237
j = fgraph.inputs.index(input_j)
12371238
if wrapped_outputs[i].borrow and wrapped_inputs[j].borrow:
12381239
fgraph.change_node_input(
1239-
"output",
1240-
i,
1241-
view_op(fgraph.outputs[i]),
1240+
output_client,
1241+
0,
1242+
view_op(original_out),
12421243
reason=reason,
12431244
)
12441245
break
12451246
else:
12461247
fgraph.change_node_input(
1247-
"output",
1248-
i,
1249-
deep_copy_op(fgraph.outputs[i]),
1248+
output_client,
1249+
0,
1250+
deep_copy_op(original_out),
12501251
reason=reason,
12511252
)
12521253
break
12531254
elif wrapped_outputs[i].borrow:
12541255
fgraph.change_node_input(
1255-
"output",
1256-
i,
1257-
view_op(fgraph.outputs[i]),
1256+
output_client,
1257+
0,
1258+
view_op(original_out),
12581259
reason=reason,
12591260
)
12601261
break
12611262
else:
12621263
fgraph.change_node_input(
1263-
"output",
1264-
i,
1265-
deep_copy_op(fgraph.outputs[i]),
1264+
output_client,
1265+
0,
1266+
deep_copy_op(original_out),
12661267
reason=reason,
12671268
)
12681269
break

pytensor/compile/profiling.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,17 @@
1616
import time
1717
from collections import defaultdict
1818
from contextlib import contextmanager
19-
from typing import TYPE_CHECKING, Any, Union
19+
from typing import Any, Union
2020

2121
import numpy as np
2222

2323
import pytensor
2424
from pytensor.configdefaults import config
2525
from pytensor.graph.basic import Apply, Constant, Variable
26+
from pytensor.graph.fg import FunctionGraph, Output
2627
from pytensor.link.utils import get_destroy_dependencies
2728

2829

29-
if TYPE_CHECKING:
30-
from pytensor.graph.fg import FunctionGraph
31-
32-
3330
@contextmanager
3431
def extended_open(filename, mode="r"):
3532
if filename == "<stdout>":
@@ -1055,7 +1052,7 @@ def count_minimum_peak(node_list, fgraph, nodes_mem):
10551052
executable_nodes = set()
10561053
for var in fgraph.inputs:
10571054
for c, _ in fgraph.clients[var]:
1058-
if c != "output":
1055+
if not isinstance(c.op, Output):
10591056
deps = c.inputs + destroy_dependencies[c]
10601057
if all(compute_map[v][0] for v in deps):
10611058
executable_nodes.add(c)
@@ -1183,7 +1180,7 @@ def min_memory_generator(executable_nodes, viewed_by, view_of):
11831180

11841181
for var in node.outputs:
11851182
for c, _ in fgraph.clients[var]:
1186-
if c != "output":
1183+
if not isinstance(c.op, Output):
11871184
deps = c.inputs + destroy_dependencies[c]
11881185
if all(compute_map[v][0] for v in deps):
11891186
new_exec_nodes.add(c)

pytensor/graph/destroyhandler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytensor.configdefaults import config
1212
from pytensor.graph.basic import Constant
1313
from pytensor.graph.features import AlreadyThere, Bookkeeper
14+
from pytensor.graph.fg import Output
1415
from pytensor.graph.utils import InconsistencyError
1516
from pytensor.misc.ordered_set import OrderedSet
1617

@@ -401,8 +402,6 @@ def has_destroyers(protected_list):
401402
def recursive_destroys_finder(protected_var):
402403
# protected_var is the idx'th input of app.
403404
for app, idx in fgraph.clients[protected_var]:
404-
if app == "output":
405-
continue
406405
destroy_maps = app.op.destroy_map.values()
407406
# If True means that the apply node, destroys the protected_var.
408407
if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
@@ -578,7 +577,7 @@ def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
578577
app.inputs[i] changed from old_r to new_r.
579578
580579
"""
581-
if app == "output":
580+
if isinstance(app.op, Output):
582581
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
583582
# considered 'outputs' of the graph.
584583
pass

0 commit comments

Comments
 (0)