Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leaky Quantized ReLU fix #961

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4e483fc
First attempt in making a layer generator for hls4ml to make adding a…
DrWatt Mar 29, 2023
0e973e6
Added longnamelayer to associate hls layer to TF layer and added type…
DrWatt Apr 3, 2023
ebbac33
Merge branch 'fastmachinelearning:main' into master
DrWatt Apr 3, 2023
e1e9157
Merge branch 'fastmachinelearning:main' into master
DrWatt Jan 18, 2024
49a30fe
Fixed a missing casting in vivado_accelerator and added a skip for li…
DrWatt Jan 18, 2024
e73ddc9
Fixed missing _ap_ in typestring sometimes not supported by the compiler
DrWatt Jan 18, 2024
6db982a
Fixed issue with quantized_relu activation whose implementation did n…
DrWatt Jan 24, 2024
6ea8ab0
Removed a couple of prints used in debugging
DrWatt Jan 24, 2024
a13c978
Merge branch 'fastmachinelearning:main' into master
DrWatt Jan 24, 2024
6640169
testing test
DrWatt Jan 25, 2024
6028bb6
pre-commit ran manually on all files
DrWatt Jan 25, 2024
6c72dd9
Fixed issues arised due to an old version of hls4ml used when writing…
DrWatt Jan 25, 2024
61ae795
Fixed hls header name
DrWatt Jan 25, 2024
7b4047d
first hls compilation errors solved in new layer template
DrWatt Jan 26, 2024
235aa5a
Readme added for the New Layer tool
DrWatt Jan 26, 2024
99afa33
Added notebook for testing quickly a new layer
DrWatt Jan 26, 2024
7dd760d
Changed name of directory
DrWatt Jan 26, 2024
11371c5
Old directory deleted
DrWatt Jan 26, 2024
eed800d
Missing outputshape description in Readme
DrWatt Jan 26, 2024
fde9841
Removing stuff that do not concern the bug fix
DrWatt Jan 29, 2024
372d35d
Uncommented line in profiling
DrWatt Jan 29, 2024
928f383
Test suite added for quantized leaky relu
DrWatt Jan 29, 2024
3b561b4
Changed output_dir of tests to follow the same configuration as the o…
DrWatt Jan 29, 2024
da65388
Merge branch 'fastmachinelearning:main' into qrelu_fix
DrWatt Jan 30, 2024
f1d187d
Removed 'ap_' string from FixedPrecisionType and fixed compiling err…
DrWatt Feb 8, 2024
700163b
Added if clause in the quantized_relu config which makes it signed or…
DrWatt Feb 8, 2024
3f48509
Renamed QLeakyReLU to LeakyReLU and subsequent changes
DrWatt Feb 9, 2024
d209f5b
Reversed changes regarding vivado_accelerator backend and writer
DrWatt Feb 9, 2024
46a6e9b
Merge branch 'fastmachinelearning:main' into qrelu_fix
DrWatt Feb 9, 2024
26f32a1
changed description string of qrelu test
DrWatt Feb 9, 2024
d4642e7
Name of qleaky relu test changed
DrWatt Feb 9, 2024
b1a4472
Name of qleaky relu test changed
DrWatt Feb 9, 2024
ecec240
Merge branch 'fastmachinelearning:main' into qrelu_fix
DrWatt Feb 19, 2024
34c8070
Merge branch 'main' into qrelu_fix
DrWatt Mar 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def __init__(self, config, model_inputs, model_outputs):
else:
self.input_bitwidth = config.backend.convert_precision_string(inp_axi_t).width

if out_axi_t == 'float':
if str(out_axi_t) == 'float':
self.output_bitwidth = 32
elif out_axi_t == 'double':
elif str(out_axi_t) == 'double':
self.output_bitwidth = 64
else:
self.output_bitwidth = config.backend.convert_precision_string(out_axi_t).width
Expand Down
5 changes: 4 additions & 1 deletion hls4ml/converters/keras/qkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def parse_qactivation_layer(keras_layer, input_names, input_shapes, data_reader)
]

layer = parse_default_keras_layer(keras_layer, input_names)

activation_config = keras_layer['config']['activation']
quantizer_obj = get_quantizer(activation_config)
activation_config = {}
Expand Down Expand Up @@ -166,6 +165,10 @@ def parse_qactivation_layer(keras_layer, input_names, input_shapes, data_reader)
layer['slope_prec'] = FixedPrecisionType(width=2, integer=0, signed=False)
layer['shift_prec'] = FixedPrecisionType(width=2, integer=0, signed=False)
layer['activation'] = activation_config['class_name'].replace('quantized_', 'hard_')
elif activation_config['class_name'] == 'quantized_relu' and activation_config['config']['negative_slope'] != 0:
layer['class_name'] = 'LeakyReLU'
layer['activation'] = activation_config['class_name'].replace('quantized_', 'leaky_')
layer['activ_param'] = activation_config['config']['negative_slope']
else:
layer['class_name'] = 'Activation'
layer['activation'] = activation_config['class_name'].replace('quantized_', '')
Expand Down
8 changes: 6 additions & 2 deletions hls4ml/converters/keras_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def parse_keras_model(model_arch, reader):
]
# Recurrent layers
recurrent_layers = ['SimpleRNN', 'LSTM', 'GRU']

# All supported layers
supported_layers = get_supported_keras_layers() + skip_layers

Expand Down Expand Up @@ -263,7 +264,6 @@ def parse_keras_model(model_arch, reader):
input_shapes = [output_shape]

keras_class = keras_layer['class_name']

if keras_class in skip_layers:
if 'inbound_nodes' in keras_layer:
name = keras_layer['config']['name']
Expand Down Expand Up @@ -293,7 +293,11 @@ def parse_keras_model(model_arch, reader):
)
)
layer_list.append(layer)
if 'activation' in layer and layer['class_name'] not in activation_layers + recurrent_layers: # + qkeras_layers:
if (
'activation' in layer
and layer['class_name'] not in activation_layers + recurrent_layers
and layer['activation'] != "linear"
): # + qkeras_layers:
act_layer = {}
# Workaround for QKeras activations passed as an argument
if isinstance(layer['activation'], dict):
Expand Down
3 changes: 2 additions & 1 deletion hls4ml/model/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,8 +589,9 @@ def get_ymodel_keras(keras_model, X):
name = layer.name
if (
hasattr(layer, "activation")
and layer.activation.__name__ != "linear"
and layer.activation is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder, in what cases can this actually be None? One has to explicitly define the model like that, against all Keras tutorials.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have found out the if the QBatchNormalization layer is used without specifying an activation function, the default layer.activation is None, which breaks the if statement when the linear check is performed. Now that I am thinking about this again, I am not sure if another case should be added in this situation by forcing a linear activation.

and not isinstance(layer, (keras.layers.Activation, qkeras.qlayers.QActivation))
and layer.activation.__name__ != "linear"
):
tmp_activation = layer.activation
layer.activation = None
Expand Down
7 changes: 5 additions & 2 deletions hls4ml/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,11 @@ def _get_precision_from_quantizer(quantizer):
rnd = "AP_RND_CONV"
overflow = "AP_SAT"
if quantizer['class_name'] in ('quantized_relu', 'quantized_relu_po2'):
signed = False
integer -= 1
if quantizer['config']['negative_slope'] != 0.0:
signed = True
else:
signed = False
integer -= 1
elif quantizer['class_name'] == 'quantized_tanh':
overflow = "AP_SAT_SYM" if quantizer['config']['symmetric'] else "AP_SAT"
integer = 1
Expand Down
62 changes: 62 additions & 0 deletions test/pytest/test_q_leaky_relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from pathlib import Path

import numpy as np
import pytest
from qkeras.qlayers import QActivation
from qkeras.quantizers import quantized_relu
from tensorflow.keras.models import Sequential

import hls4ml

test_root_path = Path(__file__).parent


def randX(batch_size, N):
return np.random.rand(batch_size, N)


@pytest.fixture(scope='module')
def randX_1000_1():
return randX(1000, 1)


@pytest.mark.parametrize(
'quantizer',
[
(quantized_relu(4, negative_slope=0.5)),
(quantized_relu(4, 2, negative_slope=0.5)),
(quantized_relu(8, negative_slope=0.125)),
(quantized_relu(8, 4, negative_slope=1.0)),
(quantized_relu(10, negative_slope=0.25)),
(quantized_relu(10, 5, negative_slope=0.5)),
(quantized_relu(10, 5, negative_slope=0.25)),
],
)
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
def test_quantizer(randX_1000_1, quantizer, backend, io_type):
'''
Test a single quantizer as an Activation function.
Using numpy's assert_allclose to check that the differnce between the converted layer and qkeras' is lower than of 10^-5.
'''
X = randX_1000_1
X = np.round(X * 2**10) * 2**-10 # make it an exact ap_fixed<16,6>
model = Sequential()
model.add(QActivation(input_shape=(1,), activation=quantizer, name='quantizer'))
model.compile()

config = hls4ml.utils.config_from_keras_model(model, granularity='name')
output_dir = str(
test_root_path
/ 'hls4mlprj_qkeras_quantizer_{}_{}_{}_{}_{}'.format(
quantizer.__class__.__name__, quantizer.bits, quantizer.integer, backend, io_type
)
)
hls_model = hls4ml.converters.convert_from_keras_model(
model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type
)
hls_model.compile()

y_qkeras = model.predict(X)
y_hls4ml = hls_model.predict(X)
np.testing.assert_allclose(y_hls4ml, y_qkeras, rtol=1e-5, atol=0)
1 change: 1 addition & 0 deletions test/pytest/test_qkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def randX_1000_1():
(quantized_relu(8, 4)),
(quantized_relu(10)),
(quantized_relu(10, 5)),
(quantized_relu(10, 5, negative_slope=0.25)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't you already testing this in the other test?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to create the test suite requested for the PR I added the line in the already present qkeras test, however due to problems with the environment requested for the pytests, some issues not concerning my fix came up. So, I have written a separate script just to show that the "new" layer passes an accuracy test. If the test_qkeras.py test passes when called in the automatic workflow, my other script can be removed.

],
)
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
Expand Down