@@ -597,27 +597,22 @@ def eval(self, inputs_to_values=None):
597
597
if inputs_to_values is None :
598
598
inputs_to_values = {}
599
599
600
- def convert_string_keys_to_variables ():
601
- process_input_to_values = {}
602
- for i in inputs_to_values :
603
- if isinstance (i , str ):
604
- nodes_with_matching_names = get_var_by_name ([self ], i )
605
- length_of_nodes_with_matching_names = len (nodes_with_matching_names )
606
- if length_of_nodes_with_matching_names == 0 :
607
- raise Exception (f"{ i } not found in graph" )
608
- else :
609
- if length_of_nodes_with_matching_names > 1 :
600
+ def convert_string_keys_to_variables (input_to_values ):
601
+ new_input_to_values = {}
602
+ for key , value in inputs_to_values .items ():
603
+ if isinstance (key , str ):
604
+ matching_vars = get_var_by_name ([self ], key )
605
+ if not matching_vars :
606
+ raise Exception (f"{ key } not found in graph" )
607
+ elif len (matching_vars ) > 1 :
610
608
raise Exception (
611
- f"Found { length_of_nodes_with_matching_names } pytensor variables with name { i } "
609
+ f"Found multiple variables with name { key } "
612
610
)
613
- process_input_to_values [
614
- nodes_with_matching_names [0 ]
615
- ] = inputs_to_values [i ]
611
+ new_input_to_values [matching_vars [0 ]] = value
616
612
else :
617
- process_input_to_values [i ] = inputs_to_values [i ]
618
- return process_input_to_values
619
-
620
- inputs_to_values = convert_string_keys_to_variables ()
613
+ new_input_to_values [key ] = value
614
+ return new_input_to_values
615
+ inputs_to_values = convert_string_keys_to_variables (inputs_to_values )
621
616
622
617
if not hasattr (self , "_fn_cache" ):
623
618
self ._fn_cache = dict ()
0 commit comments