Skip to content

Commit

Permalink
Merge pull request #908 from calad0i/quartus_multi_out_with_stream_fix
Browse files Browse the repository at this point in the history
Quartus multi out with stream fix
  • Loading branch information
jmitrevs authored Dec 15, 2023
2 parents 9278520 + 4660232 commit 9be5cba
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 19 deletions.
4 changes: 2 additions & 2 deletions hls4ml/model/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def make_node(self, kind, name, attributes, inputs, outputs=None):
node = layer_cls(self, name, attributes, inputs, outputs)
for o in node.outputs:
out_var = node.get_output_variable(output_name=o)
if o in self.outputs:
if len(self.outputs) == 1 and o in self.outputs:
out_var.type.name = 'result_t'
self.output_vars[o] = out_var
return node
Expand Down Expand Up @@ -608,7 +608,7 @@ def get_input_variables(self):
return variables

def register_output_variable(self, out_name, variable):
if out_name in self.outputs:
if len(self.outputs) == 1 and out_name in self.outputs:
variable.type.name = 'result_t'
self.output_vars[out_name] = variable

Expand Down
14 changes: 9 additions & 5 deletions hls4ml/writer/quartus_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def write_project_cpp(self, model):
for inp in model_inputs:
newline += indent + f'stream_in<{inp.type.name}> &{inp.name}_stream,\n'
for out in model_outputs:
newline += indent + f'stream_out<{out.type.name}> &{out.name}_stream'
newline += indent + f'stream_out<{out.type.name}> &{out.name}_stream,\n'
newline = newline[:-2] # Remove the tailing ',\n'
if model_brams:
newline += ',\n' + brams_str
newline += '\n) {\n'
Expand All @@ -191,7 +192,8 @@ def write_project_cpp(self, model):
for inp in model_inputs:
newline += indent + f'stream_in<{inp.type.name}> &{inp.name}_stream,\n'
for out in model_outputs:
newline += indent + f'stream_out<{out.type.name}> &{out.name}_stream'
newline += indent + f'stream_out<{out.type.name}> &{out.name}_stream,\n'
newline = newline[:-2] # Remove the tailing ',\n'\
if model_brams:
newline += ',\n' + brams_str
newline += '\n) {\n'
Expand Down Expand Up @@ -277,7 +279,7 @@ def write_project_cpp(self, model):
newline += indent + f' {out.type.name} tmp = {out.name}.read();\n'
newline += indent + f' {out.name}_stream.write(tmp);\n'
newline += indent + '}\n'
newline += '}\n'
newline += '}\n'
else:
newline = line
newline += indent + 'return outputs;\n'
Expand Down Expand Up @@ -330,7 +332,8 @@ def write_project_header(self, model):
for inp in model_inputs:
newline += indent + f'stream_in<{inp.type.name}> &{inp.name}_stream,\n'
for out in model_outputs:
newline += indent + f'stream_out<{out.type.name}> &{out.name}_stream'
newline += indent + f'stream_out<{out.type.name}> &{out.name}_stream,\n'
newline = newline[:-2] # Remove the tailing ',\n'
if model_brams:
newline += ',\n' + brams_str
newline += '\n);\n'
Expand All @@ -350,7 +353,8 @@ def write_project_header(self, model):
for inp in model_inputs:
newline += indent + f'stream_in<{inp.type.name}> &{inp.name}_stream,\n'
for out in model_outputs:
newline += indent + f'stream_out<{out.type.name}> &{out.name}_stream'
newline += indent + f'stream_out<{out.type.name}> &{out.name}_stream,\n'
newline = newline[:-2] # Remove the tailing ',\n'
if model_brams:
newline += ',\n' + brams_str
newline += '\n);\n'
Expand Down
52 changes: 52 additions & 0 deletions test/pytest/test_multiout_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from pathlib import Path

import numpy as np
import pytest
from keras.layers import Dense
from tensorflow import keras

from hls4ml.converters import convert_from_keras_model

test_root_path = Path(__file__).parent


@pytest.fixture(scope='module')
def model():
inp = keras.Input(shape=(10,))
x = Dense(10, name='dense1')(inp)
y = Dense(10, name='dense2')(inp)
model = keras.Model(inp, [x, y])
return model


@pytest.fixture(scope='module')
def data():
X = np.random.normal(0, 1, (1000, 10))
X = np.clip(X, -16, 15)
return X


@pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Vitis'])
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
def test_multi_clone(model, data, backend: str, io_type: str):
output_dir = str(test_root_path / f'hls4mlprj_multiout_network_{backend}_{io_type}')
hls_config = {'Model': {'Precision': 'fixed<32,5>', 'ReuseFactor': 1}}
layer_config = {
'dense1': {'Precision': {'result': 'fixed<35,5>'}},
'dense2': {'Precision': {'result': 'fixed<40,5>'}},
'dense1_linear': {'Precision': {'result': 'fixed<35,5>'}},
'dense2_linear': {'Precision': {'result': 'fixed<40,5>'}},
}
hls_config['LayerName'] = layer_config
model_hls = convert_from_keras_model(
model, backend=backend, output_dir=output_dir, hls_config=hls_config, io_type=io_type
)

assert model_hls.graph['dense1'].attributes['result_t'] != model_hls.graph['dense2'].attributes['result_t']

model_hls.compile()
r_hls = model_hls.predict(data)
r_keras = [x.numpy() for x in model(data)]

assert np.allclose(r_hls[0], r_keras[0], atol=1e-5, rtol=0)
assert np.allclose(r_hls[1], r_keras[1], atol=1e-5, rtol=0)
13 changes: 1 addition & 12 deletions test/pytest/test_stream_multi_clone.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import os
import random
from pathlib import Path

import numpy as np
import pytest
import tensorflow as tf
from keras.layers import Add, Dense
from tensorflow import keras

Expand All @@ -15,13 +12,6 @@

@pytest.fixture(scope='module')
def model():
seed = 42
os.environ['RANDOM_SEED'] = f'{seed}'
np.random.seed(seed)
tf.random.set_seed(seed)
tf.get_logger().setLevel('ERROR')
random.seed(seed)

inp = keras.Input(shape=(10,))
x = Dense(10)(inp)
y = Dense(10)(inp)
Expand All @@ -35,8 +25,7 @@ def model():

@pytest.fixture(scope='module')
def data():
rng = np.random.RandomState(42)
X = rng.normal(0, 1, (1000, 10))
X = np.random.normal(0, 1, (1000, 10))
X = np.clip(X, -16, 15)
return X

Expand Down

0 comments on commit 9be5cba

Please sign in to comment.