Skip to content

Commit 6e2efaa

Browse files
committed
Modified the code structure based on suggestions
1 parent d889182 commit 6e2efaa

File tree

2 files changed

+14
-19
lines changed

2 files changed

+14
-19
lines changed

pytensor/graph/basic.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -597,27 +597,22 @@ def eval(self, inputs_to_values=None):
597597
if inputs_to_values is None:
598598
inputs_to_values = {}
599599

600-
def convert_string_keys_to_variables():
601-
process_input_to_values = {}
602-
for i in inputs_to_values:
603-
if isinstance(i, str):
604-
nodes_with_matching_names = get_var_by_name([self], i)
605-
length_of_nodes_with_matching_names = len(nodes_with_matching_names)
606-
if length_of_nodes_with_matching_names == 0:
607-
raise Exception(f"{i} not found in graph")
608-
else:
609-
if length_of_nodes_with_matching_names > 1:
600+
def convert_string_keys_to_variables(input_to_values):
601+
new_input_to_values = {}
602+
for key, value in inputs_to_values.items():
603+
if isinstance(key, str):
604+
matching_vars = get_var_by_name([self], key)
605+
if not matching_vars:
606+
raise Exception(f"{key} not found in graph")
607+
elif len(matching_vars) > 1:
610608
raise Exception(
611-
f"Found {length_of_nodes_with_matching_names} pytensor variables with name {i}"
609+
f"Found multiple variables with name {key}"
612610
)
613-
process_input_to_values[
614-
nodes_with_matching_names[0]
615-
] = inputs_to_values[i]
611+
new_input_to_values[matching_vars[0]] = value
616612
else:
617-
process_input_to_values[i] = inputs_to_values[i]
618-
return process_input_to_values
619-
620-
inputs_to_values = convert_string_keys_to_variables()
613+
new_input_to_values[key] = value
614+
return new_input_to_values
615+
inputs_to_values = convert_string_keys_to_variables(inputs_to_values)
621616

622617
if not hasattr(self, "_fn_cache"):
623618
self._fn_cache = dict()

tests/graph/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def test_eval_errors_having_mulitple_variables_same_name(self):
310310
e = scalars("e")
311311
t = e + 1
312312
t.name = "e"
313-
with pytest.raises(Exception, match="Found 2 pytensor variables with name e"):
313+
with pytest.raises(Exception, match="Found multiple variables with name e"):
314314
t.eval({"e": 1})
315315

316316
def test_eval_errors_with_no_name_exists(self):

0 commit comments

Comments
 (0)