diff --git a/qkeras/utils.py b/qkeras/utils.py index fcfda16..d7262e8 100644 --- a/qkeras/utils.py +++ b/qkeras/utils.py @@ -99,11 +99,12 @@ ] -def find_bn_fusing_layer_pair(model): +def find_bn_fusing_layer_pair(model, custom_objects={}): """Finds layers that can be fused with the following batchnorm layers. Args: model: input model + custom_objects: Dict of model specific objects needed for cloning. Returns: Dict that marks all the layer pairs that need to be fused. @@ -111,7 +112,7 @@ def find_bn_fusing_layer_pair(model): Note: supports sequential and non-sequential model """ - fold_model = clone_model(model) + fold_model = clone_model(model, custom_objects) (graph, _) = qgraph.GenerateGraphFromModel( fold_model, "quantized_bits(8, 0, 1)", "quantized_bits(8, 0, 1)") @@ -219,7 +220,7 @@ def apply_quantizer(quantizer, input_weight): # Model utilities: before saving the weights, we want to apply the quantizers -def model_save_quantized_weights(model, filename=None): +def model_save_quantized_weights(model, filename=None, custom_objects={}): """Quantizes model for inference and save it. Takes a model with weights, apply quantization function to weights and @@ -241,17 +242,19 @@ def model_save_quantized_weights(model, filename=None): model: model with weights to be quantized. filename: if specified, we will save the hdf5 containing the quantized weights so that we can use them for inference later on. + custom_objects: Dict of model specific objects needed to load/store. Returns: dictionary containing layer name and quantized weights that can be used by a hardware generator. - """ saved_weights = {} # Find the conv/dense layers followed by Batchnorm layers - (fusing_layer_pair_dict, bn_layers_to_skip) = find_bn_fusing_layer_pair(model) + (fusing_layer_pair_dict, bn_layers_to_skip) = find_bn_fusing_layer_pair( + model, custom_objects + ) print("... quantizing model") for layer in model.layers: