diff --git a/qkeras/base_quantizer.py b/qkeras/base_quantizer.py new file mode 100644 index 00000000..161985cd --- /dev/null +++ b/qkeras/base_quantizer.py @@ -0,0 +1,93 @@ +# Copyright 2025 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import tensorflow.compat.v2 as tf +import tensorflow.keras.backend as K + + +def _create_variable_name(attr_name, var_name=None): + """Creates variable name. + + Arguments: + attr_name: string. attribute name + var_name: string. variable name + + Returns: + string. variable name + """ + + if var_name: + return var_name + "/" + attr_name + + # This naming scheme is to solve a problem of a layer having more than + # one quantizer can have multiple qnoise_factor variables with the same + # name of "qnoise_factor". + return attr_name + "_" + str(K.get_uid(attr_name)) + + +class BaseQuantizer(tf.Module): + """Base quantizer. + + Defines behavior all quantizers should follow. + """ + + def __init__(self): + self.built = False + + def build(self, var_name=None, use_variables=False): + if use_variables: + if hasattr(self, "qnoise_factor"): + self.qnoise_factor = tf.Variable( + lambda: tf.constant(self.qnoise_factor, dtype=tf.float32), + name=_create_variable_name("qnoise_factor", var_name=var_name), + dtype=tf.float32, + trainable=False, + ) + self.built = True + + def _set_trainable_parameter(self): + pass + + def update_qnoise_factor(self, qnoise_factor): + """Update qnoise_factor.""" + if isinstance(self.qnoise_factor, tf.Variable): + # self.qnoise_factor is a tf.Variable. + # This is to update self.qnoise_factor during training. + self.qnoise_factor.assign(qnoise_factor) + else: + if isinstance(qnoise_factor, tf.Variable): + # self.qnoise_factor is a numpy variable, and qnoise_factor is a + # tf.Variable. + self.qnoise_factor = qnoise_factor.eval() + else: + # self.qnoise_factor and qnoise_factor are numpy variables. + # This is to set self.qnoise_factor before building + # (creating tf.Variable) it. + self.qnoise_factor = qnoise_factor + + # Override not to expose the quantizer variables. + @property + def variables(self): + return () + + # Override not to expose the quantizer variables. + @property + def trainable_variables(self): + return () + + # Override not to expose the quantizer variables. + @property + def non_trainable_variables(self): + return () diff --git a/qkeras/qtools/DnC/divide_and_conquer.py b/qkeras/qtools/DnC/divide_and_conquer.py index 710589d1..902b826f 100644 --- a/qkeras/qtools/DnC/divide_and_conquer.py +++ b/qkeras/qtools/DnC/divide_and_conquer.py @@ -27,15 +27,16 @@ import enum import logging -from typing import List, Any, Union +from typing import Any, List, Union import numpy as np import tensorflow as tf +from qkeras import base_quantizer from qkeras import quantizers +from qkeras.qtools import generate_layer_data_type_map from qkeras.qtools import qgraph from qkeras.qtools import qtools_util -from qkeras.qtools import generate_layer_data_type_map from qkeras.qtools.DnC import dnc_layer_cost_ace @@ -49,8 +50,11 @@ class CostMode(enum.Enum): class DivideConquerGraph: """This class creates model graph structure and methods to access layers.""" - def __init__(self, model: tf.keras.Model, - source_quantizers: quantizers.BaseQuantizer = None): + def __init__( + self, + model: tf.keras.Model, + source_quantizers: base_quantizer.BaseQuantizer = None, + ): self._model = model self._source_quantizer_list = source_quantizers or [ quantizers.quantized_bits(8, 0, 1)] diff --git a/qkeras/qtools/quantized_operators/fused_bn_factory.py b/qkeras/qtools/quantized_operators/fused_bn_factory.py index 242215a3..ade5e060 100644 --- a/qkeras/qtools/quantized_operators/fused_bn_factory.py +++ b/qkeras/qtools/quantized_operators/fused_bn_factory.py @@ -23,13 +23,13 @@ import math import numpy as np -import copy + +from qkeras import base_quantizer +from qkeras.qtools import qtools_util from qkeras.qtools.quantized_operators import adder_factory from qkeras.qtools.quantized_operators import divider_factory from qkeras.qtools.quantized_operators import multiplier_factory from qkeras.qtools.quantized_operators import quantizer_impl -from qkeras.qtools import qtools_util -from qkeras import quantizers class FusedBNFactory: """determine which quantizer implementation to use. @@ -48,14 +48,15 @@ class FusedBNFactory: """ def make_quantizer( - self, prev_output_quantizer: quantizer_impl.IQuantizer, + self, + prev_output_quantizer: quantizer_impl.IQuantizer, beta_quantizer: quantizer_impl.IQuantizer, mean_quantizer: quantizer_impl.IQuantizer, inverse_quantizer: quantizer_impl.IQuantizer, prev_bias_quantizer: quantizer_impl.IQuantizer, use_beta: bool, use_bias: bool, - qkeras_inverse_quantizer:quantizers.BaseQuantizer + qkeras_inverse_quantizer: base_quantizer.BaseQuantizer, ): """Makes a fused_bn quantizer. diff --git a/qkeras/quantizer_imports.py b/qkeras/quantizer_imports.py new file mode 100644 index 00000000..a8a935f5 --- /dev/null +++ b/qkeras/quantizer_imports.py @@ -0,0 +1,31 @@ +# Copyright 2025 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Imports for QKeras quantizers.""" + +from .quantizers import bernoulli +from .quantizers import binary +from .quantizers import quantized_bits +from .quantizers import quantized_hswish +from .quantizers import quantized_linear +from .quantizers import quantized_po2 +from .quantizers import quantized_relu +from .quantizers import quantized_relu_po2 +from .quantizers import quantized_sigmoid +from .quantizers import quantized_tanh +from .quantizers import quantized_ulaw +from .quantizers import stochastic_binary +from .quantizers import stochastic_ternary +from .quantizers import ternary diff --git a/qkeras/quantizers.py b/qkeras/quantizers.py index ddc76dea..34bedd93 100644 --- a/qkeras/quantizers.py +++ b/qkeras/quantizers.py @@ -28,6 +28,7 @@ import tensorflow.keras.backend as K from tensorflow.keras.utils import deserialize_keras_object +from . import base_quantizer from . import quantizer_registry # from .google_internals.experimental_quantizers import parametric_quantizer_d_xmax # from .google_internals.experimental_quantizers import quantized_bits_learnable_scale @@ -667,26 +668,6 @@ def _floor_through(x): return x + tf.stop_gradient(-x + tf.floor(x)) - -def _create_variable_name(attr_name, var_name=None): - """Creates variable name. - Arguments: - attr_name: string. attribute name - var_name: string. variable name - - Returns: - string. variable name - """ - - if var_name: - return var_name + "/" + attr_name - - # This naming scheme is to solve a problem of a layer having more than - # one quantizer can have multiple qnoise_factor variables with the same - # name of "qnoise_factor". - return attr_name + "_" + str(K.get_uid(attr_name)) - - # # Activation functions for quantized networks. # @@ -696,213 +677,153 @@ def _create_variable_name(attr_name, var_name=None): # -class BaseQuantizer(tf.Module): - """Base quantizer - - Defines behavior all quantizers should follow. - """ - - def __init__(self): - self.built = False - - def build(self, var_name=None, use_variables=False): - if use_variables: - if hasattr(self, "qnoise_factor"): - self.qnoise_factor = tf.Variable( - lambda: tf.constant(self.qnoise_factor, dtype=tf.float32), - name=_create_variable_name("qnoise_factor", var_name=var_name), - dtype=tf.float32, - trainable=False) - self.built = True - - def _set_trainable_parameter(self): - pass - - def update_qnoise_factor(self, qnoise_factor): - """Update qnoise_factor.""" - if isinstance(self.qnoise_factor, tf.Variable): - # self.qnoise_factor is a tf.Variable. - # This is to update self.qnoise_factor during training. - self.qnoise_factor.assign(qnoise_factor) - else: - if isinstance(qnoise_factor, tf.Variable): - # self.qnoise_factor is a numpy variable, and qnoise_factor is a - # tf.Variable. - self.qnoise_factor = qnoise_factor.eval() - else: - # self.qnoise_factor and qnoise_factor are numpy variables. - # This is to set self.qnoise_factor before building - # (creating tf.Variable) it. - self.qnoise_factor = qnoise_factor - - # Override not to expose the quantizer variables. - @property - def variables(self): - return () - - # Override not to expose the quantizer variables. - @property - def trainable_variables(self): - return () - - # Override not to expose the quantizer variables. - @property - def non_trainable_variables(self): - return () - - @quantizer_registry.register_quantizer -class quantized_linear(BaseQuantizer): +class quantized_linear(base_quantizer.BaseQuantizer): """Linear quantization with fixed number of bits. - This quantizer maps inputs to the nearest value of a fixed number of - outputs that are evenly spaced, with possible scaling and stochastic - rounding. This is an updated version of the legacy quantized_bits. - - The core computation is: - 1. Divide the tensor by a quantization scale - 2. Clip the tensor to a specified range - 3. Round to the nearest integer - 4. Multiply the rounded result by the quantization scale - - This clip range is determined by - - The number of bits we have to represent the number - - Whether we want to have a symmetric range or not - - Whether we want to keep negative numbers or not - - The quantization scale is defined by either the quantizer parameters or the - data passed to the __call__ method. See documentation for the `alpha` - parameter to find out more. - - For backprop purposes, the quantizer uses the straight-through estimator - for the rounding step (https://arxiv.org/pdf/1903.05662.pdf). Thus the - gradient of the __call__ method is 1 on the interval - [quantization_scale * clip_min, quantization_scale * clip_max] and 0 - elsewhere. - - The quantizer also supports a number of other optional features: - - Stochastic rounding (see the `stochastic_rounding` parameter) - - Quantization noise (see the `qnoise_factor` parameter) - - Notes on the various "scales" in quantized_linear: - - - The quantization scale is the scale used in the core computation (see - above). You can access it via the `quantization_scale` attribute. - - The data type scale is the scale is determined by the type of data - stored on hardware on a small device running a true quantized model. - It is the quantization scale needed to represent `bits` bits, `integer` - of which are integer bits, and one bit is reserved for the sign if - `keep_negative` is True. It can be calculated as - 2 ** (integer - bits + keep_negative). You can access it via the - `data_type_scale` attribute. - - The `scale` attribute stores the quotient of the quantization scale and - the data type scale. This is also the scale that can be directly - specified by the user, via the `alpha` parameter. - - These three quantities are related by the equation - scale = quantization_scale / data_type_scale. - - See the diagram below of scale usage in a quantized conv layer. - - +------------------------------------------------------------------------+ - | data_type_scale ---------------> stored_weights | - | (determines decimal point) | | - | V | - | conv op | - | | | - | V | - | accumulator | - | | | - | determines quantization V | - | range and precision ---------------> quantization_scale | - | (per channel) | | - | V | - | activation | - +------------------------------------------------------------------------+ - - # TODO: The only fundamentally necessary scale is the quantization scale. - # We should consider removing the data type scale and scale attributes, - # but know that this will require rewriting much of how qtools and HLS4ML - # use these scale attributes. - - Note on binary quantization (bits=1): - The core computation is modified here when `keep_negative` is True to - perform a scaled sign function. This is needed because the core - computation as defined above requires that 0 be mapped to 0, which does - not allow us to keep both positive and negative outputs for binary - quantization. Special shifting operations are used to achieve this. - - Example usage: - - # 8-bit quantization with 3 integer bits - >>> q = quantized_linear(8, 3) - >>> x = tf.constant([0.0, 0.5, 1.0, 1.5, 2.0]) - >>> q(x).numpy() - array([0., 0., 1., 2., 2.], dtype=float32) - - # 2-bit quantization with "auto" and tensor alphas - >>> q_auto = quantized_linear(2, alpha="auto") - >>> x = tf.constant([0.0, 0.5, 1.0, 1.5, 2.0]) - >>> q_auto(x).numpy() - array([0., 0., 0., 2., 2.], dtype=float32) - >>> q_auto.scale.numpy() - array([4.], dtype=float32) - >>> q_auto.quantization_scale.numpy() - array([2.], dtype=float32) - >>> q_fixed = quantized_linear(2, alpha=q_auto.scale) - >>> q_fixed(x) - array([0., 0., 0., 2., 2.], dtype=float32) - - Args: - bits (int): Number of bits to represent the number. Defaults to 8. - integer (int): Number of bits to the left of the decimal point, used for - data_type_scale. Defaults to 0. - symmetric (bool): If true, we will have the same number of values - for positive and negative numbers. Defaults to True. - alpha (str, Tensor, None): Instructions for determining the quantization - scale. Defaults to None. - - If None: the quantization scale is the data type scale, determined - by `integer`, `bits`, and `keep_negative`. - - If "auto", the quantization scale is calculated as the minimum - floating point scale per-channel that does not clip the max of x. - - If "auto_po2", the quantization scale is chosen as the - power of two per-channel that minimizes squared error between the - quantized x and the original x. - - If Tensor: The quantization scale is the Tensor passed in - multiplied by the data type scale. - keep_negative (bool): If false, we clip negative numbers. Defaults to - True. - use_stochastic_rounding (bool): If true, we perform stochastic rounding - (https://arxiv.org/pdf/1502.02551.pdf). - scale_axis (int, None): Which axis to calculate scale from. If None, we - perform per-channel scaling based off of the image data format. Note - that each entry of a rank-1 tensor is considered its own channel by - default. See `_get_scaling_axis` for more details. Defaults to None. - qnoise_factor (float): A scalar from 0 to 1 that represents the level of - quantization noise to add. This controls the amount of the - quantization noise to add to the outputs by changing the weighted - sum of (1 - qnoise_factor) * unquantized_x + qnoise_factor * - quantized_x. Defaults to 1.0, which means that the result is fully - quantized. - use_variables (bool): If true, we use tf.Variables to store certain - parameters. See the BaseQuantizer implementation for more details. - Defaults to False. If set to True, be sure to use the special attribute - update methods detailed in the BaseQuantizer. - var_name (str or None): A variable name shared between the tf.Variables - created in on initialization, if use_variables is true. If None, the - variable names are generated automatically based on the parameter names - along with a uid. Defaults to None. + This quantizer maps inputs to the nearest value of a fixed number of + outputs that are evenly spaced, with possible scaling and stochastic + rounding. This is an updated version of the legacy quantized_bits. + + The core computation is: + 1. Divide the tensor by a quantization scale + 2. Clip the tensor to a specified range + 3. Round to the nearest integer + 4. Multiply the rounded result by the quantization scale + + This clip range is determined by + - The number of bits we have to represent the number + - Whether we want to have a symmetric range or not + - Whether we want to keep negative numbers or not + + The quantization scale is defined by either the quantizer parameters or the + data passed to the __call__ method. See documentation for the `alpha` + parameter to find out more. + + For backprop purposes, the quantizer uses the straight-through estimator + for the rounding step (https://arxiv.org/pdf/1903.05662.pdf). Thus the + gradient of the __call__ method is 1 on the interval + [quantization_scale * clip_min, quantization_scale * clip_max] and 0 + elsewhere. + + The quantizer also supports a number of other optional features: + - Stochastic rounding (see the `stochastic_rounding` parameter) + - Quantization noise (see the `qnoise_factor` parameter) + + Notes on the various "scales" in quantized_linear: + + - The quantization scale is the scale used in the core computation (see + above). You can access it via the `quantization_scale` attribute. + - The data type scale is the scale is determined by the type of data + stored on hardware on a small device running a true quantized model. + It is the quantization scale needed to represent `bits` bits, `integer` + of which are integer bits, and one bit is reserved for the sign if + `keep_negative` is True. It can be calculated as + 2 ** (integer - bits + keep_negative). You can access it via the + `data_type_scale` attribute. + - The `scale` attribute stores the quotient of the quantization scale and + the data type scale. This is also the scale that can be directly + specified by the user, via the `alpha` parameter. + + These three quantities are related by the equation + scale = quantization_scale / data_type_scale. + + See the diagram below of scale usage in a quantized conv layer. + + +------------------------------------------------------------------------+ + | data_type_scale ---------------> stored_weights | + | (determines decimal point) | | + | V | + | conv op | + | | | + | V | + | accumulator | + | | | + | determines quantization V | + | range and precision ---------------> quantization_scale | + | (per channel) | | + | V | + | activation | + +------------------------------------------------------------------------+ + + # TODO: The only fundamentally necessary scale is the quantization scale. + # We should consider removing the data type scale and scale attributes, + # but know that this will require rewriting much of how qtools and HLS4ML + # use these scale attributes. + + Note on binary quantization (bits=1): + The core computation is modified here when `keep_negative` is True to + perform a scaled sign function. This is needed because the core + computation as defined above requires that 0 be mapped to 0, which does + not allow us to keep both positive and negative outputs for binary + quantization. Special shifting operations are used to achieve this. + + Example usage: + + # 8-bit quantization with 3 integer bits + >>> q = quantized_linear(8, 3) + >>> x = tf.constant([0.0, 0.5, 1.0, 1.5, 2.0]) + >>> q(x).numpy() + array([0., 0., 1., 2., 2.], dtype=float32) + + # 2-bit quantization with "auto" and tensor alphas + >>> q_auto = quantized_linear(2, alpha="auto") + >>> x = tf.constant([0.0, 0.5, 1.0, 1.5, 2.0]) + >>> q_auto(x).numpy() + array([0., 0., 0., 2., 2.], dtype=float32) + >>> q_auto.scale.numpy() + array([4.], dtype=float32) + >>> q_auto.quantization_scale.numpy() + array([2.], dtype=float32) + >>> q_fixed = quantized_linear(2, alpha=q_auto.scale) + >>> q_fixed(x) + array([0., 0., 0., 2., 2.], dtype=float32) - Returns: - function: Function that computes linear quantization. + Args: + bits (int): Number of bits to represent the number. Defaults to 8. + integer (int): Number of bits to the left of the decimal point, used for + data_type_scale. Defaults to 0. + symmetric (bool): If true, we will have the same number of values for + positive and negative numbers. Defaults to True. + alpha (str, Tensor, None): Instructions for determining the quantization + scale. Defaults to None. - If None: the quantization scale is the data + type scale, determined by `integer`, `bits`, and `keep_negative`. - If + "auto", the quantization scale is calculated as the minimum floating point + scale per-channel that does not clip the max of x. - If "auto_po2", the + quantization scale is chosen as the power of two per-channel that + minimizes squared error between the quantized x and the original x. - If + Tensor: The quantization scale is the Tensor passed in multiplied by the + data type scale. + keep_negative (bool): If false, we clip negative numbers. Defaults to True. + use_stochastic_rounding (bool): If true, we perform stochastic rounding + (https://arxiv.org/pdf/1502.02551.pdf). + scale_axis (int, None): Which axis to calculate scale from. If None, we + perform per-channel scaling based off of the image data format. Note that + each entry of a rank-1 tensor is considered its own channel by default. + See `_get_scaling_axis` for more details. Defaults to None. + qnoise_factor (float): A scalar from 0 to 1 that represents the level of + quantization noise to add. This controls the amount of the quantization + noise to add to the outputs by changing the weighted sum of (1 - + qnoise_factor) * unquantized_x + qnoise_factor * quantized_x. Defaults to + 1.0, which means that the result is fully quantized. + use_variables (bool): If true, we use tf.Variables to store certain + parameters. See the base_quantizer.BaseQuantizer implementation for more + details. Defaults to False. If set to True, be sure to use the special + attribute update methods detailed in the base_quantizer.BaseQuantizer. + var_name (str or None): A variable name shared between the tf.Variables + created in on initialization, if use_variables is true. If None, the + variable names are generated automatically based on the parameter names + along with a uid. Defaults to None. - Raises: - ValueError: - - If `bits` is not positive, or is too small to represent `integer`. - - If `integer` is negative. - - If `alpha` is a string but not one of ("auto", "auto_po2"). + Returns: + function: Function that computes linear quantization. - """ + Raises: + ValueError: + - If `bits` is not positive, or is too small to represent `integer`. + - If `integer` is negative. + - If `alpha` is a string but not one of ("auto", "auto_po2"). + """ # string options for alpha parameter ALPHA_STRING_OPTIONS = ("auto", "auto_po2") @@ -1249,7 +1170,7 @@ def get_config(self): @quantizer_registry.register_quantizer -class quantized_bits(BaseQuantizer): # pylint: disable=invalid-name +class quantized_bits(base_quantizer.BaseQuantizer): # pylint: disable=invalid-name """Legacy quantizer: Quantizes the number to a number of bits. In general, we want to use a quantization function like: @@ -1565,7 +1486,7 @@ def get_config(self): @quantizer_registry.register_quantizer -class bernoulli(BaseQuantizer): # pylint: disable=invalid-name +class bernoulli(base_quantizer.BaseQuantizer): # pylint: disable=invalid-name """Computes a Bernoulli sample with probability sigmoid(x). This computation uses ST approximation. @@ -1680,7 +1601,7 @@ def get_config(self): @quantizer_registry.register_quantizer -class ternary(BaseQuantizer): # pylint: disable=invalid-name +class ternary(base_quantizer.BaseQuantizer): # pylint: disable=invalid-name """Computes an activation function returning -alpha, 0 or +alpha. Right now we assume two type of behavior. For parameters, we should @@ -1987,7 +1908,7 @@ def get_config(self): @quantizer_registry.register_quantizer -class binary(BaseQuantizer): # pylint: disable=invalid-name +class binary(base_quantizer.BaseQuantizer): # pylint: disable=invalid-name """Computes the sign(x) returning a value between -alpha and alpha. Although we cannot guarantee E[dL/dy] = E[dL/dx] if we do not use the @@ -2294,7 +2215,7 @@ def get_config(self): @quantizer_registry.register_quantizer -class quantized_relu(BaseQuantizer): # pylint: disable=invalid-name +class quantized_relu(base_quantizer.BaseQuantizer): # pylint: disable=invalid-name """Computes a quantized relu to a number of bits. Modified from: @@ -2516,7 +2437,7 @@ def get_config(self): @quantizer_registry.register_quantizer -class quantized_ulaw(BaseQuantizer): # pylint: disable=invalid-name +class quantized_ulaw(base_quantizer.BaseQuantizer): # pylint: disable=invalid-name """Computes a u-law quantization. Attributes: @@ -2590,7 +2511,7 @@ def get_config(self): @quantizer_registry.register_quantizer -class quantized_tanh(BaseQuantizer): # pylint: disable=invalid-name +class quantized_tanh(base_quantizer.BaseQuantizer): # pylint: disable=invalid-name """Computes a quantized tanh to a number of bits. Modified from: @@ -2660,7 +2581,7 @@ def get_config(self): @quantizer_registry.register_quantizer -class quantized_sigmoid(BaseQuantizer): # pylint: disable=invalid-name +class quantized_sigmoid(base_quantizer.BaseQuantizer): # pylint: disable=invalid-name """Computes a quantized sigmoid to a number of bits. Attributes: @@ -2849,7 +2770,7 @@ def _get_min_max_exponents(non_sign_bits, need_exponent_sign_bit, @quantizer_registry.register_quantizer -class quantized_po2(BaseQuantizer): # pylint: disable=invalid-name +class quantized_po2(base_quantizer.BaseQuantizer): # pylint: disable=invalid-name """Quantizes to the closest power of 2. Attributes: @@ -2986,7 +2907,7 @@ def get_config(self): @quantizer_registry.register_quantizer -class quantized_relu_po2(BaseQuantizer): # pylint: disable=invalid-name +class quantized_relu_po2(base_quantizer.BaseQuantizer): # pylint: disable=invalid-name """Quantizes x to the closest power of 2 when x > 0 Attributes: