Skip to content

Commit

Permalink
Update graph.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jmduarte authored Dec 14, 2023
1 parent 9f1b0f1 commit c80c8f2
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions hls4ml/model/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def __init__(self, config, layer_list, inputs=None, outputs=None):
input_layers = inputs if inputs is not None else [layer_list[0]['name']]
output_layers = outputs if outputs is not None else [layer_list[-1]['name']]
self.inputs = self._find_output_variable_names(layer_list, input_layers)
if sorted(self.inputs) != sorted(input_layers):
if self.inputs != input_layers:
raise RuntimeError(
"Currently only support the case when input variables and input layer names match\n"
+ f"Input layers = {input_layers}, input_vars = {self.inputs}"
Expand All @@ -362,9 +362,12 @@ def __init__(self, config, layer_list, inputs=None, outputs=None):
self.apply_flow(flow)

def _find_output_variable_names(self, layer_list, layer_names):
"""Given a list of all layers, and a list input/output names, find the names of the their outputs that will be used
"""Given a list of all layers, and a list input/output names, find the names of their outputs that will be used
as the name of the output variables."""
inout_nodes = [node for node in layer_list if node['name'] in layer_names]
inout_nodes = []
for layer_name in layer_names:
for node in layer_list:
if node['name'] == layer_name: inout_nodes.append(node)
all_node_output_names = [node['outputs'] if 'outputs' in node else [node['name']] for node in inout_nodes]
return [output for node_output_names in all_node_output_names for output in node_output_names] # to flatten

Expand Down

0 comments on commit c80c8f2

Please sign in to comment.