From 78a9a6ae041b258ac031977275a6dba7436239fb Mon Sep 17 00:00:00 2001 From: mschoenb97 Date: Wed, 4 Oct 2023 16:51:04 -0700 Subject: [PATCH] Copybara import of the project: -- eccc2e33ae0ce15b8effded00d1b22cf80209bd4 by mschoenb97 : Update quantizers file COPYBARA_INTEGRATE_REVIEW=https://github.com/google/qkeras/pull/124 from mschoenb97:small-pr eccc2e33ae0ce15b8effded00d1b22cf80209bd4 PiperOrigin-RevId: 570850707 Change-Id: Iea493190064986a86aadd9c70b03619c5d58fc29 --- qkeras/quantizers.py | 537 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 514 insertions(+), 23 deletions(-) diff --git a/qkeras/quantizers.py b/qkeras/quantizers.py index 41d8b381..be90e720 100644 --- a/qkeras/quantizers.py +++ b/qkeras/quantizers.py @@ -138,13 +138,13 @@ def _get_scaling_axis(scale_axis: Any, len_axis: int) -> List[int]: if isinstance(scale_axis, list): axis = [i for i in range(len_axis) if i not in scale_axis] else: - axis = list(range(scale_axis)) - axis += list(range(scale_axis+1, len_axis)) + axis = tf.range(scale_axis) + axis = tf.concat([axis, tf.range(scale_axis + 1, len_axis)], axis=0) else: if K.image_data_format() == "channels_last": - axis = list(range(len_axis - 1)) + axis = tf.range(tf.math.maximum(len_axis - 1, 0)) else: - axis = list(range(1, len_axis)) + axis = tf.range(1, len_axis) return axis @@ -430,10 +430,10 @@ def _clip_po2_scale(scale: tf.Tensor, min_po2_exponent: Any, return scale -def _get_scale(alpha: Any, x: tf.Tensor, q: tf.Tensor, - scale_axis: Any = None, per_channel_scale: bool = True, - elements_per_scale: Any = None, min_po2_exponent: Any = None, - max_po2_exponent: Any = None): +def _get_least_squares_scale( + alpha: Any, x: tf.Tensor, q: tf.Tensor, scale_axis: Any = None, + per_channel_scale: bool = True, elements_per_scale: Any = None, + min_po2_exponent: Any = None, max_po2_exponent: Any = None): """Gets scaling factor for scaling the tensor per channel. It uses the least squares method to find the scaling factor. @@ -499,6 +499,9 @@ def _get_scale(alpha: Any, x: tf.Tensor, q: tf.Tensor, scale = float(alpha) return scale +def _get_scale(*args, **kwargs): + """Old name for _get_least_squares_scale. Kept for backwards compatibility.""" + return _get_least_squares_scale(*args, **kwargs) def smooth_sigmoid(x): """Implements a linear approximation of a sigmoid function.""" @@ -707,12 +710,6 @@ def build(self, var_name=None, use_variables=False): name=_create_variable_name("qnoise_factor", var_name=var_name), dtype=tf.float32, trainable=False) - if hasattr(self, "integer"): - self.integer = tf.Variable( - lambda: tf.constant(self.integer, dtype=tf.int32), - name=_create_variable_name("integer", var_name=var_name), - dtype=tf.int32, - trainable=False) self.built = True def _set_trainable_parameter(self): @@ -750,9 +747,503 @@ def trainable_variables(self): def non_trainable_variables(self): return () +class quantized_linear(BaseQuantizer): + """Linear quantization with fixed number of bits. + + This quantizer maps inputs to the nearest value of a fixed number of + outputs that are evenly spaced, with possible scaling and stochastic + rounding. This is an updated version of the legacy quantized_bits. + + The core computation is: + 1. Divide the tensor by a quantization scale + 2. Clip the tensor to a specified range + 3. Round to the nearest integer + 4. Multiply the rounded result by the quantization scale + + This clip range is determined by + - The number of bits we have to represent the number + - Whether we want to have a symmetric range or not + - Whether we want to keep negative numbers or not + + The quantization scale is defined by either the quantizer parameters or the + data passed to the __call__ method. See documentation for the `alpha` + parameter to find out more. + + For backprop purposes, the quantizer uses the straight-through estimator + for the rounding step (https://arxiv.org/pdf/1903.05662.pdf). Thus the + gradient of the __call__ method is 1 on the interval + [quantization_scale * clip_min, quantization_scale * clip_max] and 0 + elsewhere. + + The quantizer also supports a number of other optional features: + - Stochastic rounding (see the `stochastic_rounding` parameter) + - Quantization noise (see the `qnoise_factor` parameter) + + Notes on the various "scales" in quantized_linear: + + - The quantization scale is the scale used in the core computation (see + above). You can access it via the `quantization_scale` attribute. + - The data type scale is the scale is determined by the type of data + stored on hardware on a small device running a true quantized model. + It is the quantization scale needed to represent `bits` bits, `integer` + of which are integer bits, and one bit is reserved for the sign if + `keep_negative` is True. It can be calculated as + 2 ** (integer - bits + keep_negative). You can access it via the + `data_type_scale` attribute. + - The `scale` attribute stores the quotient of the quantization scale and + the data type scale. This is also the scale that can be directly + specified by the user, via the `alpha` parameter. + + These three quantities are related by the equation + scale = quantization_scale / data_type_scale. + + See the diagram below of scale usage in a quantized conv layer. + + +------------------------------------------------------------------------+ + | data_type_scale ---------------> stored_weights | + | (determines decimal point) | | + | V | + | conv op | + | | | + | V | + | accumulator | + | | | + | determines quantization V | + | range and precision ---------------> quantization_scale | + | (per channel) | | + | V | + | activation | + +------------------------------------------------------------------------+ + + # TODO: The only fundamentally necessary scale is the quantization scale. + # We should consider removing the data type scale and scale attributes, + # but know that this will require rewriting much of how qtools and HLS4ML + # use these scale attributes. + + Note on binary quantization (bits=1): + The core computation is modified here when `keep_negative` is True to + perform a scaled sign function. This is needed because the core + computation as defined above requires that 0 be mapped to 0, which does + not allow us to keep both positive and negative outputs for binary + quantization. Special shifting operations are used to achieve this. + + Example usage: + + # 8-bit quantization with 3 integer bits + >>> q = quantized_linear(8, 3) + >>> x = tf.constant([0.0, 0.5, 1.0, 1.5, 2.0]) + >>> q(x).numpy() + array([0., 0., 1., 2., 2.], dtype=float32) + + # 2-bit quantization with "auto" and tensor alphas + >>> q_auto = quantized_linear(2, alpha="auto") + >>> x = tf.constant([0.0, 0.5, 1.0, 1.5, 2.0]) + >>> q_auto(x).numpy() + array([0., 0., 0., 2., 2.], dtype=float32) + >>> q_auto.scale.numpy() + array([4.], dtype=float32) + >>> q_auto.quantization_scale.numpy() + array([2.], dtype=float32) + >>> q_fixed = quantized_linear(2, alpha=q_auto.scale) + >>> q_fixed(x) + array([0., 0., 0., 2., 2.], dtype=float32) + + Args: + bits (int): Number of bits to represent the number. Defaults to 8. + integer (int): Number of bits to the left of the decimal point, used for + data_type_scale. Defaults to 0. + symmetric (bool): If true, we will have the same number of values + for positive and negative numbers. Defaults to True. + alpha (str, Tensor, None): Instructions for determining the quantization + scale. Defaults to None. + - If None: the quantization scale is the data type scale, determined + by `integer`, `bits`, and `keep_negative`. + - If "auto", the quantization scale is calculated as the minimum + floating point scale per-channel that does not clip the max of x. + - If "auto_po2", the quantization scale is chosen as the + power of two per-channel that minimizes squared error between the + quantized x and the original x. + - If Tensor: The quantization scale is the Tensor passed in + multiplied by the data type scale. + keep_negative (bool): If false, we clip negative numbers. Defaults to + True. + use_stochastic_rounding (bool): If true, we perform stochastic rounding + (https://arxiv.org/pdf/1502.02551.pdf). + scale_axis (int, None): Which axis to calculate scale from. If None, we + perform per-channel scaling based off of the image data format. Note + that each entry of a rank-1 tensor is considered its own channel by + default. See `_get_scaling_axis` for more details. Defaults to None. + qnoise_factor (float): A scalar from 0 to 1 that represents the level of + quantization noise to add. This controls the amount of the + quantization noise to add to the outputs by changing the weighted + sum of (1 - qnoise_factor) * unquantized_x + qnoise_factor * + quantized_x. Defaults to 1.0, which means that the result is fully + quantized. + use_variables (bool): If true, we use tf.Variables to store certain + parameters. See the BaseQuantizer implementation for more details. + Defaults to False. If set to True, be sure to use the special attribute + update methods detailed in the BaseQuantizer. + var_name (str or None): A variable name shared between the tf.Variables + created in on initialization, if use_variables is true. If None, the + variable names are generated automatically based on the parameter names + along with a uid. Defaults to None. + + Returns: + function: Function that computes linear quantization. + + Raises: + ValueError: + - If `bits` is not positive, or is too small to represent `integer`. + - If `integer` is negative. + - If `alpha` is a string but not one of ("auto", "auto_po2"). + + """ + + # string options for alpha parameter + ALPHA_STRING_OPTIONS = ("auto", "auto_po2") + + def __init__( + self, + bits=8, + integer=0, + symmetric=1, + keep_negative=True, + alpha=None, + use_stochastic_rounding=False, + scale_axis=None, + qnoise_factor=1.0, + var_name=None, + use_variables=False, + ): + super(quantized_linear, self).__init__() + + self.var_name = var_name + + # Error checking + self._check_bits(bits) + self._check_alpha(alpha) + + # Set non-modifyable attributes + self._bits = bits + self._integer = integer + self._keep_negative = keep_negative + self._use_stochastic_rounding = use_stochastic_rounding + self._scale_axis = scale_axis + self._use_variables = use_variables + + # Set modifyable attributes + self.alpha = alpha + self.qnoise_factor = qnoise_factor + self.symmetric = symmetric + + # Set default quantization scale + self.quantization_scale = self.default_quantization_scale + + def _check_bits(self, bits): + """Error checking for bits parameter""" + err_msg = f"Bit count {bits} must be positive" + if bits <= 0: + raise ValueError(err_msg) + + def _check_alpha(self, alpha): + """Error checking for alpha parameter""" + + if isinstance(alpha, six.string_types): + # Check the quantizer has been given a valid alpha string + if not alpha in self.ALPHA_STRING_OPTIONS: + raise ValueError( + f"Invalid alpha '{alpha}' for auto alpha computation. " + f"Must be one of {self.ALPHA_STRING_OPTIONS}") + elif alpha is not None: # alpha is a tensor + try: + # any allowable array type can be cast as a numpy array + np.array(alpha) + except TypeError: + raise TypeError( + f"alpha must be, a string, an array, or None, not {type(alpha)}") + + @property + def bits(self): + return self._bits + + @property + def integer(self): + return self._integer + + @property + def keep_negative(self): + return self._keep_negative + + @property + def use_stochastic_rounding(self): + return self._use_stochastic_rounding + + @property + def scale_axis(self): + return self._scale_axis + + @property + def use_variables(self): + return self._use_variables + + @property + def scale(self): + return self.quantization_scale / self.data_type_scale + + @property + def data_type_scale(self): + """Quantization scale for the data type""" + # integer is sometimes cast as int32, so cast to float32 to avoid errors + integer = tf.cast(self.integer, tf.float32) + return K.pow(2.0, integer - self.bits + self.keep_negative) + + @property + def auto_alpha(self): + """Returns true if using a data-dependent alpha""" + + return isinstance(self.alpha, six.string_types) + + @property + def use_sign_function(self): + """Return true if using sign function for quantization""" + + return (self.bits == 1.0) and self.keep_negative + + @property + def default_quantization_scale(self): + """Calculate and set quantization_scale default""" + + # Set default quantization scale + quantization_scale = self.data_type_scale + + # Quantization scale given by alpha + if self.alpha is not None and not self.auto_alpha: + quantization_scale = self.alpha * self.data_type_scale + + return quantization_scale + + def get_clip_bounds(self): + """Get bounds of clip range""" + + if self.use_sign_function: + clip_min = K.cast_to_floatx(-0.5) + clip_max = K.cast_to_floatx(0.5) + else: + unsigned_bits_po2 = K.pow(2.0, self.bits - self.keep_negative) + # if symmetric, clip_min is negative of clip_max. Otherwise clip_min is + # lowered by 1, giving us one more representable number + clip_min = self.keep_negative * (-unsigned_bits_po2 + self.symmetric) + clip_max = unsigned_bits_po2 - K.cast_to_floatx(1.0) + + return clip_min, clip_max + + def __call__(self, x): + """Core quantization function""" + + # Build if not already built + self._build() + + # Data type conversion + x = K.cast_to_floatx(x) + shape = x.shape + + if self.auto_alpha: + # get data-dependent quantization scale + quantization_scale = self._get_auto_quantization_scale(x) + else: + # quantization scale determined by quantizer params, not data + # see default_quantization_scale property for more info + quantization_scale = self.quantization_scale + + scaled_xq = self._scale_clip_and_round(x, quantization_scale) + xq = scaled_xq * quantization_scale + + res = x + self.qnoise_factor * (xq - x) + res.set_shape(shape) + + return res + + def _scale_clip_and_round(self, x, quantization_scale): + """Scale, clip, and round x to an integer value in a limited range + Note that the internal shift is needed for 1-bit quantization to ensure + that a sign function is used. Otherise, the binary quantizer would have + three output values""" + + # special shifting needed to compute a sign function. + shift = self.use_sign_function * 0.5 + + clip_min, clip_max = self.get_clip_bounds() + + scaled_x = x / quantization_scale + clipped_scaled_x = K.clip(scaled_x, clip_min, clip_max) + # Round through to nearest integer, using straight-through estimator + # for gradient computations. + scaled_xq = _round_through( + clipped_scaled_x - shift, + use_stochastic_rounding=self.use_stochastic_rounding, + precision=1.0, # using 1.0 precision so that we round to a nearby integer + ) + + return scaled_xq + shift + + def _get_auto_quantization_scale(self, x): + """Get quantization_scale, either from self or from input x""" + + # Get the minimum floating point scale that does not clip the max of x + # This is the quantization scale for alpha="auto" + quantization_scale = self._get_quantization_scale_from_max_data(x) + + if self.alpha == "auto_po2": + quantization_scale = self._po2_autoscale(x, quantization_scale) + + # update quantization_scale variable + # stop_gradient on quantization_scale to ignore dependence on x + self.quantization_scale = tf.stop_gradient(quantization_scale) + + # very important that return value is a tf.Variable with shape None + return self.quantization_scale + + def _get_quantization_scale_from_max_data(self, x): + """Get the minimum floating point scale that does not clip the max + of x""" + + axis = _get_scaling_axis(self.scale_axis, tf.rank(x)) + + clip_min, clip_max = self.get_clip_bounds() + clip_range = clip_max - clip_min + + # get quantization scale- depends on whether we are keeping negative + # divide by clip range to ensure that we clip right at the max of x + if self.keep_negative: + data_max = K.max(tf.math.abs(x), axis=axis, keepdims=True) + quantization_scale = (data_max * 2) / clip_range + else: + data_max = K.max(x, axis=axis, keepdims=True) + quantization_scale = data_max / clip_range + + return tf.math.maximum(quantization_scale, K.epsilon()) + + def _po2_autoscale(self, x, quantization_scale): + """Get an approximation of the "best" po2 scale using least squares""" + + # set alpha scale to a near power of two + quantization_scale = K.pow(2.0, + tf.math.round(K.log(quantization_scale + K.epsilon()) / + K.log(2.0))) + + def loop_body(_, quantization_scale): + """Loop body for least squares autoscaling""" + + scaled_xq = self._scale_clip_and_round(x, quantization_scale) + new_quantization_scale = _get_least_squares_scale( + alpha="auto_po2", + x=x, + q=scaled_xq, + scale_axis=self.scale_axis, + ) + return quantization_scale, new_quantization_scale + + def loop_cond(last_quantization_scale, quantization_scale): + """Loop condition for least squares autoscaling- stop when the + scale converges""" + + tensors_not_equal = tf.math.reduce_any( + tf.not_equal(last_quantization_scale, quantization_scale)) + return tensors_not_equal + + # Need a tensor of the same shape as quantization_scale that + # does not equal quantization_scale + dummy_quantization_scale = -tf.ones_like(quantization_scale) + + # For 1-bit quantization, po2 autoscale loop is guaranteed to converge + # after 1 iteration + max_iterations = 1 if self.use_sign_function else 5 + + _, quantization_scale = tf.while_loop( + loop_cond, + loop_body, + (dummy_quantization_scale, quantization_scale), + maximum_iterations=max_iterations, + ) + + return quantization_scale + + def _build(self): + """Build if not done so already""" + + if not self.built: + self.build(var_name=self.var_name, use_variables=self.use_variables) + + def max(self): + """Get maximum value that quantized_linear class can represent.""" + _, clip_max = self.get_clip_bounds() + return clip_max * self.quantization_scale + + def min(self): + """Get minimum value that quantized_linear class can represent.""" + clip_min, _ = self.get_clip_bounds() + return clip_min * self.quantization_scale + + def range(self): + """Returns a list of all values that quantized_linear can represent + }.""" + + if self.use_sign_function: + return K.cast_to_floatx([self.max(), self.min()]) + else: + clip_min, clip_max = self.get_clip_bounds() + clip_max = tf.cast(clip_max, tf.int32) + clip_min = tf.cast(clip_min, tf.int32) + pos_array = K.cast_to_floatx(tf.range(clip_max + 1)) + neg_array = K.cast_to_floatx(tf.range(clip_min, 0)) + + return self.quantization_scale * tf.concat([pos_array, neg_array], axis=0) + + def __str__(self): + + # Main parameters always printed in string + flags = [ + str(int(self.bits)), + str(int(self.integer)), + str(int(self.symmetric))] + # Optional parameters only printed if not default + if not self.keep_negative: + flags.append("keep_negative=False") + if self.auto_alpha: + alpha = "'" + self.alpha + "'" + flags.append("alpha=" + alpha) + elif self.alpha is not None: + alpha = np.array(alpha) + flags.append("alpha=" + str(alpha)) + if self.use_stochastic_rounding: + flags.append("use_stochastic_rounding=" + + str(int(self.use_stochastic_rounding))) + return "quantized_linear(" + ",".join(flags) + ")" + + def _set_trainable_parameter(self): + if self.alpha is None: + self.alpha = "auto_po2" + self.symmetric = True + + @classmethod + def from_config(cls, config): + return cls(**config) + + def get_config(self): + + config = { + "bits": self.bits, + "integer": self.integer, + "symmetric": self.symmetric, + "alpha": self.alpha, + "keep_negative": self.keep_negative, + "use_stochastic_rounding": self.use_stochastic_rounding, + "qnoise_factor": self.qnoise_factor, + } + return config class quantized_bits(BaseQuantizer): # pylint: disable=invalid-name - """Quantizes the number to a number of bits. + """Legacy quantizer: Quantizes the number to a number of bits. In general, we want to use a quantization function like: @@ -836,6 +1327,7 @@ def __init__(self, min_po2_exponent=None, max_po2_exponent=None): super(quantized_bits, self).__init__() + self.bits = bits self.integer = integer self.symmetric = symmetric @@ -937,7 +1429,7 @@ def __call__(self, x): mask = v < levels / 2 z = tf.sign(x) * tf.where(mask, v, tf.ones_like(v) * levels / 2) print(idx, self.min_po2_exponent, self.max_po2_exponent, m) - scale = _get_scale(alpha="auto_po2", x=x, q=z, + scale = _get_least_squares_scale(alpha="auto_po2", x=x, q=z, scale_axis=self.scale_axis, elements_per_scale=self.elements_per_scale, min_po2_exponent=self.min_po2_exponent, @@ -1152,7 +1644,7 @@ def __call__(self, x): # if we use non stochastic binary to compute alpha, # this function seems to behave better - scale = _get_scale(self.alpha, x, q_non_stochastic) + scale = _get_least_squares_scale(self.alpha, x, q_non_stochastic) self.scale = scale return x + tf.stop_gradient(-x + scale * q) @@ -1281,7 +1773,7 @@ def __call__(self, x): use_stochastic_rounding=self.use_stochastic_rounding, precision=1. / 3.) q = K.cast(tf.abs(v) >= thres, K.floatx()) * tf.sign(x) - scale = _get_scale(self.alpha, x, q) + scale = _get_least_squares_scale(self.alpha, x, q) else: if self.threshold is None: thres = self.default_threshold @@ -1419,7 +1911,7 @@ def stochastic_output(): for _ in range(self.number_of_unrolls): T = scale / 2.0 q_ns = K.cast(tf.abs(x) >= T, K.floatx()) * K.sign(x) - scale = _get_scale(self.alpha, x, q_ns) + scale = _get_least_squares_scale(self.alpha, x, q_ns) x_norm = x / (x_std + K.epsilon()) T = scale / (2.0 * (x_std + K.epsilon())) @@ -1635,7 +2127,7 @@ def __call__(self, x): if self.alpha is None: x = K.tanh(x) - self.scale = _get_scale( + self.scale = _get_least_squares_scale( self.alpha, x, k_sign, @@ -1745,7 +2237,7 @@ def stochastic_output(): q += (1.0 - tf.abs(q)) q_non_stochastic = tf.sign(x) q_non_stochastic += (1.0 - tf.abs(q_non_stochastic)) - scale = _get_scale(self.alpha, x, q_non_stochastic) + scale = _get_least_squares_scale(self.alpha, x, q_non_stochastic) self.scale = scale return x + tf.stop_gradient(-x + scale * q) @@ -2648,6 +3140,7 @@ def get_config(self): class quantized_hswish(quantized_bits): # pylint: disable=invalid-name """Computes a quantized hard swish to a number of bits. + # TODO(mschoenb97): Update to inherit from quantized_linear. Equation of h-swisth function in mobilenet v3: hswish(x) = x * ReluY(x + relu_shift) / Y @@ -2696,7 +3189,6 @@ def __init__(self, scale_axis=None, qnoise_factor=1.0, var_name=None, - use_ste=True, use_variables=False, relu_shift: int = 3, relu_upper_bound: int = 6): @@ -2710,7 +3202,6 @@ def __init__(self, scale_axis=scale_axis, qnoise_factor=qnoise_factor, var_name=var_name, - use_ste=use_ste, use_variables=use_variables) self.relu_shift = relu_shift