Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713012271
Change-Id: Ib9b64b6ddde9ad843fc7612812324b36b7b24fff
  • Loading branch information
Akshaya Purohit authored and copybara-github committed Jan 7, 2025
1 parent 84f3adf commit 96e6f39
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions qkeras/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,20 @@
]


def find_bn_fusing_layer_pair(model):
def find_bn_fusing_layer_pair(model, custom_objects={}):
"""Finds layers that can be fused with the following batchnorm layers.
Args:
model: input model
custom_objects: Dict of model specific objects needed for cloning.
Returns:
Dict that marks all the layer pairs that need to be fused.
Note: supports sequential and non-sequential model
"""

fold_model = clone_model(model)
fold_model = clone_model(model, custom_objects)
(graph, _) = qgraph.GenerateGraphFromModel(
fold_model, "quantized_bits(8, 0, 1)", "quantized_bits(8, 0, 1)")

Expand Down Expand Up @@ -219,7 +220,7 @@ def apply_quantizer(quantizer, input_weight):


# Model utilities: before saving the weights, we want to apply the quantizers
def model_save_quantized_weights(model, filename=None):
def model_save_quantized_weights(model, filename=None, custom_objects={}):
"""Quantizes model for inference and save it.
Takes a model with weights, apply quantization function to weights and
Expand All @@ -241,17 +242,19 @@ def model_save_quantized_weights(model, filename=None):
model: model with weights to be quantized.
filename: if specified, we will save the hdf5 containing the quantized
weights so that we can use them for inference later on.
custom_objects: Dict of model specific objects needed to load/store.
Returns:
dictionary containing layer name and quantized weights that can be used
by a hardware generator.
"""

saved_weights = {}

# Find the conv/dense layers followed by Batchnorm layers
(fusing_layer_pair_dict, bn_layers_to_skip) = find_bn_fusing_layer_pair(model)
(fusing_layer_pair_dict, bn_layers_to_skip) = find_bn_fusing_layer_pair(
model, custom_objects
)

print("... quantizing model")
for layer in model.layers:
Expand Down

0 comments on commit 96e6f39

Please sign in to comment.