Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 711792163
Change-Id: I5118482d2fbd0c4a722ab65ca01a2016aa4ce44b
  • Loading branch information
Akshaya Purohit authored and copybara-github committed Jan 3, 2025
1 parent 8e7a1a4 commit 4eedf1e
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 240 deletions.
93 changes: 93 additions & 0 deletions qkeras/base_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2025 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow.compat.v2 as tf
import tensorflow.keras.backend as K


def _create_variable_name(attr_name, var_name=None):
"""Creates variable name.
Arguments:
attr_name: string. attribute name
var_name: string. variable name
Returns:
string. variable name
"""

if var_name:
return var_name + "/" + attr_name

# This naming scheme is to solve a problem of a layer having more than
# one quantizer can have multiple qnoise_factor variables with the same
# name of "qnoise_factor".
return attr_name + "_" + str(K.get_uid(attr_name))


class BaseQuantizer(tf.Module):
"""Base quantizer.
Defines behavior all quantizers should follow.
"""

def __init__(self):
self.built = False

def build(self, var_name=None, use_variables=False):
if use_variables:
if hasattr(self, "qnoise_factor"):
self.qnoise_factor = tf.Variable(
lambda: tf.constant(self.qnoise_factor, dtype=tf.float32),
name=_create_variable_name("qnoise_factor", var_name=var_name),
dtype=tf.float32,
trainable=False,
)
self.built = True

def _set_trainable_parameter(self):
pass

def update_qnoise_factor(self, qnoise_factor):
"""Update qnoise_factor."""
if isinstance(self.qnoise_factor, tf.Variable):
# self.qnoise_factor is a tf.Variable.
# This is to update self.qnoise_factor during training.
self.qnoise_factor.assign(qnoise_factor)
else:
if isinstance(qnoise_factor, tf.Variable):
# self.qnoise_factor is a numpy variable, and qnoise_factor is a
# tf.Variable.
self.qnoise_factor = qnoise_factor.eval()
else:
# self.qnoise_factor and qnoise_factor are numpy variables.
# This is to set self.qnoise_factor before building
# (creating tf.Variable) it.
self.qnoise_factor = qnoise_factor

# Override not to expose the quantizer variables.
@property
def variables(self):
return ()

# Override not to expose the quantizer variables.
@property
def trainable_variables(self):
return ()

# Override not to expose the quantizer variables.
@property
def non_trainable_variables(self):
return ()
12 changes: 8 additions & 4 deletions qkeras/qtools/DnC/divide_and_conquer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@

import enum
import logging
from typing import List, Any, Union
from typing import Any, List, Union

import numpy as np
import tensorflow as tf

from qkeras import base_quantizer
from qkeras import quantizers
from qkeras.qtools import generate_layer_data_type_map
from qkeras.qtools import qgraph
from qkeras.qtools import qtools_util
from qkeras.qtools import generate_layer_data_type_map
from qkeras.qtools.DnC import dnc_layer_cost_ace


Expand All @@ -49,8 +50,11 @@ class CostMode(enum.Enum):
class DivideConquerGraph:
"""This class creates model graph structure and methods to access layers."""

def __init__(self, model: tf.keras.Model,
source_quantizers: quantizers.BaseQuantizer = None):
def __init__(
self,
model: tf.keras.Model,
source_quantizers: base_quantizer.BaseQuantizer = None,
):
self._model = model
self._source_quantizer_list = source_quantizers or [
quantizers.quantized_bits(8, 0, 1)]
Expand Down
11 changes: 6 additions & 5 deletions qkeras/qtools/quantized_operators/fused_bn_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
import math

import numpy as np
import copy

from qkeras import base_quantizer
from qkeras.qtools import qtools_util
from qkeras.qtools.quantized_operators import adder_factory
from qkeras.qtools.quantized_operators import divider_factory
from qkeras.qtools.quantized_operators import multiplier_factory
from qkeras.qtools.quantized_operators import quantizer_impl
from qkeras.qtools import qtools_util
from qkeras import quantizers

class FusedBNFactory:
"""determine which quantizer implementation to use.
Expand All @@ -48,14 +48,15 @@ class FusedBNFactory:
"""

def make_quantizer(
self, prev_output_quantizer: quantizer_impl.IQuantizer,
self,
prev_output_quantizer: quantizer_impl.IQuantizer,
beta_quantizer: quantizer_impl.IQuantizer,
mean_quantizer: quantizer_impl.IQuantizer,
inverse_quantizer: quantizer_impl.IQuantizer,
prev_bias_quantizer: quantizer_impl.IQuantizer,
use_beta: bool,
use_bias: bool,
qkeras_inverse_quantizer:quantizers.BaseQuantizer
qkeras_inverse_quantizer: base_quantizer.BaseQuantizer,
):
"""Makes a fused_bn quantizer.
Expand Down
31 changes: 31 additions & 0 deletions qkeras/quantizer_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2025 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Imports for QKeras quantizers."""

from .quantizers import bernoulli
from .quantizers import binary
from .quantizers import quantized_bits
from .quantizers import quantized_hswish
from .quantizers import quantized_linear
from .quantizers import quantized_po2
from .quantizers import quantized_relu
from .quantizers import quantized_relu_po2
from .quantizers import quantized_sigmoid
from .quantizers import quantized_tanh
from .quantizers import quantized_ulaw
from .quantizers import stochastic_binary
from .quantizers import stochastic_ternary
from .quantizers import ternary
Loading

0 comments on commit 4eedf1e

Please sign in to comment.