From 5b8135bc047eaa87bdb989af3c42562fa6ab038a Mon Sep 17 00:00:00 2001 From: Sioni Summers Date: Wed, 1 Apr 2020 16:51:11 +0200 Subject: [PATCH] (Re)quantize existing QKeras model with model_quantize --- qkeras/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/qkeras/utils.py b/qkeras/utils.py index 20d7c9f6..b311c63f 100644 --- a/qkeras/utils.py +++ b/qkeras/utils.py @@ -270,7 +270,7 @@ def model_quantize(model, # Dense becomes QDense # Activation converts activation functions - if layer["class_name"] == "Dense": + if layer["class_name"] in ["Dense", "QDense"]: layer["class_name"] = "QDense" # needs to add kernel/bias quantizers kernel_quantizer = get_config( @@ -290,7 +290,7 @@ def model_quantize(model, else: quantize_activation(layer_config, activation_bits) - elif layer["class_name"] in ["Conv1D", "Conv2D"]: + elif layer["class_name"] in ["Conv1D", "Conv2D", "QConv1D", "QConv2D"]: q_name = "Q" + layer["class_name"] layer["class_name"] = q_name # needs to add kernel/bias quantizers @@ -311,7 +311,7 @@ def model_quantize(model, else: quantize_activation(layer_config, activation_bits) - elif layer["class_name"] == "DepthwiseConv2D": + elif layer["class_name"] in ["DepthwiseConv2D", "QDepthwiseConv2D"]: layer["class_name"] = "QDepthwiseConv2D" # needs to add kernel/bias quantizers depthwise_quantizer = get_config(quantizer_config, layer, @@ -330,7 +330,7 @@ def model_quantize(model, else: quantize_activation(layer_config, activation_bits) - elif layer["class_name"] == "Activation": + elif layer["class_name"] in ["Activation", "QActivation"]: quantizer = get_config(quantizer_config, layer, "QActivation") # this is to avoid softmax from quantizing in autoq if quantizer is None: @@ -351,7 +351,7 @@ def model_quantize(model, else: quantize_activation(layer_config, activation_bits) - elif layer["class_name"] == "BatchNormalization": + elif layer["class_name"] in ["BatchNormalization", "QBatchNormalization"]: layer["class_name"] = "QBatchNormalization" # needs to add kernel/bias quantizers gamma_quantizer = get_config(