Skip to content

Commit

Permalink
Fixes for quantised RNNs
Browse files Browse the repository at this point in the history
  • Loading branch information
bo3z committed Jan 22, 2025
1 parent 92c8880 commit bcfddcd
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 30 deletions.
26 changes: 23 additions & 3 deletions hls4ml/backends/vivado/passes/recurrent_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# recurrent multiplication template

recr_mult_config_template = """struct config{index} : nnet::dense_config {{
recr_mult_config_template_1 = """struct config{index} : nnet::dense_config {{
static const unsigned n_in = {n_in};
static const unsigned n_out = {n_out};
static const unsigned strategy = nnet::{strategy};
Expand All @@ -22,6 +22,24 @@
using product = nnet::product::{product_type}<x_T, y_T>;
}};\n"""

recr_mult_config_template_2 = """struct config{index} : nnet::dense_config {{
static const unsigned n_in = {n_in};
static const unsigned n_out = {n_out};
static const unsigned strategy = nnet::{strategy};
static const unsigned reuse_factor = {reuse};
static const unsigned n_zeros = {nzeros};
static const unsigned n_nonzeros = {nonzeros};
static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor;
static const bool store_weights_in_bram = false;
typedef {accum_t.name} accum_t;
typedef {recurrent_bias_t.name} bias_t;
typedef {recurrent_weight_t.name} weight_t;
template<class data_T, class res_T, class CONFIG_T>
using kernel = nnet::{dense_function}<data_T, res_T, CONFIG_T>;
template<class x_T, class y_T>
using product = nnet::product::{product_type}<x_T, y_T>;
}};\n"""

# activation templates

activ_config_template = """struct {type}_config{index} : nnet::activ_config {{
Expand All @@ -45,7 +63,9 @@
recr_config_template = """struct config{index} : nnet::{recr_type}_config {{
typedef {accum_t.name} accum_t;
typedef {weight_t.name} weight_t; // Matrix
typedef {recurrent_weight_t.name} recurrent_weight_t; // Matrix
typedef {bias_t.name} bias_t; // Vector
typedef {recurrent_bias_t.name} recurrent_bias_t; // Vector
typedef {config_mult_t1} mult_config1;
typedef {config_mult_t2} mult_config2;
typedef {recr_act_t} ACT_CONFIG_{RECR_TYPE};
Expand Down Expand Up @@ -77,8 +97,8 @@ def __init__(self):
self.template = recr_config_template
self.act_template = activ_config_template
self.recr_act_template = recr_activ_config_template
self.mult1_template = recr_mult_config_template
self.mult2_template = recr_mult_config_template
self.mult1_template = recr_mult_config_template_1
self.mult2_template = recr_mult_config_template_2

def format(self, node):
params = self._default_config_params(node)
Expand Down
2 changes: 1 addition & 1 deletion hls4ml/converters/keras/qkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def parse_qrnn_layer(keras_layer, input_names, input_shapes, data_reader):
layer, output_shape = parse_rnn_layer(keras_layer, input_names, input_shapes, data_reader)

layer['weight_quantizer'] = get_quantizer_from_config(keras_layer, 'kernel')
layer['recurrent_quantizer'] = get_quantizer_from_config(keras_layer, 'recurrent')
layer['recurrent_weight_quantizer'] = get_quantizer_from_config(keras_layer, 'recurrent')
layer['bias_quantizer'] = get_quantizer_from_config(keras_layer, 'bias')

return layer, output_shape
Expand Down
57 changes: 31 additions & 26 deletions hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ namespace nnet {
struct lstm_config {
// Internal data type definitions
typedef float weight_t;
typedef float recurrent_weight_t;
typedef float bias_t;
typedef float recurrent_bias_t;
typedef float accum_t;

// Layer Sizes
static const unsigned n_in = 2;
Expand Down Expand Up @@ -47,9 +50,9 @@ struct lstm_config {
template <class data_T, class res_T, typename CONFIG_T>
void lstm(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state],
res_T s_newstate[CONFIG_T::n_state], typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in],
typename CONFIG_T::weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
typename CONFIG_T::recurrent_weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4],
typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 4]) {
typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 4]) {
// Initialize the state variable -- will maintain state between function calls

typename CONFIG_T::accum_t tmpres[CONFIG_T::n_state * 4];
Expand Down Expand Up @@ -86,11 +89,11 @@ void lstm(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG
inputacc_c[iacc] = tmpres[index] + tmpres_state[index];
}

CONFIG_T::template activation_recr<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::ACT_CONFIG_LSTM>::activation(
CONFIG_T::template activation_recr<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t, typename CONFIG_T::ACT_CONFIG_LSTM>::activation(
inputacc_ifo, tmpres_ifo);

// Now for the confusion matrix
CONFIG_T::template activation<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::ACT_CONFIG_T>::activation(
CONFIG_T::template activation<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t, typename CONFIG_T::ACT_CONFIG_T>::activation(
inputacc_c, tmpres_c);

// Operation: s=g*i+sold*f (update state with buffer to avoid timing issues)
Expand All @@ -99,7 +102,7 @@ void lstm(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG
s_newstate[iacc] = tmpres_c[iacc] * tmpres_ifo[iacc] + s_newstate[iacc] * tmpres_ifo[iacc + (CONFIG_T::n_state)];
}
// Operation: h=act(s)*o
CONFIG_T::template activation<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::ACT_CONFIG_T>::activation(
CONFIG_T::template activation<res_T, typename CONFIG_T::accum_t, typename CONFIG_T::ACT_CONFIG_T>::activation(
s_newstate, s_actstate);

for (int iacc = 0; iacc < CONFIG_T::n_state; iacc++) {
Expand All @@ -112,9 +115,9 @@ template <class data_T, class res_T, typename CONFIG_T>
void lstm_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state],
res_T s_newstate[CONFIG_T::n_state],
typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in],
typename CONFIG_T::weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
typename CONFIG_T::recurrent_weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4],
typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 4]) {
typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 4]) {
static res_T h_state[CONFIG_T::n_state];
static res_T s_state[CONFIG_T::n_state];
// Initialize the state variable -- will maintain state between function calls
Expand Down Expand Up @@ -163,11 +166,11 @@ void lstm_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate
inputacc_c[iacc] = tmpres[index] + tmpres_state[index];
}

CONFIG_T::template activation_recr<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::ACT_CONFIG_LSTM>::activation(
CONFIG_T::template activation_recr<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t, typename CONFIG_T::ACT_CONFIG_LSTM>::activation(
inputacc_ifo, tmpres_ifo);

// Now for the confusion matrix
CONFIG_T::template activation<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::ACT_CONFIG_T>::activation(
CONFIG_T::template activation<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t, typename CONFIG_T::ACT_CONFIG_T>::activation(
inputacc_c, tmpres_c);

// Operation: s=g*i+sold*f (update state with buffer to avoid timing issues)
Expand All @@ -177,7 +180,7 @@ void lstm_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate
s_newstate[iacc] = s_state[iacc];
}
// Operation: h=act(s)*o
CONFIG_T::template activation<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::ACT_CONFIG_T>::activation(
CONFIG_T::template activation<res_T, typename CONFIG_T::accum_t, typename CONFIG_T::ACT_CONFIG_T>::activation(
s_state, s_actstate);

for (int iacc = 0; iacc < CONFIG_T::n_state; iacc++) {
Expand All @@ -190,9 +193,9 @@ void lstm_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate
template <class data_T, class res_T, typename CONFIG_T>
void lstm_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CONFIG_T::n_sequence_out * CONFIG_T::n_state],
typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in],
typename CONFIG_T::weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
typename CONFIG_T::recurrent_weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4],
typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 4]) {
typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 4]) {

res_T h_newstate[CONFIG_T::n_state];
res_T s_newstate[CONFIG_T::n_state];
Expand Down Expand Up @@ -235,9 +238,9 @@ void lstm_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CO
template <class data_T, class res_T, typename CONFIG_T>
void lstm_stack(hls::stream<data_T> &data_stream, hls::stream<res_T> &res_stream,
typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in],
typename CONFIG_T::weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
typename CONFIG_T::recurrent_weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state],
typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4],
typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 4]) {
typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 4]) {

typename res_T::value_type h_newstate[CONFIG_T::n_state];
typename res_T::value_type s_newstate[CONFIG_T::n_state];
Expand Down Expand Up @@ -300,7 +303,9 @@ void lstm_stack(hls::stream<data_T> &data_stream, hls::stream<res_T> &res_stream
struct gru_config {
// Internal data type definitions
typedef float weight_t;
typedef float recurrent_weight_t;
typedef float bias_t;
typedef float recurrent_bias_t;
typedef float accum_t;

// Layer Sizes
Expand All @@ -327,9 +332,9 @@ template <class data_T, class res_T, typename CONFIG_T>
void gru(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state],
typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in], // TODO - Check the layout of the param
// weights - refer page in copy!!
typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
typename CONFIG_T::recurrent_weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3],
typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 3]) {
typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 3]) {
// Initialize the state variable -- will maintain state between function calls
typename CONFIG_T::accum_t tmpres[CONFIG_T::n_state * 3];
typename CONFIG_T::accum_t tmpres_state_zr[CONFIG_T::n_state * 3];
Expand Down Expand Up @@ -361,7 +366,7 @@ void gru(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_
}

// Activation function Sub layer -- START
CONFIG_T::template activation_recr<typename CONFIG_T::accum_t, typename CONFIG_T::weight_t,
CONFIG_T::template activation_recr<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t,
typename CONFIG_T::ACT_CONFIG_GRU>::activation(inputacc_zr, tmpres_zr);

// Activation function Sub layer -- END
Expand All @@ -383,7 +388,7 @@ void gru(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_
}

// Now run the activation on this guy
CONFIG_T::template activation<typename CONFIG_T::accum_t, typename CONFIG_T::weight_t,
CONFIG_T::template activation<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t,
typename CONFIG_T::ACT_CONFIG_T>::activation(inputacc_h, tmpres_h);

// Mix the stat with the previous state
Expand All @@ -400,9 +405,9 @@ void gru(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_
template <class data_T, class res_T, typename CONFIG_T>
void gru_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state],
typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in],
typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
typename CONFIG_T::recurrent_weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3],
typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 3]) {
typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 3]) {
// Initialize the state variable -- will maintain state between function calls

static res_T h_state[CONFIG_T::n_state];
Expand Down Expand Up @@ -444,7 +449,7 @@ void gru_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[
}

// Activation function Sub layer -- START
CONFIG_T::template activation_recr<typename CONFIG_T::accum_t, typename CONFIG_T::weight_t,
CONFIG_T::template activation_recr<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t,
typename CONFIG_T::ACT_CONFIG_GRU>::activation(inputacc_zr, tmpres_zr);

// Activation function Sub layer -- END
Expand All @@ -466,7 +471,7 @@ void gru_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[
}

// Now run the activation on this guy
CONFIG_T::template activation<typename CONFIG_T::accum_t, typename CONFIG_T::weight_t,
CONFIG_T::template activation<typename CONFIG_T::accum_t, typename CONFIG_T::accum_t,
typename CONFIG_T::ACT_CONFIG_T>::activation(inputacc_h, tmpres_h);

// Mix the stat with the previous state
Expand All @@ -484,9 +489,9 @@ void gru_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[
template <class data_T, class res_T, typename CONFIG_T>
void gru_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CONFIG_T::n_sequence_out * CONFIG_T::n_state],
typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in],
typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
typename CONFIG_T::recurrent_weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3],
typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 3]) {
typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 3]) {

res_T h_state[CONFIG_T::n_state];
data_T data_in[CONFIG_T::n_in];
Expand Down Expand Up @@ -525,9 +530,9 @@ void gru_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CON
template <class data_T, class res_T, typename CONFIG_T>
void gru_stack(hls::stream<data_T> &data_stream, hls::stream<res_T> &res_stream,
typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in],
typename CONFIG_T::weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
typename CONFIG_T::recurrent_weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state],
typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3],
typename CONFIG_T::bias_t param_br[CONFIG_T::n_state * 3]) {
typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 3]) {

typename res_T::value_type h_newstate[CONFIG_T::n_state];
#pragma HLS ARRAY_PARTITION variable=h_newstate complete
Expand Down

0 comments on commit bcfddcd

Please sign in to comment.