diff --git a/docs/api/source/conf.py b/docs/api/source/conf.py index 0a3b03237a4..91b22f67ced 100644 --- a/docs/api/source/conf.py +++ b/docs/api/source/conf.py @@ -145,6 +145,7 @@ def collect_api_entities() -> APIInfo: "nncf.tensor.functions.torch_linalg", "nncf.tensor.functions.torch_io", "nncf.tensor.functions.numpy_io", + "nncf.tensor.functions.ov_numeric", ] with mock(mock_modules): diff --git a/nncf/common/logging/logger.py b/nncf/common/logging/logger.py index 36b4e24e596..306d0a8ca6e 100644 --- a/nncf/common/logging/logger.py +++ b/nncf/common/logging/logger.py @@ -11,6 +11,7 @@ import logging import sys +from functools import lru_cache from typing import Set NNCF_LOGGER_NAME = "nncf" @@ -79,3 +80,13 @@ def warn_bkc_version_mismatch(backend: str, bkc_version: str, current_version: s f"while current {backend} version is {current_version}. " f"If you encounter issues, consider switching to {backend}{bkc_version}" ) + + +@lru_cache(None) +def log_once(level: int, message: str) -> None: + """ + Logs a message only once. + :param level: Logging level, e.g. logging.WARNING. + :param message: The message to log. + """ + nncf_logger.log(level, message) diff --git a/nncf/common/utils/decorators.py b/nncf/common/utils/decorators.py index ef91b360c27..bf644b52931 100644 --- a/nncf/common/utils/decorators.py +++ b/nncf/common/utils/decorators.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from importlib import import_module from typing import Any, Callable, Dict, List @@ -51,3 +52,62 @@ def wrapped_f(*args: Any, **kwargs: Any): # type: ignore return wrapped_f return wrap + + +class ResultsCacheContainer: + """ + A container for results decorated with @cache_results decorator. + """ + + def __init__(self) -> None: + # Stores the results of the decorated function + self._cache: Dict[Any, Any] = {} + # Stores the number of times the cached result was accessed + self._access_count: Dict[Any, int] = {} + + def clear(self) -> None: + self._cache.clear() + self._access_count.clear() + + def is_empty(self) -> bool: + return len(self._cache) == 0 + + def __getitem__(self, item: Any) -> Any: + self._access_count[item] += 1 + return self._cache[item] + + def __setitem__(self, key: Any, value: Any) -> None: + self._access_count[key] = 0 + self._cache[key] = value + + def __contains__(self, item: Any) -> bool: + return item in self._cache + + +def cache_results(cache: ResultsCacheContainer) -> Callable: # type: ignore + """ + Decorator to cache the results of a function. + + Decorated function additionally accepts a `disable_caching` argument do disable caching if needed. If it is True, + the result will not be stored saved to a cache. Also, if there is a corresponding result in the cache, it will be + recomputed. + :param cache: A cache container where results will be stored. + """ + + def decorator(func: Callable) -> Callable: # type: ignore + def wrapper(*args, disable_caching: bool = False, **kwargs) -> Any: # type: ignore + if disable_caching: + return func(*args, **kwargs) + sig = inspect.signature(func) + new_kwargs = {name: arg for name, arg in zip(sig.parameters, args)} + new_kwargs.update(kwargs) + cache_key = (func.__name__, frozenset(new_kwargs.items())) + if cache_key in cache: + return cache[cache_key] + result = func(*args, **kwargs) + cache[cache_key] = result + return result + + return wrapper + + return decorator diff --git a/nncf/import_utils.py b/nncf/import_utils.py new file mode 100644 index 00000000000..2e97e1d58fa --- /dev/null +++ b/nncf/import_utils.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib + +_openvino_available = importlib.util.find_spec("openvino") is not None +_openvino_version = "N/A" +if _openvino_available: + try: + from openvino.runtime import get_version # type: ignore + + version = get_version() + # avoid invalid format + if "-" in version: + ov_major_version, dev_info = version.split("-", 1) + commit_id = dev_info.split("-")[0] + version = f"{ov_major_version}-{commit_id}" + _openvino_version = version + except ImportError: + _openvino_available = False + + +def is_openvino_available() -> bool: + """ + Check if OpenVINO is available. + :return: True if openvino package is installed, False otherwise. + """ + return _openvino_available diff --git a/nncf/openvino/graph/node_utils.py b/nncf/openvino/graph/node_utils.py index 470c0a23c3b..2947c1418e5 100644 --- a/nncf/openvino/graph/node_utils.py +++ b/nncf/openvino/graph/node_utils.py @@ -13,6 +13,7 @@ import numpy as np import openvino.runtime as ov +import openvino.runtime.op as op import openvino.runtime.opset13 as opset import nncf @@ -41,6 +42,8 @@ from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVOpMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import get_node_metatype +from nncf.tensor import Tensor +from nncf.tensor import TensorBackend InplaceInsertionFnType = Callable[[ov.Node, int, str], ov.Node] @@ -97,26 +100,27 @@ def get_number_if_op(model: ov.Model) -> int: """ def cnt_if_op(model: ov.Model, cnt: int) -> int: - for op in model.get_ops(): - if get_node_metatype(op) == OVIfMetatype: + for model_op in model.get_ops(): + if get_node_metatype(model_op) == OVIfMetatype: cnt += 1 - cnt = cnt_if_op(op.get_function(0), cnt) - cnt = cnt_if_op(op.get_function(1), cnt) + cnt = cnt_if_op(model_op.get_function(0), cnt) + cnt = cnt_if_op(model_op.get_function(1), cnt) return cnt return cnt_if_op(model, 0) -def get_const_value(const_node: ov.Node) -> np.ndarray: +def get_const_value(const_node: ov.Node, cast_bf16_to_fp32: Optional[bool] = True) -> np.ndarray: """ Returns the constant tensor for the node. This method is applicable only for the floating-point constant data. :param const_node: OpenVINO node. + :param cast_bf16_to_fp32: Whether to cast bf16 node data to fp32 or not. If False and the node contains bf16 data, + the resulting bf16 value will be returned encoded inside a numpy.float16 array. :return: The constant value. """ - if const_node.get_element_type() == ov.Type.bf16: - # Fixed FP32 data type as the result for BF16 constant + if const_node.get_element_type() == ov.Type.bf16 and cast_bf16_to_fp32: return const_node.get_data(dtype=np.float32) return const_node.data @@ -631,3 +635,42 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple channel_axis = activations_layout.index(OVLayoutElem.C_IN) return channel_axis + + +def convert_op(node: ov.Node, target_dtype: ov.Type) -> ov.Node: + """ + Return a subgraph which converts the given node output to the target data type. If the output is already in the + target data type then the given node is returned. + + :param node: The input node to convert. + :param target_dtype: The target data type to convert the input node to. + :return: The converted node. + """ + if node.get_element_type() == target_dtype: + return node + return opset.convert(node, target_dtype) + + +def non_convertable_divide_op(a: ov.Node, b: ov.Node) -> ov.Node: + """ + Creates a "non-convertable" divide operation. It won't be converted to a*(1/b). + """ + divide_node = a / b + divide_node.get_rt_info()["nonconvertable_divide_0"] = True + return divide_node + + +def create_ov_const_from_tensor(x: Tensor, dtype: ov.Type, name: Optional[str] = None) -> op.Constant: + """ + Create an OpenVINO Constant node from the given tensor. + :param x: Data tensor. Supports NumPy and OV tensor backends. If x backend is OV, the constant node is created + directly from underlying OV tensor. + :param dtype: Data type of the constant. + :param name: Optional name of the constant. + :return: OpenVINO Constant node. + """ + if x.backend == TensorBackend.ov: + assert x.data.get_element_type() == dtype + return opset.constant(x.data, name=name, shared_memory=True) + const = opset.constant(x.data, dtype=dtype, name=name) + return const diff --git a/nncf/quantization/algorithms/weight_compression/awq.py b/nncf/quantization/algorithms/weight_compression/awq.py index e3cb0bea28f..67ce43d3279 100644 --- a/nncf/quantization/algorithms/weight_compression/awq.py +++ b/nncf/quantization/algorithms/weight_compression/awq.py @@ -31,10 +31,9 @@ from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_scale -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_quantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_quantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import quantize_dequantize_weight from nncf.quantization.passes import transform_to_inference_graph from nncf.tensor import TensorDataType from nncf.tensor import functions as fns @@ -262,10 +261,9 @@ def apply( g_compressed_weighs = do_nf4_quantization(weights_to_fake_quantize, g_c_scale) g_decompressed_weighs = do_nf4_dequantization(g_compressed_weighs, g_c_scale) else: - g_compressed_weighs, g_c_scale, g_c_zp = do_int_quantization( - weights_to_fake_quantize, reduction_axis, awq_config + g_decompressed_weighs = quantize_dequantize_weight( + weights_to_fake_quantize, awq_config, reduction_axis ) - g_decompressed_weighs = do_int_dequantization(g_compressed_weighs, g_c_scale, g_c_zp) sacts = gacts / fns.unsqueeze(cur_scale, 1) cur_out = fns.matmul(g_decompressed_weighs, sacts) diff --git a/nncf/quantization/algorithms/weight_compression/config.py b/nncf/quantization/algorithms/weight_compression/config.py index c395b17a3b9..6995e31e58f 100644 --- a/nncf/quantization/algorithms/weight_compression/config.py +++ b/nncf/quantization/algorithms/weight_compression/config.py @@ -40,12 +40,23 @@ def num_bits(self): """ return 8 if self.mode in [CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM] else 4 + @property + def is_asym_mode(self): + return self.mode in [CompressWeightsMode.INT4_ASYM, CompressWeightsMode.INT8_ASYM] + + @property def is_integer(self): """ :return: True if compression type in integer, else False. """ return self.mode not in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1] + def __hash__(self): + return hash((self.mode.value, self.group_size)) + + def __str__(self): + return f"{self.mode.value}_{self.group_size}" + @dataclass class WeightCompressionParameters: diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index 44637b17629..7f6255d120b 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -267,7 +267,6 @@ def _quantize_weights( activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs] wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations) scale, zero_point = ScaleEstimation.calculate_quantization_params( - self._backend_entity, wc_statistics, weight_tensor[:, (i1 + i) : (i1 + i + group_size)], reduction_axes, diff --git a/nncf/quantization/algorithms/weight_compression/lora_correction.py b/nncf/quantization/algorithms/weight_compression/lora_correction.py index efb7896da44..2255ff30027 100644 --- a/nncf/quantization/algorithms/weight_compression/lora_correction.py +++ b/nncf/quantization/algorithms/weight_compression/lora_correction.py @@ -24,6 +24,7 @@ from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters +from nncf.quantization.algorithms.weight_compression.weight_lowering import CompressedWeight from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_quantization @@ -105,7 +106,7 @@ def is_applicable(self, wc_params: WeightCompressionParameters): return wc_params.compression_config.num_bits == 4 def calculate_adapters( - self, weight: Tensor, compressed_weight: Tensor, wc_params: WeightCompressionParameters + self, weight: Tensor, compressed_weight: CompressedWeight, wc_params: WeightCompressionParameters ) -> Tuple[Tensor, Tensor, List[float]]: """ Calculates low rank matrices for a given original and compressed weights. @@ -134,7 +135,7 @@ def calculate_adapters( @staticmethod def calculate_low_rank_matrices( weight: Tensor, - compressed_weight: Tensor, + compressed_weight: CompressedWeight, compression_config: WeightCompressionConfig, reduction_axes: Tuple[int, ...], lora_correction_params: AdvancedLoraCorrectionParameters, diff --git a/nncf/quantization/algorithms/weight_compression/mixed_precision.py b/nncf/quantization/algorithms/weight_compression/mixed_precision.py index a58c408217f..50ce2bf24a8 100644 --- a/nncf/quantization/algorithms/weight_compression/mixed_precision.py +++ b/nncf/quantization/algorithms/weight_compression/mixed_precision.py @@ -354,7 +354,7 @@ def _calc_weight_sensitivity( if weight.dtype != TensorDataType.float32: weight = weight.astype(TensorDataType.float32) - compressed_weights, scale, zero_point = do_int_quantization(weight, reduction_axes, backup_config) + compressed_weights, scale, zero_point = do_int_quantization(weight, backup_config, reduction_axes) decompressed_weight = do_int_dequantization(compressed_weights, scale, zero_point) decompressed_weight = decompressed_weight.reshape(orig_shape) return fns.linalg.norm(decompressed_weight - weight, ord="fro").item() diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index 0dd54927b50..39c22993466 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -11,8 +11,6 @@ from typing import Dict, Iterable, List, Optional, Tuple import openvino as ov -from openvino import Type -from openvino.properties.hint import inference_precision from openvino.runtime import opset13 as opset import nncf @@ -31,6 +29,8 @@ from nncf.openvino.graph.metatypes import openvino_metatypes as om from nncf.openvino.graph.metatypes.groups import ATOMIC_ACTIVATIONS_OPERATIONS from nncf.openvino.graph.model_transformer import OVModelTransformer +from nncf.openvino.graph.node_utils import convert_op +from nncf.openvino.graph.node_utils import create_ov_const_from_tensor from nncf.openvino.graph.node_utils import get_const_value from nncf.openvino.graph.node_utils import get_weight_channel_axes from nncf.openvino.graph.transformations.command_creation import OVCommandCreator @@ -49,9 +49,13 @@ from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.lora_correction import LoraCorrectionAlgorithm +from nncf.quantization.algorithms.weight_compression.openvino_modeling import OVModelParameters +from nncf.quantization.algorithms.weight_compression.openvino_modeling import clear_ov_model_cache from nncf.quantization.algorithms.weight_compression.weight_lowering import compress_weight from nncf.tensor import Tensor +from nncf.tensor.definitions import TensorBackend from nncf.tensor.definitions import TensorDataType +from nncf.tensor.functions.ov_numeric import DTYPE_MAP_REV class OVWeightCompressionAlgoBackend(WeightCompressionAlgoBackend): @@ -130,17 +134,8 @@ def get_weight_dtype( self, node_with_weight: NNCFNode, weight_port_id: int, model: ov.Model, graph: NNCFGraph ) -> TensorDataType: ov_type_name = node_with_weight.layer_attributes.constant_attributes[weight_port_id]["dtype"] - dtype_map = { - "f16": TensorDataType.float16, - "bf16": TensorDataType.bfloat16, - "f32": TensorDataType.float32, - "f64": TensorDataType.float64, - "i8": TensorDataType.int8, - "i32": TensorDataType.int32, - "i64": TensorDataType.int64, - "u8": TensorDataType.uint8, - } - return dtype_map.get(ov_type_name) + ov_type = getattr(ov.Type, ov_type_name) + return DTYPE_MAP_REV[ov_type] @staticmethod def get_weight_shape(node_with_weight: NNCFNode, weight_port_id: int, graph: NNCFGraph) -> Tuple: @@ -244,24 +239,31 @@ def _create_compression_subgraph( raise nncf.ParameterNotSupportedError(f"{compression_config.mode.value} is not supported.") original_shape = weight.shape - compressed_weight = compress_weight(weight, reduction_axes, compression_config, layer_scales, layer_zero_points) + compressed_weight = compress_weight( + weight, + reduction_axes, + compression_config, + layer_scales, + layer_zero_points, + OVModelParameters(recompile=True, release_memory=False), + ) - compressed_const = opset.constant(compressed_weight.tensor.data, dtype=compression_dtype, name=const_node_name) + compressed_const = create_ov_const_from_tensor( + compressed_weight.tensor, compression_dtype, name=const_node_name + ) converted_const = opset.convert(compressed_const, ov.Type.f16) - if compressed_weight.zero_point is not None and compressed_weight.tensor.dtype == TensorDataType.uint8: - zero_point_const = opset.constant( - compressed_weight.zero_point.data, - dtype=compression_dtype, - name=f"{const_node_name}/zero_point", + + if compressed_weight.zero_point is not None: + zero_point_const = create_ov_const_from_tensor( + compressed_weight.zero_point, compression_dtype, name=f"{const_node_name}/zero_point" ) - converted_zero_point = opset.convert(zero_point_const, ov.Type.f16) + zero_point_const = opset.convert(zero_point_const, ov.Type.f16) converted_const = opset.subtract( - converted_const, converted_zero_point, name=f"{const_node_name}/zero_point/subtract" + converted_const, zero_point_const, name=f"{const_node_name}/zero_point/subtract" ) - scale_const = opset.constant(compressed_weight.scale.data, dtype=scale_dtype, name=f"{const_node_name}/scale") - if scale_dtype != ov.Type.f16: - scale_const = opset.convert(scale_const, ov.Type.f16) + scale_const = create_ov_const_from_tensor(compressed_weight.scale, scale_dtype, name=f"{const_node_name}/scale") + scale_const = convert_op(scale_const, ov.Type.f16) mul = opset.multiply( converted_const, @@ -291,7 +293,12 @@ def transform_model( const_node = self.name_to_node_mapping[const_node_name] const_node_output = const_node.output(0) const_dtype = const_node_output.get_element_type() - weight = Tensor(get_const_value(const_node)) + weight = get_const_value(const_node, cast_bf16_to_fp32=False) + # Creation of ov.Tensor is required for two reasons: + # 1. To be able to process BF16 weight properly + # 2. To indicate that it is allowed for the compressed constant to be returned as int4/uint4 if needed + weight = ov.Tensor(weight, weight.shape, const_dtype) + weight = Tensor(weight) should_add_convert_node = False if const_dtype != ov.Type.f16: @@ -319,8 +326,16 @@ def transform_model( mul_output = mul.output(0) for target_input in const_node.output(0).get_target_inputs(): target_input.replace_source_output(mul_output) - if lora_correction_algo is not None and lora_correction_algo.is_applicable(wc_params): + if weight.backend == TensorBackend.ov: + weight = weight.as_numpy_tensor() + if compressed_weight.tensor.backend == TensorBackend.ov: + compressed_weight.tensor = compressed_weight.tensor.as_numpy_tensor() + if ( + compressed_weight.zero_point is not None + and compressed_weight.zero_point.backend == TensorBackend.ov + ): + compressed_weight.zero_point = compressed_weight.zero_point.as_numpy_tensor() adapters = lora_correction_algo.calculate_adapters(weight, compressed_weight, wc_params) self.insert_adapters(wc_params, *adapters, int8_lora=lora_correction_algo.use_int8_adapters) @@ -335,57 +350,8 @@ def dump_parameters( ) -> None: dump_parameters(model, parameters, algo_name, path) - @staticmethod - def get_compress_decompress_pipeline(config: WeightCompressionConfig, w_shape, s_shape, z_p_shape=None): - parameters, clamp = OVWeightCompressionAlgoBackend.get_compress_pipeline( - config, w_shape, s_shape, z_p_shape, True - ) - - if len(parameters) == 3: - _, s, zp = parameters - result = (clamp - zp) * s - else: - s = parameters[1] - result = clamp * s - - model = ov.Model([result], parameters) - - compiled_model = ov.compile_model(model, device_name="CPU", config={inference_precision: Type.f32}) - - return lambda parameters: compiled_model(parameters)[0] - - @staticmethod - def get_compress_pipeline(config: WeightCompressionConfig, w_shape, s_shape, z_p_shape=None, return_nodes=False): - mode = config.mode - assert mode in [ - CompressWeightsMode.INT4_SYM, - CompressWeightsMode.INT4_ASYM, - ], f"Only int4 supported, but given={mode}" - num_bits = config.num_bits - - asym_quant = mode in [CompressWeightsMode.INT4_ASYM] - level_low = 0 if asym_quant else -(2 ** (num_bits - 1)) - level_high = 2**num_bits - 1 if asym_quant else 2 ** (num_bits - 1) - 1 - - w = opset.parameter(w_shape, name="w") - s = opset.parameter(s_shape, name="s") - parameters = [w, s] - compressed_w = w / s - if z_p_shape is not None: - zp = opset.parameter(z_p_shape, name="zp") - parameters.append(zp) - compressed_w += zp - - result = opset.clamp(opset.round(compressed_w), level_low, level_high, name="compressed_weights") - - if return_nodes: - return parameters, result - - model = ov.Model([result], parameters) - - compiled_model = ov.compile_model(model, device_name="CPU", config={inference_precision: Type.f32}) - - return lambda parameters: compiled_model(parameters)[0] + def __del__(self): + clear_ov_model_cache() class OVAWQAlgoAlgoBackend(AWQAlgoBackend, OVWeightCompressionAlgoBackend): diff --git a/nncf/quantization/algorithms/weight_compression/openvino_modeling.py b/nncf/quantization/algorithms/weight_compression/openvino_modeling.py new file mode 100644 index 00000000000..e2044609ab7 --- /dev/null +++ b/nncf/quantization/algorithms/weight_compression/openvino_modeling.py @@ -0,0 +1,517 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from dataclasses import dataclass +from functools import partial +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import openvino as ov +from openvino._pyopenvino.op import Parameter +from openvino._pyopenvino.properties.hint import inference_precision +from openvino.runtime import Node +from openvino.runtime import opset13 as opset + +from nncf.common.utils.decorators import ResultsCacheContainer +from nncf.common.utils.decorators import cache_results +from nncf.openvino.graph.node_utils import convert_op +from nncf.openvino.graph.node_utils import non_convertable_divide_op +from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig +from nncf.tensor import Tensor +from nncf.tensor import TensorDataType +from nncf.tensor.functions.ov_numeric import DTYPE_MAP as DTYPE_MAP_OV + +TensorList = List[Tensor] +ModelCallable = Callable[[TensorList], TensorList] +ReductionAxes = Union[int, Tuple[int, ...]] + + +OV_MODEL_CACHE = ResultsCacheContainer() + + +@dataclass(init=False) +class OVModelParameters: + """ + A class to hold parameters for building and inferring an OpenVINO model. + """ + + def __init__( + self, + input_dtypes: Optional[Dict[str, TensorDataType]] = None, + output_dtypes: Optional[Dict[str, TensorDataType]] = None, + dynamic_shapes: bool = True, + recompile: bool = False, + release_memory: bool = True, + share_inputs: bool = True, + share_outputs: bool = True, + return_ov_tensors: bool = False, + convertable_division: bool = False, + ): + """ + :param input_dtypes: Optional dictionary mapping input names to their data types. + :param output_dtypes: Optional dictionary mapping output names to their data types. + :param dynamic_shapes: Whether to use dynamic shapes for the model. When dynamic shapes are used and + recompile is False, it allows to save on the number of models stored in a model cache. + :param recompile: Whether to recompile the model before every inference. Otherwise, compiled models are cached. + :param release_memory: Whether to release memory after every inference. If memory is released, it will be + reallocated during every inference, reducing performance to some extent. + :param share_inputs: Whether to share input tensors. Avoids cloning inputs for inference. + :param share_outputs: Whether to share output tensors. Avoids cloning outputs after the inference. + :param return_ov_tensors: Whether to return results as OpenVINO tensors or NumPy arrays. + :param convertable_division: Whether to use convertable division for division operations. If True, division a/b + will be transformed at runtime to a*(1/b). + """ + self.input_dtypes = input_dtypes or {} + self.output_dtypes = output_dtypes or {} + self.dynamic_shapes = dynamic_shapes + self.recompile = recompile + self.release_memory = release_memory + self.share_inputs = share_inputs + self.share_outputs = share_outputs + self.return_ov_tensors = return_ov_tensors + self.convertable_division = convertable_division + + def __copy__(self): + return OVModelParameters( + input_dtypes=self.input_dtypes.copy(), + output_dtypes=self.output_dtypes.copy(), + dynamic_shapes=self.dynamic_shapes, + recompile=self.recompile, + release_memory=self.release_memory, + share_inputs=self.share_inputs, + share_outputs=self.share_outputs, + return_ov_tensors=self.return_ov_tensors, + convertable_division=self.convertable_division, + ) + + def __deepcopy__(self, memo): + return OVModelParameters( + input_dtypes=copy.deepcopy(self.input_dtypes, memo), + output_dtypes=copy.deepcopy(self.output_dtypes, memo), + dynamic_shapes=self.dynamic_shapes, + recompile=self.recompile, + release_memory=self.release_memory, + share_inputs=self.share_inputs, + share_outputs=self.share_outputs, + return_ov_tensors=self.return_ov_tensors, + convertable_division=self.convertable_division, + ) + + def __hash__(self): + return hash( + ( + frozenset(self.input_dtypes.items()), + frozenset(self.output_dtypes.items()), + self.dynamic_shapes, + self.recompile, + self.release_memory, + self.share_inputs, + self.share_outputs, + self.return_ov_tensors, + self.convertable_division, + ) + ) + + +ModelAsNodes = Tuple[List[Parameter], List[Node], OVModelParameters] + + +def clear_ov_model_cache(): + OV_MODEL_CACHE.clear() + + +def _infer_ov_model( + ov_model_params: OVModelParameters, compiled_model: ov.CompiledModel, inputs: TensorList +) -> TensorList: + """ + Run compiled OpenVINO model inference on the given inputs. + :param ov_model_params: OV model related parameters. + :param compiled_model: Compiled OpenVINO model. + :param inputs: Input tensors. + :return: List of output tensors. Tensor backend is OV if return_ov_tensors is True, else NumPy. + """ + # Check that input dtypes match the expected dtypes + for i, inp in enumerate(compiled_model.inputs): + input_name = inp.any_name + actual_dtype = inputs[i].dtype + expected_dtype = ov_model_params.input_dtypes[input_name] + if actual_dtype != expected_dtype: + raise ValueError(f"Expected input '{input_name}' to be {expected_dtype}. But found: {actual_dtype}.") + + # Infer the model + inputs = [inp.data for inp in inputs] + if ov_model_params.return_ov_tensors: + infer_request = compiled_model.create_infer_request() + infer_request.infer( + inputs, share_inputs=ov_model_params.share_inputs, share_outputs=ov_model_params.share_outputs + ) + outputs = [infer_request.get_output_tensor(i) for i in range(len(infer_request.results))] + else: + outputs = compiled_model( + inputs, share_inputs=ov_model_params.share_inputs, share_outputs=ov_model_params.share_outputs + ) + outputs = [outputs[i] for i in range(len(outputs))] + outputs = [Tensor(it) for it in outputs] + + if ov_model_params.release_memory: + compiled_model.release_memory() + + return outputs + + +def _prepare_compression_model_inputs( + ov_model_params, + weight_shape: Tuple, + scale_shape: Optional[Tuple], + zero_point_shape: Optional[Tuple], + reduction_axes: Optional[ReductionAxes], +) -> Tuple[Tuple, Optional[Tuple], Optional[Tuple]]: + """ + Do some input checks and convert static shapes to dynamic shapes if needed. + """ + if scale_shape is None and zero_point_shape is not None: + raise Exception("Zero point shape can only be provided if scale shape is provided.") + if scale_shape is None and reduction_axes is None: + raise ValueError("Reduction axes must be provided if scale shape is not provided.") + + # Set dynamic shapes if needed + if ov_model_params.dynamic_shapes: + weight_shape = (-1,) * len(weight_shape) + if scale_shape is not None: + scale_shape = (-1,) * (len(scale_shape) - 1) + (1,) + if zero_point_shape is not None: + zero_point_shape = (-1,) * (len(zero_point_shape) - 1) + (1,) + + return weight_shape, scale_shape, zero_point_shape + + +def get_compress_weight_model( + ov_model_params: OVModelParameters, + config: WeightCompressionConfig, + weight_shape: Tuple, + scale_shape: Optional[Tuple] = None, + zero_point_shape: Optional[Tuple] = None, + reduction_axes: Optional[ReductionAxes] = None, + return_nodes: Optional[bool] = False, +) -> Union[ModelCallable, ModelAsNodes]: + """ + Get a model that compresses weights using the given configuration. + :param ov_model_params: OV model parameters. + :param config: Compression configuration. + :param weight_shape: Shape of the weight to compress. Weight is assumed to be already reshaped as needed. + :param scale_shape: Optional shape of the scale. If not provided, scale will be computed by the OV model. + Otherwise, it is expected that the scale tensor is given as an input to the model. + :param zero_point_shape: Optional shape of the zero point tensor. If not provided and the mode is asymmetric, + zero point will be computed by the OV model. Otherwise, it is expected that the zero point tensor is provided + as an input. + :param reduction_axes: Optional axes to reduce the weight tensor. Not needed if scale (and z.p.) are provided as + inputs. + :param return_nodes: Whether to return the OV model inputs parameters and results nodes instead of the model + callable. + :return: A model callable that compresses weights using the given configuration. Or a model as nodes, if + `return_nodes` is True. + """ + + weight_shape, scale_shape, zero_point_shape = _prepare_compression_model_inputs( + ov_model_params, weight_shape, scale_shape, zero_point_shape, reduction_axes + ) + + return _build_compress_model( + config, + ov_model_params, + weight_shape, + scale_shape, + zero_point_shape, + reduction_axes, + return_nodes=return_nodes, + disable_caching=ov_model_params.recompile, + ) + + +def get_compress_decompress_weight_model( + ov_model_params: OVModelParameters, + config: WeightCompressionConfig, + weight_shape: Tuple, + scale_shape: Optional[Tuple] = None, + zero_point_shape: Optional[Tuple] = None, + reduction_axes: Optional[ReductionAxes] = None, + return_compressed_weight: Optional[bool] = False, +) -> ModelCallable: + """ + Get a model that performs compression and decompression of the given weight. + :param ov_model_params: OV model parameters. + :param config: Compression configuration. + :param weight_shape: Shape of the weight. Weight is assumed to be already reshaped as needed. + :param scale_shape: Optional shape of the scale. If not provided, scale will be computed by the OV model. + Otherwise, it is expected that the scale tensor is given as an input to the model. + :param zero_point_shape: Optional shape of the zero point tensor. If not provided and the mode is asymmetric, + zero point will be computed by the OV model. Otherwise, it is expected that the zero point is provided as an + input. + :param reduction_axes: Optional axes to reduce the weight tensor. Not needed if scale (and z.p.) are provided as + inputs. + :param return_compressed_weight: Whether to also return compressed weight, scale, (and zero point) besides the + decompressed weight. + :return: A model callable that returns a decompressed weight, and optionally compressed weight, scale, + (and zero point) if `return_compressed_weight` is True. + """ + + weight_shape, scale_shape, zero_point_shape = _prepare_compression_model_inputs( + ov_model_params, weight_shape, scale_shape, zero_point_shape, reduction_axes + ) + + return _build_compress_decompress_model( + config, + ov_model_params, + weight_shape, + scale_shape, + zero_point_shape, + reduction_axes, + return_compressed_weight, + disable_caching=ov_model_params.recompile, + ) + + +@cache_results(OV_MODEL_CACHE) +def _build_compress_model( + config: WeightCompressionConfig, + ov_model_params: OVModelParameters, + weight_shape: Tuple, + scale_shape: Optional[Tuple] = None, + zero_point_shape: Optional[Tuple] = None, + reduction_axes: Optional[ReductionAxes] = None, + return_nodes: bool = False, +) -> Union[ModelCallable, ModelAsNodes]: + is_asym_mode = config.is_asym_mode + + default_input_dtypes = { + "scale": TensorDataType.float32, + "zero_point": TensorDataType.int32, + } + default_output_dtypes = { + "compressed_weight": TensorDataType.uint8 if is_asym_mode else TensorDataType.int8, + "scale": TensorDataType.float32, + "zero_point": TensorDataType.int32, + } + + # Update input and output dtypes with the default values + ov_model_params = copy.deepcopy(ov_model_params) + ov_model_params.input_dtypes = {**default_input_dtypes, **ov_model_params.input_dtypes} + ov_model_params.output_dtypes = {**default_output_dtypes, **ov_model_params.output_dtypes} + + if "weight" not in ov_model_params.input_dtypes: + raise ValueError("Input weight dtype is required!") + + weight_dtype = ov_model_params.input_dtypes["weight"] + input_scale_dtype = ov_model_params.input_dtypes["scale"] + input_zero_point_dtype = ov_model_params.input_dtypes["zero_point"] + compressed_weight_dtype = ov_model_params.output_dtypes["compressed_weight"] + output_scale_dtype = ov_model_params.output_dtypes["scale"] + output_zero_point_dtype = ov_model_params.output_dtypes["zero_point"] + + # Validate input dtypes + valid_weight_dtypes = [TensorDataType.float32, TensorDataType.float16, TensorDataType.bfloat16] + if weight_dtype not in valid_weight_dtypes: + raise ValueError( + f"Weight must be one of the following data types: {valid_weight_dtypes}. But found: {weight_dtype}." + ) + if scale_shape is not None and input_scale_dtype != TensorDataType.float32: + raise ValueError(f"Input scale must be of float32 data type. But found: {input_scale_dtype}.") + if zero_point_shape is not None and input_zero_point_dtype not in [TensorDataType.int32, TensorDataType.float32]: + raise ValueError(f"Input zero point must be of int32/float32 data type. But found: {input_zero_point_dtype}.") + + # Validate output dtypes + valid_compressed_weight_dtypes = [ + TensorDataType.float32, + TensorDataType.int32, + TensorDataType.int8, + TensorDataType.uint8, + TensorDataType.int4, + TensorDataType.uint4, + ] + if compressed_weight_dtype not in valid_compressed_weight_dtypes: + raise ValueError( + f"Compressed weight must be one of the following data types: {valid_compressed_weight_dtypes}. " + f"But found: {compressed_weight_dtype}." + ) + if scale_shape is None and output_scale_dtype != TensorDataType.float32: + raise ValueError(f"Output scale must be of float32 data type. But found: {output_scale_dtype}.") + if is_asym_mode and zero_point_shape is None and output_zero_point_dtype not in valid_compressed_weight_dtypes: + raise ValueError( + f"Output zero point must be of one of the following data types: {valid_compressed_weight_dtypes}. " + f"But found: {output_zero_point_dtype}." + ) + + # Build OV model + weight = opset.parameter(weight_shape, name="weight", dtype=DTYPE_MAP_OV[weight_dtype]) + ov_parameters = [weight] + + num_bits = config.num_bits + eps = np.finfo(np.float32).eps + level_low = 0 if is_asym_mode else -(2 ** (num_bits - 1)) + level_high = 2**num_bits - 1 if is_asym_mode else 2 ** (num_bits - 1) - 1 + + divide_op = opset.divide if ov_model_params.convertable_division else non_convertable_divide_op + + min_values = None + if scale_shape is not None: + # Scale is given as an input + scale = opset.parameter(scale_shape, name="scale", dtype=DTYPE_MAP_OV[input_scale_dtype]) + ov_parameters.append(scale) + else: + # Compute scale + if is_asym_mode: + # [a1, r, a2] -> [a1, 1, a2] + min_values = opset.reduce_min(weight, reduction_axes=reduction_axes, keep_dims=True) + max_values = opset.reduce_max(weight, reduction_axes=reduction_axes, keep_dims=True) + min_values, max_values = opset.convert(min_values, ov.Type.f32), opset.convert(max_values, ov.Type.f32) + + levels = level_high - level_low + 1 + scale = divide_op(max_values - min_values, opset.constant(levels - 1, ov.Type.f32)) + scale = opset.select(opset.abs(scale) < eps, eps, scale) + else: + w_abs_min = opset.abs(opset.reduce_min(weight, reduction_axes=reduction_axes, keep_dims=True)) + w_max = opset.reduce_max(weight, reduction_axes=reduction_axes, keep_dims=True) + w_abs_min, w_max = opset.convert(w_abs_min, ov.Type.f32), opset.convert(w_max, ov.Type.f32) + + scale = opset.select(w_abs_min >= w_max, w_abs_min, opset.negative(w_max)) + scale = divide_op(scale, opset.constant(-level_low, ov.Type.f32)) + scale = opset.select(opset.abs(scale) < eps, eps, scale) + + zero_point = None + if zero_point_shape is not None: + # Zero point is given as an input + zero_point = opset.parameter(zero_point_shape, name="zero_point", dtype=DTYPE_MAP_OV[input_zero_point_dtype]) + ov_parameters.append(zero_point) + # Cast to float32 for an addition later + zero_point = convert_op(zero_point, ov.Type.f32) + elif is_asym_mode: + # Compute zero point + scaled_min_values = divide_op(min_values, scale) + zero_point = opset.constant(level_low, ov.Type.f32) - opset.round(scaled_min_values) + zero_point = opset.clamp(zero_point, level_low, level_high) + + weight = convert_op(weight, ov.Type.f32) + compressed_weight = divide_op(weight, scale) + + if is_asym_mode: + compressed_weight += zero_point + + compressed_weight = opset.round(compressed_weight) + compressed_weight = opset.clamp(opset.round(compressed_weight), level_low, level_high) + compressed_weight = convert_op(compressed_weight, DTYPE_MAP_OV[compressed_weight_dtype]) + + ov_results = [compressed_weight] + if len(ov_parameters) == 1: + ov_results.append(scale) + if zero_point is not None: + zero_point = convert_op(zero_point, DTYPE_MAP_OV[output_zero_point_dtype]) + ov_results.append(zero_point) + + if return_nodes: + return ov_parameters, ov_results, ov_model_params + + model = ov.Model(ov_results, ov_parameters) + compiled_model = ov.compile_model(model, device_name="CPU", config={inference_precision(): ov.Type.f32}) + + return partial(_infer_ov_model, ov_model_params, compiled_model) + + +@cache_results(OV_MODEL_CACHE) +def _build_compress_decompress_model( + config: WeightCompressionConfig, + ov_model_params: OVModelParameters, + weight_shape: Tuple, + scale_shape: Optional[Tuple] = None, + zero_point_shape: Optional[Tuple] = None, + reduction_axes: Optional[ReductionAxes] = None, + return_compressed_weight: Optional[bool] = False, +) -> ModelCallable: + default_output_dtypes = {"decompressed_weight": TensorDataType.float32} + if not return_compressed_weight: + # If compressed weight is not returned to a user, we can keep it in float32 to avoid additional conversion + default_output_dtypes["compressed_weight"] = TensorDataType.float32 + ov_model_params = copy.deepcopy(ov_model_params) + ov_model_params.output_dtypes = {**default_output_dtypes, **ov_model_params.output_dtypes} + + decompressed_weight_dtype = ov_model_params.output_dtypes["decompressed_weight"] + if decompressed_weight_dtype != TensorDataType.float32: + raise ValueError(f"Decompressed weight must be of float32 data type. But found: {decompressed_weight_dtype}.") + + # Get compression model as input/result nodes and potentially modified ov model parameters + ov_parameters, ov_results, ov_model_params = get_compress_weight_model( + ov_model_params, config, weight_shape, scale_shape, zero_point_shape, reduction_axes, return_nodes=True + ) + + if config.is_asym_mode: + if len(ov_parameters) == 1: + # weight -> compressed_weight, scale, zero_point + compressed_weight, scale, zero_point = ov_results + else: + # weight, scale, zero_point -> compressed_weight + compressed_weight = ov_results[0] + scale, zero_point = ov_parameters[1:] + + compressed_weight = convert_op(compressed_weight, ov.Type.i32) - convert_op(zero_point, ov.Type.i32) + else: + if len(ov_parameters) == 1: + # weight -> compressed_weight, scale + compressed_weight, scale = ov_results + else: + # weight, scale -> compressed_weight + compressed_weight = ov_results[0] + scale = ov_parameters[1] + + decompressed_weight = opset.multiply(scale, convert_op(compressed_weight, ov.Type.f32)) + + ov_results = [decompressed_weight] + ov_results if return_compressed_weight else [decompressed_weight] + model = ov.Model(ov_results, ov_parameters) + compiled_model = ov.compile_model(model, device_name="CPU", config={inference_precision(): ov.Type.f32}) + + return partial(_infer_ov_model, ov_model_params, compiled_model) + + +def get_astype_model(ov_model_params: OVModelParameters, input_shape: Tuple) -> ModelCallable: + """ + Return a model that cast the input of the given shape to the given data type. Especially useful for + casting from/to data types not supported by NumPy such as bfloat16, uint4 and int4. + These data types are represented as the following data types in numpy: + - bfloat16 -> np.float16, + - uint4 -> uint8, + - int4 -> int8. + :param ov_model_params: OV model related parameters. + :param input_shape: Shape of the tensor to cast. + :return: A model callable that casts the input tensor to the given data type. + """ + if ov_model_params.dynamic_shapes: + input_shape = (-1,) * len(input_shape) + return _build_astype_model(ov_model_params, input_shape, disable_caching=ov_model_params.recompile) + + +@cache_results(OV_MODEL_CACHE) +def _build_astype_model(ov_model_params: OVModelParameters, arg_shape: Tuple) -> ModelCallable: + input_dtypes = ov_model_params.input_dtypes + if input_dtypes is None: + raise ValueError("Input dtypes must be provided.") + output_dtypes = ov_model_params.output_dtypes + if output_dtypes is None: + raise ValueError("Output dtypes must be provided.") + if "input" not in input_dtypes: + raise ValueError("Input dtype is required.") + if "output" not in output_dtypes: + raise ValueError("Output dtype is required.") + + arg = opset.parameter(arg_shape, dtype=DTYPE_MAP_OV[input_dtypes["input"]], name="input") + res = opset.convert(arg, DTYPE_MAP_OV[output_dtypes["output"]]) + model = ov.Model([res], [arg]) + compiled_model = ov.compile_model(model, device_name="CPU", config={inference_precision(): ov.Type.f32}) + + return partial(_infer_ov_model, ov_model_params, compiled_model) diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index fb13ace6341..80ba13cf428 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -21,16 +21,16 @@ from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic +from nncf.import_utils import is_openvino_available from nncf.parameters import CompressWeightsMode from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats -from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_normalized_weight_and_fp4_scale -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_quantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_quantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import quantize_dequantize_weight from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization from nncf.tensor import Tensor from nncf.tensor import TensorDataType @@ -44,8 +44,6 @@ class ScaleEstimation: Scale estimation algorithm implementation. """ - compress_decompress_cache = {} - def __init__( self, model: TModel, @@ -148,7 +146,6 @@ def apply( weight = self._backend_entity.get_weight(wp.node_with_weight, weight_port_id, model, graph) scales[weight_name], zero_points[weight_name] = self.calculate_quantization_params( - self._backend_entity, stats, weight, wp.reduction_axes, @@ -163,7 +160,6 @@ def apply( @staticmethod def calculate_quantization_params( - backend_entity: WeightCompressionAlgoBackend, statistics: WCTensorStatistic, weight: Tensor, reduction_axes: Tuple[int, ...], @@ -183,7 +179,6 @@ def calculate_quantization_params( 1. Initial scale rectification based on activation statistics. 2. A grid search to further refine the scale parameters. - :param backend_entity: The backend-specific implementation of the weight compression algorithm. :param statistics: The input activations of the layer reduced over batch and sequence length dimensions, together with original activation tensor shapes. :param weight: The weight tensor that is being quantized. @@ -219,12 +214,14 @@ def calculate_quantization_params( ) compressed_weights = do_nf4_quantization(norm_weight, scale, is_normalized_weight=True) q_weights = do_nf4_dequantization(compressed_weights, scale, reduction_axis) + q_weights, _ = reshape_weight_for_grouped_quantization(q_weights, reduction_axis, group_size) zp = None else: - compressed_weights, scale, zp = do_int_quantization(original_weight, reduction_axis, cur_config) + q_weights, compressed_weights, scale, zp = quantize_dequantize_weight( + original_weight, cur_config, reduction_axis, return_compressed_weight=True + ) if zp is not None: zp = zp.astype(scale.dtype) - q_weights = do_int_dequantization(compressed_weights, scale, zp, reduction_axis) s = fns.unsqueeze(s, 0) s, _ = reshape_weight_for_grouped_quantization(s, reduction_axis, group_size) @@ -243,7 +240,6 @@ def calculate_quantization_params( importance = importance / (denum + eps) X, _ = reshape_weight_for_grouped_quantization(X, 0, group_size) - q_weights, _ = reshape_weight_for_grouped_quantization(q_weights, reduction_axis, group_size) best_diffs = None result_scale = None @@ -256,41 +252,35 @@ def calculate_quantization_params( if weight_penalty > 0.0: min_max_scale_diffs += weight_penalty * fns.mean((q_weights - original_weight) ** 2, axis=-1) - zp_shape = zp.shape if zp is not None else None - key = (config.mode, config.num_bits) + q_weights.shape + scale.shape - if zp is not None: - key += zp_shape - if config.mode != CompressWeightsMode.NF4: - if key in ScaleEstimation.compress_decompress_cache: - compress_decompress_model = ScaleEstimation.compress_decompress_cache[key]["compress_decompress_model"] - compress_model = ScaleEstimation.compress_decompress_cache[key]["compress_model"] - else: - compress_decompress_model = backend_entity.get_compress_decompress_pipeline( - config, q_weights.shape, scale.shape, zp_shape - ) - compress_model = backend_entity.get_compress_pipeline(config, q_weights.shape, scale.shape, zp_shape) - ScaleEstimation.compress_decompress_cache[key] = { - "compress_decompress_model": compress_decompress_model, - "compress_model": compress_model, - } scale_sign = scale / fns.abs(scale) zero_scale = 0.001 zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) - input_tensors = [original_weight.data, None] - if zp is not None: - input_tensors.append(zp.data) + if is_openvino_available(): + # This is required for alignment with a previous OpenVINO models implementation + from nncf.quantization.algorithms.weight_compression.openvino_modeling import OVModelParameters + + ov_model_params = OVModelParameters(dynamic_shapes=False, convertable_division=True) + else: + ov_model_params = None + # iterative rectification of initial scale for i in range(initial_steps): near_to_ideal_scale = estimate_scales(original_weight, target, zero_mask, importance) near_to_ideal_scale = near_to_ideal_scale * scale_sign - input_tensors[1] = near_to_ideal_scale.data if config.mode == CompressWeightsMode.NF4: g_compressed_weighs = do_nf4_quantization(original_weight, near_to_ideal_scale) out = do_nf4_dequantization(g_compressed_weighs, near_to_ideal_scale) else: - out = compress_decompress_model(input_tensors) + out = quantize_dequantize_weight( + original_weight, + config, + precomputed_scale=near_to_ideal_scale, + precomputed_zero_point=zp, + ov_model_params=ov_model_params, + ) + q_weights_ = fns.zeros_like(original_weight) + out q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X) @@ -313,13 +303,18 @@ def calculate_quantization_params( else: near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale result_scale = near_to_ideal_scale - input_tensors[1] = near_to_ideal_scale.data if i < initial_steps - 1: if config.mode == CompressWeightsMode.NF4: out = do_nf4_quantization(original_weight, near_to_ideal_scale) else: - out = compress_model(input_tensors) + out, _, _ = do_int_quantization( + original_weight, + config, + precomputed_scale=near_to_ideal_scale, + precomputed_zero_point=zp, + ov_model_params=ov_model_params, + ) compressed_weights = fns.zeros_like(original_weight) + out target, zero_mask = get_target_zero_mask(compressed_weights, zp) zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) @@ -329,11 +324,16 @@ def calculate_quantization_params( factor = 1.0 - 0.05 * scale_step scaled_scale = factor * scale - input_tensors[1] = scaled_scale.data if config.mode == CompressWeightsMode.NF4: out = do_nf4_quantization(original_weight, scaled_scale) else: - out = compress_model(input_tensors) + out, _, _ = do_int_quantization( + original_weight, + config, + precomputed_scale=scaled_scale, + precomputed_zero_point=zp, + ov_model_params=ov_model_params, + ) compressed_weights = fns.zeros_like(original_weight) + out target, zero_mask = get_target_zero_mask(compressed_weights, zp) @@ -341,12 +341,17 @@ def calculate_quantization_params( near_to_ideal_scale = estimate_scales(original_weight, target, zero_mask, importance) near_to_ideal_scale = near_to_ideal_scale * scale_sign - input_tensors[1] = near_to_ideal_scale.data if config.mode == CompressWeightsMode.NF4: g_compressed_weighs = do_nf4_quantization(original_weight, near_to_ideal_scale) out = do_nf4_dequantization(g_compressed_weighs, near_to_ideal_scale) else: - out = compress_decompress_model(input_tensors) + out = quantize_dequantize_weight( + original_weight, + config, + precomputed_scale=near_to_ideal_scale, + precomputed_zero_point=zp, + ov_model_params=ov_model_params, + ) q_weights_ = fns.zeros_like(original_weight) + out q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X) diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/nncf/quantization/algorithms/weight_compression/weight_lowering.py index 4e34e4816d6..92bacbe8cf0 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -8,21 +8,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import copy +import logging from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import numpy as np import nncf +from nncf.common.logging.logger import log_once +from nncf.import_utils import is_openvino_available from nncf.parameters import CompressWeightsMode from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.fake_quantize import calculate_scale_zero_point from nncf.tensor import Tensor from nncf.tensor import functions as fns +from nncf.tensor.definitions import TensorBackend from nncf.tensor.definitions import TensorDataType -ReductionAxes = Tuple[int, ...] +ReductionAxes = Union[int, Tuple[int, ...]] NF4_QUANTILES = np.array( [ @@ -249,7 +253,9 @@ def calculate_normalized_weight_and_fp4_scale( def calculate_integer_quantization_params( - weight: Tensor, reduction_axes: ReductionAxes, config: WeightCompressionConfig + weight: Tensor, + reduction_axes: ReductionAxes, + config: WeightCompressionConfig, ) -> Tuple[Tensor, Tensor]: """ Calculates the scale and zero point for uniform quantization (INT4, INT8), when the range of values is divided into @@ -260,14 +266,13 @@ def calculate_integer_quantization_params( :param config: Weight compression configuration. :return: Scale and zero point tensors. """ - mode = config.mode - assert config.is_integer(), "The function supports integer quantization only" + assert config.is_integer, "The function supports integer quantization only" num_bits = config.num_bits if weight.dtype != TensorDataType.float32: weight = weight.astype(TensorDataType.float32) - if mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT4_ASYM]: + if config.is_asym_mode: level_low = 0 level_high = 2**num_bits - 1 min_values = fns.min(weight, axis=reduction_axes, keepdims=True) # [a1, r, a2] -> [a1, 1, a2] @@ -286,7 +291,6 @@ def calculate_quantized_weight( config: WeightCompressionConfig, scale: Tensor, zero_point: Optional[Tensor] = None, - invert_scale=False, ) -> Tensor: """ Quantizes the weight tensor using the provided scale and zero point. @@ -295,7 +299,6 @@ def calculate_quantized_weight( :param config: Weight compression configuration. :param scale: Scale tensor used for quantization. :param zero_point: Zero point tensor used for quantization. - :param invert_scale: applies inversion for scale and then multiply by weights instead of division. :return: Quantized weight tensor of uint8 or int8 type. """ if weight.dtype != TensorDataType.float32: @@ -304,16 +307,12 @@ def calculate_quantized_weight( scale = scale.astype(TensorDataType.float32) num_bits = config.num_bits - asym_quant = config.mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT4_ASYM] + asym_quant = config.is_asym_mode dtype = TensorDataType.uint8 if asym_quant else TensorDataType.int8 level_low = 0 if asym_quant else -(2 ** (num_bits - 1)) level_high = 2**num_bits - 1 if asym_quant else 2 ** (num_bits - 1) - 1 - if invert_scale: - scale = fns.power(scale, -1) - compressed_weights = weight * scale - else: - compressed_weights = weight / scale + compressed_weights = weight / scale if zero_point is not None: compressed_weights += zero_point.astype(weight.dtype) compressed_weights = fns.round(compressed_weights) @@ -322,70 +321,10 @@ def calculate_quantized_weight( return compressed_weights -def do_int_quantization( +def get_integer_quantization_error( weight: Tensor, reduction_axes: ReductionAxes, config: WeightCompressionConfig, - precomputed_scale: Tensor = None, - precomputed_zero_point: Tensor = None, - invert_scale=False, -) -> Tuple[Tensor, Tensor, Tensor]: - """ - The method quantizes the given weights to integer data type uniformly in accordance with the compression config. - The config defines a quantization mode: - INT8_SYM mode refers to signed int8 symmetric weight compression without zero point - - quantization to [-128, 127] range. - INT8_ASYM mode refers to unsigned int8 asymmetric weight compression with a typical non-fixed zero-point - - quantization to [0, 255] range. - INT4_ASYM mode refers to unsigned int4 asymmetric weight compression with a typical non-fixed zero-point - - quantization to [0, 15] range. - INT4_SYM mode refers to signed int4 symmetric weight compression without zero point - - quantization to [-8, 7] range. - NF4 or E2M1 mode requires a dedicated procedure and it is not supported in this method. - One of the parameter of compression config is a group size. Quantization is per-channel, if group size equals to -1, - otherwise it's per-group, i.e. group size number of weights in the channel dimension share quantization parameters - (scales). - - :param weight: Weight array to compress. - :param reduction_axes: Axes, along which to reduce (collect) different statistics (e.g. min, max). - :param config: Information on how to compress (quantize) a specific weight. - :param precomputed_scale: Precomputed scale. - :param precomputed_zero_point: Precomputed zero point. - :param invert_scale: applies inversion for scale and then multiply by weights instead of division. - Need as reference implementation for OV. - :return: The compressed weights tensor of uint8 (asymmetric mode) or int8 (symmetric mode) type, - scale tensor of float32 type and zero point tensor of int32 type that was used for its quantization. - """ - assert config.is_integer(), "The function supports integer quantization only" - group_size = config.group_size - is_asym = config.mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT4_ASYM] - if is_asym and (precomputed_scale is None) != (precomputed_zero_point is None): - raise ValueError( - "If precomputed quantization parameters are provided, both scale and zero point are required " - "for asymmetric quantization." - ) - - if weight.dtype != TensorDataType.float32: - weight = weight.astype(TensorDataType.float32) - - if group_size != -1: - # weights are reshaped from [a1, r, a2] to [a1, r//gs, gs, a2] - weight, reduction_axes = reshape_weight_for_grouped_quantization(weight, reduction_axes, group_size) - - scale, zero_point = None, None - if precomputed_scale is None or (is_asym and precomputed_zero_point is None): - scale, zero_point = calculate_integer_quantization_params(weight, reduction_axes, config) - if precomputed_scale is not None: - scale = precomputed_scale - if precomputed_zero_point is not None: - zero_point = precomputed_zero_point - - compressed_weights = calculate_quantized_weight(weight, config, scale, zero_point, invert_scale) - return compressed_weights, scale, zero_point - - -def get_integer_quantization_error( - weight: Tensor, reduction_axes: ReductionAxes, config: WeightCompressionConfig ) -> float: """ Calculates a quantity characterizing the difference between floating point weights and fake quantized @@ -401,8 +340,7 @@ def get_integer_quantization_error( if weight.dtype != TensorDataType.float32: weight = weight.astype(TensorDataType.float32) - compressed_weights, scale, zero_point = do_int_quantization(weight, reduction_axes, config) - decompressed_weight = do_int_dequantization(compressed_weights, scale, zero_point) + decompressed_weight = quantize_dequantize_weight(weight, config, reduction_axes) decompressed_weight = decompressed_weight.reshape(orig_shape) diff = (decompressed_weight - weight) ** 2 @@ -417,6 +355,7 @@ def compress_weight( config: WeightCompressionConfig, precomputed_scale: Tensor = None, precomputed_zero_point: Tensor = None, + ov_model_params: Optional = None, ): """ Compress weight using compression configuration. @@ -426,15 +365,19 @@ def compress_weight( :param config: Compression configuration. :param precomputed_scale: Precomputed scale. :param precomputed_zero_point: Precomputed zero point. + :param ov_model_params: OpenVINO model parameters for acceleration. :return: The compressed weight and decompression parameters as instance of CompressedWeight """ - if not config.is_integer(): + if not config.is_integer: + if weight.backend == TensorBackend.ov: + weight = weight.as_numpy_tensor() + compressed_weight, scale = calculate_normalized_weight_and_fp4_scale( weight, reduction_axes, config.group_size, precomputed_scale, config.mode ) return CompressedWeight(compressed_weight, scale) compressed_weight, scale, zero_point = do_int_quantization( - weight, reduction_axes, config, precomputed_scale, precomputed_zero_point + weight, config, reduction_axes, precomputed_scale, precomputed_zero_point, ov_model_params ) return CompressedWeight(compressed_weight, scale, zero_point) @@ -472,10 +415,206 @@ def do_int_dequantization( original shapes. If equals to -1: weights are not reshaped, assumed not a group quantization. Default to -1. :return: dequantized/decompressed weights. """ - decompressed_weight = compressed_weights - zero_point if zero_point is not None else compressed_weights + decompressed_weight = ( + compressed_weights.astype(TensorDataType.int32) - zero_point if zero_point is not None else compressed_weights + ) decompressed_weight = decompressed_weight.astype(scale.dtype) * scale if reduction_axis > -1: decompressed_weight = ungroup_weights(decompressed_weight, reduction_axis) return decompressed_weight + + +def do_int_quantization( + weight: Tensor, + config: WeightCompressionConfig, + reduction_axes: Optional[ReductionAxes] = None, + precomputed_scale: Tensor = None, + precomputed_zero_point: Tensor = None, + ov_model_params: Optional = None, +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Performs integer quantization on the given weight tensor. + + :param weight: The weight tensor to quantize. + :param config: The weight compression configuration. + :param reduction_axes: Axes along which to reduce (collect) statistics (e.g., min, max). Not required if + precomputed scale (and zero point) are provided. + :param precomputed_scale: Optional precomputed scale tensor. + :param precomputed_zero_point: Optional precomputed zero point tensor. + :param ov_model_params: OpenVINO model parameters for acceleration. + :return: A tuple containing the compressed weights, scale, and zero point tensors. + """ + assert config.is_integer, "The function supports integer quantization only" + if config.is_asym_mode and (precomputed_scale is None) != (precomputed_zero_point is None): + raise ValueError( + "If precomputed quantization parameters are provided, both scale and zero point are required " + "for asymmetric quantization." + ) + + accelerate_through_ov = is_openvino_available() and weight.backend != TensorBackend.torch + if not is_openvino_available() and weight.backend != TensorBackend.torch: + log_once(logging.INFO, "Compression time may be improved after installing OpenVINO") + + # When reduction axes are not provided, assuming that the weights are already reshaped + if config.group_size != -1 and reduction_axes is not None: + # weights are reshaped from [a1, r, a2] to [a1, r//gs, gs, a2] + weight, reduction_axes = reshape_weight_for_grouped_quantization(weight, reduction_axes, config.group_size) + + if not accelerate_through_ov: + # Reference implementation + if weight.backend == TensorBackend.ov: + weight = weight.as_numpy_tensor() + if weight.dtype != TensorDataType.float32: + weight = weight.astype(TensorDataType.float32) + + scale, zero_point = None, None + if precomputed_scale is None or (config.is_asym_mode and precomputed_zero_point is None): + if reduction_axes is None: + raise ValueError("Reduction axes are required for computing the scale and (zero point) vectors.") + scale, zero_point = calculate_integer_quantization_params(weight, reduction_axes, config) + if precomputed_scale is not None: + scale = precomputed_scale + if precomputed_zero_point is not None: + zero_point = precomputed_zero_point + + compressed_weights = calculate_quantized_weight(weight, config, scale, zero_point) + return compressed_weights, scale, zero_point + + from nncf.quantization.algorithms.weight_compression.openvino_modeling import OVModelParameters + from nncf.quantization.algorithms.weight_compression.openvino_modeling import get_compress_weight_model + + weight_shape = weight.shape + scale_shape = None if precomputed_scale is None else precomputed_scale.shape + zero_point_shape = None if precomputed_zero_point is None else precomputed_zero_point.shape + + ov_model_params = OVModelParameters() if ov_model_params is None else copy.deepcopy(ov_model_params) + ov_model_params.input_dtypes["weight"] = weight.dtype + if precomputed_scale is not None: + ov_model_params.input_dtypes["scale"] = precomputed_scale.dtype + if precomputed_zero_point is not None: + ov_model_params.input_dtypes["zero_point"] = precomputed_zero_point.dtype + if config.num_bits == 4 and weight.backend == TensorBackend.ov: + # Return ov tensors in target precision to seamlessly insert them into openvino model later + ov_model_params.return_ov_tensors = weight.backend == TensorBackend.ov + compressed_weight_dtype = TensorDataType.uint4 if config.is_asym_mode else TensorDataType.int4 + ov_model_params.output_dtypes.update( + {"compressed_weight": compressed_weight_dtype, "zero_point": compressed_weight_dtype} + ) + + model = get_compress_weight_model( + ov_model_params, + config, + weight_shape, + scale_shape, + zero_point_shape, + reduction_axes, + ) + + if precomputed_scale is None: + # weight -> compressed_weight, scale, (zero_point) + results = model([weight]) + if config.is_asym_mode: + compressed_weight, scale, zero_point = results + else: + compressed_weight, scale = results + zero_point = None + + # Scale is always in fp32 so there is no need to store it in ov.Tensor + if scale.backend == TensorBackend.ov: + scale = scale.as_numpy_tensor() + else: + # weight, scale, (zero_point) -> compressed_weight + inputs = ( + [weight, precomputed_scale] + if precomputed_zero_point is None + else [weight, precomputed_scale, precomputed_zero_point] + ) + compressed_weight = model(inputs)[0] + scale, zero_point = precomputed_scale, precomputed_zero_point + + return compressed_weight, scale, zero_point + + +def quantize_dequantize_weight( + weight: Tensor, + config: WeightCompressionConfig, + reduction_axes: Optional[ReductionAxes] = None, + precomputed_scale: Optional[Tensor] = None, + precomputed_zero_point: Optional[Tensor] = None, + return_compressed_weight: Optional[bool] = False, + ov_model_params: Optional = None, +) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor]]: + """ + First quantizes the given weight tensor and then dequantizes it back to obtain float32 values. + :param weight: The weight tensor to quantize-dequantize. + :param config: Compression configuration. + :param reduction_axes: Axes along which to reduce (collect) statistics (e.g., min, max). Not required if + precomputed scale (and zero point) are provided. + :param precomputed_scale: Optional precomputed scale tensor. + :param precomputed_zero_point: Optional precomputed zero point tensor. + :param return_compressed_weight: If True, besides decompressed weight will also return compressed weight, scale, + (and zero point). + :param ov_model_params: OpenVINO model parameters for acceleration. + :return: Dequantized weight tensor or a tuple containing the decompressed weight, compressed weight, scale, + (and zero point). + """ + accelerate_through_ov = is_openvino_available() and weight.backend != TensorBackend.torch + if not is_openvino_available() and weight.backend != TensorBackend.torch: + log_once(logging.INFO, "Compression time may be improved after installing OpenVINO") + + if not accelerate_through_ov: + # Reference implementation + compressed_weight, scale, zero_point = do_int_quantization( + weight, config, reduction_axes, precomputed_scale, precomputed_zero_point + ) + decompressed_weight = do_int_dequantization(compressed_weight, scale, zero_point) + if return_compressed_weight: + return decompressed_weight, compressed_weight, scale, zero_point + else: + return decompressed_weight + + from nncf.quantization.algorithms.weight_compression.openvino_modeling import OVModelParameters + from nncf.quantization.algorithms.weight_compression.openvino_modeling import get_compress_decompress_weight_model + + # When reduction axes are not provided, assuming that the weights are already reshaped + if config.group_size != -1 and reduction_axes is not None: + # weights are reshaped from [a1, r, a2] to [a1, r//gs, gs, a2] + weight, reduction_axes = reshape_weight_for_grouped_quantization(weight, reduction_axes, config.group_size) + + weight_shape = weight.shape + scale_shape = precomputed_scale.shape if precomputed_scale is not None else None + zero_point_shape = precomputed_zero_point.shape if precomputed_zero_point is not None else None + + ov_model_params = OVModelParameters() if ov_model_params is None else copy.deepcopy(ov_model_params) + ov_model_params.input_dtypes["weight"] = weight.dtype + if precomputed_scale is not None: + ov_model_params.input_dtypes["scale"] = precomputed_scale.dtype + if precomputed_zero_point is not None: + ov_model_params.input_dtypes["zero_point"] = precomputed_zero_point.dtype + + model = get_compress_decompress_weight_model( + ov_model_params, config, weight_shape, scale_shape, zero_point_shape, reduction_axes, return_compressed_weight + ) + + inputs = [weight] + if precomputed_scale is not None: + inputs.append(precomputed_scale) + if precomputed_zero_point is not None: + inputs.append(precomputed_zero_point) + + compressed_weight, scale, zero_point = None, precomputed_scale, precomputed_zero_point + results = model(inputs) + if len(results) == 1: + decompressed_weight = results[0] + elif len(results) == 2: + decompressed_weight, compressed_weight = results + elif len(results) == 3: + decompressed_weight, compressed_weight, scale = results + else: + decompressed_weight, compressed_weight, scale, zero_point = results + if return_compressed_weight: + return decompressed_weight, compressed_weight, scale, zero_point + else: + return decompressed_weight diff --git a/nncf/quantization/fake_quantize.py b/nncf/quantization/fake_quantize.py index ead4463dc86..fb7d893853d 100644 --- a/nncf/quantization/fake_quantize.py +++ b/nncf/quantization/fake_quantize.py @@ -341,7 +341,11 @@ def _calculate_scaled_parameters( def calculate_scale_zero_point( - input_low: Tensor, input_high: Tensor, level_low: int, level_high: int, narrow_range: bool + input_low: Tensor, + input_high: Tensor, + level_low: int, + level_high: int, + narrow_range: bool, ) -> Tuple[Tensor, Tensor]: """ Calculates scale and zero_point values for the quantizer. diff --git a/nncf/tensor/definitions.py b/nncf/tensor/definitions.py index 520782adf88..db0fa5668aa 100644 --- a/nncf/tensor/definitions.py +++ b/nncf/tensor/definitions.py @@ -21,6 +21,7 @@ class TensorBackend(Enum): numpy = auto() torch = auto() + ov = auto() class TensorDataType(Enum): @@ -36,6 +37,8 @@ class TensorDataType(Enum): int32 = auto() int64 = auto() uint8 = auto() + uint4 = auto() + int4 = auto() def is_float(self): """ diff --git a/nncf/tensor/functions/__init__.py b/nncf/tensor/functions/__init__.py index 1defabc92d6..48303f463eb 100644 --- a/nncf/tensor/functions/__init__.py +++ b/nncf/tensor/functions/__init__.py @@ -79,5 +79,8 @@ def _initialize_backends(): import nncf.tensor.functions.torch_linalg import nncf.tensor.functions.torch_numeric # noqa F401 + with contextlib.suppress(ImportError): + import nncf.tensor.functions.ov # noqa: F401 + _initialize_backends() diff --git a/nncf/tensor/functions/numeric.py b/nncf/tensor/functions/numeric.py index f27ee169914..55965514342 100644 --- a/nncf/tensor/functions/numeric.py +++ b/nncf/tensor/functions/numeric.py @@ -927,3 +927,14 @@ def tensor( :return: A tensor created from the given data. """ return Tensor(get_numeric_backend_fn("tensor", backend)(data, dtype=dtype, device=device)) + + +@functools.singledispatch +@tensor_guard +def as_numpy_tensor(a: Tensor) -> Tensor: + """ + Change backend of the tensor to numpy. + :param a: Tensor to change backend for. + :return: Tensor in numpy backend. + """ + return Tensor(as_numpy_tensor(a.data)) diff --git a/nncf/tensor/functions/numpy_numeric.py b/nncf/tensor/functions/numpy_numeric.py index a4f7657189c..9c2f65e3a34 100644 --- a/nncf/tensor/functions/numpy_numeric.py +++ b/nncf/tensor/functions/numpy_numeric.py @@ -391,6 +391,11 @@ def _(a: np.ndarray, v: np.ndarray, side: str = "left", sorter: Optional[np.ndar return np.searchsorted(a, v, side, sorter) +@register_numpy_types(numeric.as_numpy_tensor) +def _(a: np.ndarray) -> np.ndarray: + return a + + def zeros( shape: Tuple[int, ...], *, diff --git a/nncf/tensor/functions/ov_numeric.py b/nncf/tensor/functions/ov_numeric.py new file mode 100644 index 00000000000..4adcacd2c08 --- /dev/null +++ b/nncf/tensor/functions/ov_numeric.py @@ -0,0 +1,111 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Union + +import numpy as np +import openvino as ov + +from nncf.tensor import Tensor +from nncf.tensor import TensorDataType +from nncf.tensor.functions import numeric + +from ..definitions import TensorBackend +from ..definitions import TensorDeviceType +from .numpy_numeric import DTYPE_MAP as DTYPE_MAP_NP + +DTYPE_MAP = { + TensorDataType.float16: ov.Type.f16, + TensorDataType.bfloat16: ov.Type.bf16, + TensorDataType.float32: ov.Type.f32, + TensorDataType.float64: ov.Type.f64, + TensorDataType.int8: ov.Type.i8, + TensorDataType.int32: ov.Type.i32, + TensorDataType.int64: ov.Type.i64, + TensorDataType.uint8: ov.Type.u8, + TensorDataType.uint4: ov.Type.u4, + TensorDataType.int4: ov.Type.i4, +} + +DTYPE_MAP_REV = {v: k for k, v in DTYPE_MAP.items()} + + +@numeric.device.register(ov.Tensor) +def _(a: ov.Tensor) -> TensorDeviceType: + return TensorDeviceType.CPU + + +@numeric.backend.register(ov.Tensor) +def _(a: ov.Tensor) -> TensorBackend: + return TensorBackend.ov + + +@numeric.astype.register(ov.Tensor) +def _(a: ov.Tensor, dtype: TensorDataType) -> ov.Tensor: + if a.get_element_type() in [ov.Type.bf16, ov.Type.i4, ov.Type.u4] or dtype in [ + TensorDataType.bfloat16, + TensorDataType.int4, + TensorDataType.uint4, + ]: + # Cannot cast to/from bfloat16, uint4, int4 directly + return _astype_ov(a, dtype) + return ov.Tensor(a.data.astype(DTYPE_MAP_NP[dtype])) + + +@numeric.dtype.register(ov.Tensor) +def _(a: ov.Tensor) -> TensorDataType: + return DTYPE_MAP_REV[a.get_element_type()] + + +@numeric.size.register(ov.Tensor) +def _(a: ov.Tensor) -> int: + return a.size + + +@numeric.reshape.register(ov.Tensor) +def _(a: ov.Tensor, shape: Union[int, Tuple[int, ...]]) -> ov.Tensor: + return ov.Tensor(a.data.reshape(shape), shape, a.get_element_type()) + + +@numeric.as_numpy_tensor.register(ov.Tensor) +def _(a: ov.Tensor) -> np.ndarray: + # Cannot convert bfloat16, uint4, int4 to numpy directly + a_dtype = DTYPE_MAP_REV[a.get_element_type()] + if a_dtype in [TensorDataType.bfloat16, TensorDataType.uint4, TensorDataType.int4]: + dtype = TensorDataType.float32 + if a_dtype == TensorDataType.uint4: + dtype = TensorDataType.uint8 + elif a_dtype == TensorDataType.int4: + dtype = TensorDataType.int8 + a = _astype_ov(a, dtype) + return a.data + + +def _astype_ov(a: ov.Tensor, dtype: TensorDataType) -> ov.Tensor: + """ + Cast to a different data type using an OpenVINO model. + :param a: Tensor to cast. + :param dtype: Data type to cast to. + :return: Casted openvino tensor. + """ + from nncf.quantization.algorithms.weight_compression.openvino_modeling import OVModelParameters + from nncf.quantization.algorithms.weight_compression.openvino_modeling import get_astype_model + + a_dtype = DTYPE_MAP_REV[a.get_element_type()] + + ov_model_params = OVModelParameters( + input_dtypes={"input": a_dtype}, + output_dtypes={"output": dtype}, + recompile=True, + release_memory=False, + return_ov_tensors=True, + ) + model = get_astype_model(ov_model_params, tuple(a.shape)) + return model([Tensor(a)])[0].data diff --git a/nncf/tensor/tensor.py b/nncf/tensor/tensor.py index f80abb6a148..3ee387371e5 100644 --- a/nncf/tensor/tensor.py +++ b/nncf/tensor/tensor.py @@ -196,6 +196,9 @@ def item(self) -> float: def clone(self) -> float: return _call_function("clone", self) + def as_numpy_tensor(self) -> Tensor: + return _call_function("as_numpy_tensor", self) + def _call_function(func_name: str, *args): """ diff --git a/nncf/tensorflow/sparsity/rb/algorithm.py b/nncf/tensorflow/sparsity/rb/algorithm.py index 395f621b551..ee9985b053a 100644 --- a/nncf/tensorflow/sparsity/rb/algorithm.py +++ b/nncf/tensorflow/sparsity/rb/algorithm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Intel Corporation +# Copyright (c) 2025 Intel Corporation # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/common/utils/test_cache_results_decorator.py b/tests/common/utils/test_cache_results_decorator.py new file mode 100644 index 00000000000..8b9167443c6 --- /dev/null +++ b/tests/common/utils/test_cache_results_decorator.py @@ -0,0 +1,132 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nncf.common.utils.decorators import ResultsCacheContainer +from nncf.common.utils.decorators import cache_results + +TEST_CACHE_CONTAINER = ResultsCacheContainer() + + +@cache_results(TEST_CACHE_CONTAINER) +def cached_addition(a, b): + return a + b + + +CALL_SEQUENCE = [ + ( + (1, 2), + False, + 3, + False, + 1, + {("cached_addition", frozenset({("a", 1), ("b", 2)})): 3}, + {("cached_addition", frozenset({("a", 1), ("b", 2)})): 0}, + ), + ( + (1, 2), + False, + 3, + False, + 1, + {("cached_addition", frozenset({("a", 1), ("b", 2)})): 3}, + {("cached_addition", frozenset({("a", 1), ("b", 2)})): 1}, + ), + ( + (2, 3), + True, + 5, + False, + 1, + {("cached_addition", frozenset({("a", 1), ("b", 2)})): 3}, + {("cached_addition", frozenset({("a", 1), ("b", 2)})): 1}, + ), + ( + (3, 4), + False, + 7, + False, + 2, + { + ("cached_addition", frozenset({("a", 1), ("b", 2)})): 3, + ("cached_addition", frozenset({("a", 3), ("b", 4)})): 7, + }, + { + ("cached_addition", frozenset({("a", 1), ("b", 2)})): 1, + ("cached_addition", frozenset({("a", 3), ("b", 4)})): 0, + }, + ), + ( + (1, 2), + False, + 3, + False, + 2, + { + ("cached_addition", frozenset({("a", 1), ("b", 2)})): 3, + ("cached_addition", frozenset({("a", 3), ("b", 4)})): 7, + }, + { + ("cached_addition", frozenset({("a", 1), ("b", 2)})): 2, + ("cached_addition", frozenset({("a", 3), ("b", 4)})): 0, + }, + ), + ( + (3, 4), + False, + 7, + False, + 2, + { + ("cached_addition", frozenset({("a", 1), ("b", 2)})): 3, + ("cached_addition", frozenset({("a", 3), ("b", 4)})): 7, + }, + { + ("cached_addition", frozenset({("a", 1), ("b", 2)})): 2, + ("cached_addition", frozenset({("a", 3), ("b", 4)})): 1, + }, + ), + ( + (3, 4), + True, + 7, + False, + 2, + { + ("cached_addition", frozenset({("a", 1), ("b", 2)})): 3, + ("cached_addition", frozenset({("a", 3), ("b", 4)})): 7, + }, + { + ("cached_addition", frozenset({("a", 1), ("b", 2)})): 2, + ("cached_addition", frozenset({("a", 3), ("b", 4)})): 1, + }, + ), + ((3, 4), True, 7, True, 0, {}, {}), + ( + (1, 2), + False, + 3, + False, + 1, + {("cached_addition", frozenset({("a", 1), ("b", 2)})): 3}, + {("cached_addition", frozenset({("a", 1), ("b", 2)})): 0}, + ), +] + + +def test_caching_results(): + for inputs, disable_caching, output, clear_cache, cache_size, ref_cache, ref_access_count in CALL_SEQUENCE: + if clear_cache: + TEST_CACHE_CONTAINER.clear() + kwargs = {"disable_caching": True} if disable_caching else {} + assert cached_addition(*inputs, **kwargs) == output + assert len(TEST_CACHE_CONTAINER._cache) == cache_size + assert TEST_CACHE_CONTAINER._cache == ref_cache + assert TEST_CACHE_CONTAINER._access_count == ref_access_count diff --git a/tests/cross_fw/test_templates/template_test_nncf_tensor.py b/tests/cross_fw/test_templates/template_test_nncf_tensor.py index d92313ebaa6..93c86134c44 100644 --- a/tests/cross_fw/test_templates/template_test_nncf_tensor.py +++ b/tests/cross_fw/test_templates/template_test_nncf_tensor.py @@ -1505,7 +1505,15 @@ def test_expand_dims_error(self, x, axis, match): def test_fn_zeros(self): shape = (2, 2) for dtype in TensorDataType: - if dtype == TensorDataType.bfloat16 and self.backend() == TensorBackend.numpy: + if ( + self.backend() == TensorBackend.numpy + and dtype == TensorDataType.bfloat16 + or dtype + in [ + TensorDataType.int4, + TensorDataType.uint4, + ] + ): continue tensor_a = fns.zeros(shape, backend=self.backend(), dtype=dtype, device=self.device()) assert isinstance(tensor_a, Tensor) @@ -1526,7 +1534,15 @@ def test_fn_zeros(self): ) def test_fn_eye(self, n, m, ref): for dtype in TensorDataType: - if dtype == TensorDataType.bfloat16 and self.backend() == TensorBackend.numpy: + if ( + self.backend() == TensorBackend.numpy + and dtype == TensorDataType.bfloat16 + or dtype + in [ + TensorDataType.int4, + TensorDataType.uint4, + ] + ): continue tensor_a = fns.eye(n, m, backend=self.backend(), dtype=dtype, device=self.device()) assert isinstance(tensor_a, Tensor) diff --git a/tests/openvino/native/quantization/test_ov_modeling_compression.py b/tests/openvino/native/quantization/test_ov_modeling_compression.py new file mode 100644 index 00000000000..229817739bf --- /dev/null +++ b/tests/openvino/native/quantization/test_ov_modeling_compression.py @@ -0,0 +1,271 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from collections import defaultdict +from contextlib import contextmanager +from enum import Enum +from unittest.mock import patch + +import numpy as np +import openvino as ov +import pytest + +from nncf import CompressWeightsMode +from nncf.common.utils.decorators import ResultsCacheContainer +from nncf.common.utils.decorators import cache_results +from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig +from nncf.quantization.algorithms.weight_compression.openvino_modeling import OVModelParameters +from nncf.quantization.algorithms.weight_compression.openvino_modeling import get_compress_decompress_weight_model +from nncf.quantization.algorithms.weight_compression.openvino_modeling import get_compress_weight_model +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_quantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import quantize_dequantize_weight +from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization +from nncf.tensor import Tensor +from nncf.tensor import TensorDataType +from nncf.tensor.definitions import TensorBackend +from nncf.tensor.functions.numpy_numeric import DTYPE_MAP as DTYPE_MAP_NP +from nncf.tensor.functions.numpy_numeric import DTYPE_MAP_REV as DTYPE_MAP_REV_NP +from nncf.tensor.functions.ov_numeric import DTYPE_MAP as DTYPE_MAP_OV + + +class ComputationBackend(Enum): + NumPy = "numpy" + OV = "ov" + + +class QuantizationTask(Enum): + Q = "quantize" + Q_DQ = "quantize_dequantize" + Q_DQ_RQ = "quantize_dequantize_return_quantized" + + +COMPRESSION_CONFIGS = [ + WeightCompressionConfig(CompressWeightsMode.INT8_ASYM), + WeightCompressionConfig(CompressWeightsMode.INT8_SYM), + WeightCompressionConfig(CompressWeightsMode.INT4_ASYM), + WeightCompressionConfig(CompressWeightsMode.INT4_SYM), + WeightCompressionConfig(CompressWeightsMode.INT4_ASYM, group_size=2), + WeightCompressionConfig(CompressWeightsMode.INT4_SYM, group_size=2), +] + +REDUCTION_AXES = (1,) + +RANDOM_TENSOR_CACHE_CONTAINER = ResultsCacheContainer() + + +@cache_results(RANDOM_TENSOR_CACHE_CONTAINER) +def get_random_float_tensor(shape, dtype, backend, seed=0): + np.random.seed(seed) + data = np.random.normal(size=shape) + data = data.astype(np.float16 if dtype == TensorDataType.float16 else np.float32) + + if backend == TensorBackend.ov or dtype == TensorDataType.bfloat16: + data = Tensor(ov.Tensor(data, shape, DTYPE_MAP_OV[DTYPE_MAP_REV_NP[data.dtype]])) + if dtype == TensorDataType.bfloat16: + data = data.astype(TensorDataType.bfloat16) + if backend == TensorBackend.numpy: + data = data.as_numpy_tensor() if dtype == TensorDataType.bfloat16 else Tensor(data) + return Tensor(data) + + +@cache_results(RANDOM_TENSOR_CACHE_CONTAINER) +def get_random_integer_tensor(shape, low, high, dtype, backend, seed=0): + np.random.seed(seed) + data = np.random.randint(low, high, size=shape).astype(DTYPE_MAP_NP[dtype]) + if backend == TensorBackend.ov: + data = ov.Tensor(data, shape, DTYPE_MAP_OV[dtype]) + return Tensor(data) + + +@contextmanager +def openvino_available(available: bool): + import nncf.import_utils + + original_value = nncf.import_utils._openvino_available + nncf.import_utils._openvino_available = available + yield + nncf.import_utils._openvino_available = original_value + + +@pytest.mark.parametrize("weight_shape", [(100000, 4)], ids=[""]) +@pytest.mark.parametrize("config", COMPRESSION_CONFIGS, ids=[str(c) for c in COMPRESSION_CONFIGS]) +@pytest.mark.parametrize( + ("quantization_task", "tensor_backend"), + [ + (QuantizationTask.Q, TensorBackend.numpy), + (QuantizationTask.Q, "auto"), + # NumPy backend should support OV tensors as inputs only for quantization task + (QuantizationTask.Q, TensorBackend.ov), + (QuantizationTask.Q_DQ, TensorBackend.numpy), + (QuantizationTask.Q_DQ, "auto"), + (QuantizationTask.Q_DQ_RQ, TensorBackend.numpy), + (QuantizationTask.Q_DQ_RQ, "auto"), + ], +) +@pytest.mark.parametrize("dtype", [TensorDataType.float32, TensorDataType.float16, TensorDataType.bfloat16]) +@pytest.mark.parametrize("precompute_s_zp", [False, True], ids=["no-precompute", "precompute"]) +def test_quantization_alignment(weight_shape, config, quantization_task, tensor_backend, dtype, precompute_s_zp): + d1, d2 = weight_shape + group_size = config.group_size + zero_point_shape = scale_shape = (d1, 1) if group_size == -1 else (d1, d2 // group_size, 1) + level_low, level_high = 0, 2**config.num_bits - 1 + + results = defaultdict(dict) + # Iterate over two implementations + for cb in [ComputationBackend.NumPy, ComputationBackend.OV]: + # A context manager to enable/disable ov implementation + with openvino_available(cb == ComputationBackend.OV): + # OV tensor backend for weight is only supported for quantization task + if quantization_task == QuantizationTask.Q and ( + tensor_backend == TensorBackend.ov or cb == ComputationBackend.OV and tensor_backend == "auto" + ): + weight_tensor_backend = TensorBackend.ov + else: + weight_tensor_backend = TensorBackend.numpy + + # Generate input tensors + weight = get_random_float_tensor(weight_shape, dtype, weight_tensor_backend) + precomputed_scale, precomputed_zero_point = None, None + if precompute_s_zp: + # When scale (and z.p) are precomputed, all inputs are assumed to be reshaped beforehand + if group_size != -1: + weight, _ = reshape_weight_for_grouped_quantization(weight, REDUCTION_AXES, group_size) + + precomputed_scale = get_random_float_tensor(scale_shape, TensorDataType.float32, TensorBackend.numpy) + if config.is_asym_mode: + precomputed_zero_point = get_random_integer_tensor( + zero_point_shape, level_low, level_high, TensorDataType.int32, TensorBackend.numpy + ) + + if quantization_task == QuantizationTask.Q: + fn_to_call = do_int_quantization + fn_to_patch = get_compress_weight_model + else: + fn_to_call = quantize_dequantize_weight + fn_to_patch = get_compress_decompress_weight_model + patch_path = f"{inspect.getmodule(fn_to_patch).__name__}.{fn_to_patch.__name__}" + with patch(patch_path, side_effect=fn_to_patch) as mock: + # When scale (and z.p) are precomputed, all inputs are assumed to be already reshaped and reduction + # axes are not needed + reduction_axes = None if precompute_s_zp else REDUCTION_AXES + + kwargs = {} + if cb == ComputationBackend.OV: + ov_model_params = OVModelParameters() + kwargs["ov_model_params"] = ov_model_params + if quantization_task == QuantizationTask.Q_DQ_RQ: + kwargs["return_compressed_weight"] = True + + outputs = fn_to_call( + weight, config, reduction_axes, precomputed_scale, precomputed_zero_point, **kwargs + ) + + decompressed_weight, compressed_weight, scale, zero_point = (None,) * 4 + if quantization_task == QuantizationTask.Q: + compressed_weight, scale, zero_point = outputs + elif quantization_task == QuantizationTask.Q_DQ: + decompressed_weight = outputs + else: + decompressed_weight, compressed_weight, scale, zero_point = outputs + + if cb == ComputationBackend.NumPy: + mock.assert_not_called() + else: + mock.assert_called_once() + + if quantization_task != QuantizationTask.Q_DQ and precompute_s_zp: + # In case of precomputed scale or zero point, the returned scale and z.p. should equal the given ones + np.testing.assert_allclose(precomputed_scale.data, scale.data, atol=0, rtol=0) + if config.is_asym_mode: + np.testing.assert_allclose(precomputed_zero_point.data, zero_point.data, atol=0, rtol=0) + + # Save results for comparison between implementations + if quantization_task != QuantizationTask.Q: + results[cb]["decompressed_weight"] = decompressed_weight + if quantization_task != QuantizationTask.Q_DQ: + results[cb]["compressed_weight"] = compressed_weight.as_numpy_tensor() + results[cb]["scale"] = scale + if config.is_asym_mode: + results[cb]["zero_point"] = zero_point.as_numpy_tensor() + + _check_backends_and_dtypes( + quantization_task, + cb, + weight_tensor_backend, + config, + precompute_s_zp, + compressed_weight, + scale, + zero_point, + decompressed_weight, + ) + + _check_values(results) + + +def _check_backends_and_dtypes( + quantization_task, + cb, + weight_tensor_backend, + config, + precompute_s_zp, + compressed_weight, + scale, + zero_point, + decompressed_weight, +): + if quantization_task != QuantizationTask.Q_DQ: + # Scale should always be float32 and numpy backend + assert scale.dtype == TensorDataType.float32 + assert scale.backend == TensorBackend.numpy + + if ( + quantization_task == QuantizationTask.Q + and cb == ComputationBackend.OV + and weight_tensor_backend == TensorBackend.ov + and config.num_bits == 4 + ): + # For 4 bit compression in case of ov implementation and ov backend the compressed weight and the computed + # zero point must be in ov backend and have (u)int4 dtype in order to be able to insert them into OV model + # without re-packing + assert compressed_weight.backend == TensorBackend.ov + assert compressed_weight.dtype == (TensorDataType.uint4 if config.is_asym_mode else TensorDataType.int4) + if config.is_asym_mode and not precompute_s_zp: + assert zero_point.backend == TensorBackend.ov + assert zero_point.dtype == TensorDataType.uint4 + else: + if quantization_task != QuantizationTask.Q_DQ: + # Otherwise compressed weight and zero point must be returned in numpy backend, compressed weight must + # be of (u)int8 data type, zero point -- in int32 + assert compressed_weight.backend == TensorBackend.numpy + assert compressed_weight.dtype == (TensorDataType.uint8 if config.is_asym_mode else TensorDataType.int8) + if config.is_asym_mode and not precompute_s_zp: + assert zero_point.backend == TensorBackend.numpy + assert zero_point.dtype == TensorDataType.int32 + if quantization_task != QuantizationTask.Q: + assert decompressed_weight.backend == TensorBackend.numpy + assert decompressed_weight.dtype == TensorDataType.float32 + + +def _check_values(results): + # Check that the computed tensors are equal between implementations + keys = set(results[ComputationBackend.OV]).union(set(results[ComputationBackend.NumPy])) + for key in keys: + numpy_result = results[ComputationBackend.NumPy][key] + ov_result = results[ComputationBackend.OV][key] + + # Note: For static-shaped OV models doing asymmetric compression with convertable divisions there maybe + # misalignments equal to 1 quant between OV and NumPy. For more details see ticket 156511. + + np.testing.assert_allclose( + ov_result.data, numpy_result.data, atol=0, rtol=0, err_msg=f"Results do not align for {key}." + ) diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index e39a621a4a8..6d0e514f75f 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -36,8 +36,6 @@ from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.mixed_precision import MIXED_PRECISION_CRITERIA -from nncf.quantization.algorithms.weight_compression.openvino_backend import OVWeightCompressionAlgoBackend -from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_quantization from nncf.quantization.algorithms.weight_compression.weight_lowering import get_integer_quantization_error from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization @@ -1078,34 +1076,6 @@ def test_mixed_precision_e2m1(mode, all_layers, ratio, ref_ids): assert ref_e8m0_nodes == names_e8m0 -@pytest.mark.parametrize("mode", (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM)) -def test_np_ov_compression_decompression(mode): - sz = 60 - w = np.arange(-sz, sz).reshape(2, sz).astype(np.float32) / 9.0 - w = Tensor(w) - - config = WeightCompressionConfig(mode) - - compressed_weighs, scale, zp = do_int_quantization(w, -1, config, invert_scale=True) - decompressed_weighs = do_int_dequantization(compressed_weighs, scale, zp) - - compressed_weighs = compressed_weighs.data - decompressed_weighs = decompressed_weighs.data - zp_shape = zp.shape if zp is not None else None - - compress = OVWeightCompressionAlgoBackend.get_compress_pipeline(config, w.shape, scale.shape, zp_shape) - compress_decompress = OVWeightCompressionAlgoBackend.get_compress_decompress_pipeline( - config, w.shape, scale.shape, zp_shape - ) - - params = [w.data, scale.data, zp.data] if zp is not None else [w.data, scale.data] - compressed_weighs_ov = compress(params) - decompressed_weighs_ov = compress_decompress(params) - - assert np.allclose(compressed_weighs, compressed_weighs_ov) - assert np.allclose(decompressed_weighs, decompressed_weighs_ov) - - @pytest.mark.parametrize( ("mode", "data"), ( @@ -1122,7 +1092,7 @@ def test_compressed_weighs_range(mode, data): w = Tensor(data) config = WeightCompressionConfig(mode=mode) - compressed_weighs, _, _ = do_int_quantization(w, -1, config) + compressed_weighs, _, _ = do_int_quantization(w, config, -1) assert np.allclose(np.abs(compressed_weighs.data), np.abs(w.data)) @@ -1145,8 +1115,6 @@ def test_compressed_weighs_range(mode, data): ], ) def test_int_quantization_with_precomputed_parameters(config, precompute_scale, precompute_zero_point, raises): - is_asym = config.mode in [CompressWeightsMode.INT4_ASYM, CompressWeightsMode.INT8_ASYM] - precomputed_scale, precomputed_zero_point = None, None weight = Tensor(((np.arange(11) - 5) / 10).astype(np.float32)[:, None]) if precompute_scale: @@ -1156,18 +1124,18 @@ def test_int_quantization_with_precomputed_parameters(config, precompute_scale, if raises: with pytest.raises(ValueError) as exc_info: - _, scale, zero_point = do_int_quantization(weight, -1, config, precomputed_scale, precomputed_zero_point) + _, scale, zero_point = do_int_quantization(weight, config, -1, precomputed_scale, precomputed_zero_point) assert exc_info.value == ( "If precomputed quantization parameters are provided, both scale and zero point " "are required for asymmetric quantization." ) return else: - _, scale, zero_point = do_int_quantization(weight, -1, config, precomputed_scale, precomputed_zero_point) + _, scale, zero_point = do_int_quantization(weight, config, -1, precomputed_scale, precomputed_zero_point) if precompute_scale: assert np.allclose(scale.data, precomputed_scale.data) - if is_asym: + if config.is_asym_mode: if precompute_zero_point: assert np.allclose(zero_point.data, precomputed_zero_point.data) else: diff --git a/tests/openvino/native/test_node_utils.py b/tests/openvino/native/test_node_utils.py index c6654218d42..a2e319d6a20 100644 --- a/tests/openvino/native/test_node_utils.py +++ b/tests/openvino/native/test_node_utils.py @@ -22,6 +22,7 @@ from nncf.openvino.graph.node_utils import get_weight_channel_axes from nncf.openvino.graph.node_utils import get_weighted_layer_attributes from nncf.openvino.graph.node_utils import is_node_with_bias +from nncf.openvino.graph.node_utils import non_convertable_divide_op from tests.openvino.native.models import ConvModel from tests.openvino.native.models import ConvNotBiasModel from tests.openvino.native.models import MatMul2DModel @@ -29,30 +30,63 @@ @pytest.mark.parametrize( - "precisions", + "precisions,cast_bf16_to_fp32", [ # base FP32 precision - { - "type_for_const": ov.Type.f32, - "ref_type": np.float32, - }, + ( + { + "type_for_const": ov.Type.f32, + "ref_type": np.float32, + }, + True, + ), # base FP16 precision - { - "type_for_const": ov.Type.f16, - "ref_type": np.float16, - }, + ( + { + "type_for_const": ov.Type.f16, + "ref_type": np.float16, + }, + True, + ), # base BF16 precision should be casted to FP32 - { - "type_for_const": ov.Type.bf16, - "ref_type": np.float32, - }, + ( + { + "type_for_const": ov.Type.bf16, + "ref_type": np.float32, + }, + True, + ), + # base FP32 precision, cast_bf16_to_fp32=False has no effect + ( + { + "type_for_const": ov.Type.f32, + "ref_type": np.float32, + }, + False, + ), + # base FP16 precision, cast_bf16_to_fp32=False has no effect + ( + { + "type_for_const": ov.Type.f16, + "ref_type": np.float16, + }, + False, + ), + # with cast_bf16_to_fp32=False BF16 constant is retrieved as FP16 + ( + { + "type_for_const": ov.Type.bf16, + "ref_type": np.float16, + }, + False, + ), ], ) -def test_get_const_value(precisions): +def test_get_const_value(precisions, cast_bf16_to_fp32): const_data = np.ones((1, 2, 3), dtype=np.float32) weight_const = opset.constant(const_data, dtype=precisions["type_for_const"]) - const_value = get_const_value(weight_const) + const_value = get_const_value(weight_const, cast_bf16_to_fp32=cast_bf16_to_fp32) assert const_value.dtype == precisions["ref_type"] @@ -114,3 +148,21 @@ def test_get_weight_channel_axes_for_matmul(weights_port_id, transpose, shape, d assert len(actual_channel_axes) == len(expected_channel_axes) assert all(a == b for a, b in zip(actual_channel_axes, expected_channel_axes)) + + +@pytest.mark.parametrize( + "a,b,convertable,ref_result", + [ + (0.058599039912223816, 15, True, 0.003906603), + (0.058599039912223816, 15, False, 0.003906602505594492), + ], +) +def test_non_convertable_division(a, b, convertable, ref_result): + a, b, ref_result = tuple(map(lambda x: np.array([x], np.float32), [a, b, ref_result])) + a_param = opset.parameter((-1,), ov.Type.f32) + b_param = opset.parameter((-1,), ov.Type.f32) + division = (a_param / b_param) if convertable else non_convertable_divide_op(a_param, b_param) + model = ov.Model([division], [a_param, b_param]) + compiled_model = ov.compile_model(model, device_name="CPU") + actual_result = compiled_model([a, b])[0] + np.testing.assert_allclose(actual_result, ref_result, atol=0, rtol=0) diff --git a/tests/openvino/native/test_openvino_modeling.py b/tests/openvino/native/test_openvino_modeling.py new file mode 100644 index 00000000000..b72ca228aed --- /dev/null +++ b/tests/openvino/native/test_openvino_modeling.py @@ -0,0 +1,319 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from nncf import CompressWeightsMode +from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig +from nncf.quantization.algorithms.weight_compression.openvino_modeling import OV_MODEL_CACHE +from nncf.quantization.algorithms.weight_compression.openvino_modeling import OVModelParameters +from nncf.quantization.algorithms.weight_compression.openvino_modeling import _infer_ov_model +from nncf.quantization.algorithms.weight_compression.openvino_modeling import get_astype_model +from nncf.quantization.algorithms.weight_compression.openvino_modeling import get_compress_decompress_weight_model +from nncf.quantization.algorithms.weight_compression.openvino_modeling import get_compress_weight_model +from nncf.tensor import Tensor +from nncf.tensor import TensorDataType +from nncf.tensor.definitions import TensorBackend +from nncf.tensor.functions.numpy_numeric import DTYPE_MAP as DTYPE_MAP_NP + + +class ModelGetter: + def __init__(self, get_model_fn, ov_model_params_kwargs, get_model_kwargs): + self._get_model_fn = get_model_fn + self._ov_model_params_kwargs = ov_model_params_kwargs + self._get_model_kwargs = get_model_kwargs + + def get(self, ov_model_params_kwargs=None, get_model_kwargs=None): + ov_model_params_kwargs = ov_model_params_kwargs or {} + get_model_kwargs = get_model_kwargs or {} + return self._get_model_fn( + OVModelParameters(**{**self._ov_model_params_kwargs, **ov_model_params_kwargs}), + **{**self._get_model_kwargs, **get_model_kwargs}, + ) + + +MODEL_GETTERS = [ + ModelGetter( + get_model_fn=get_compress_weight_model, + ov_model_params_kwargs=dict( + input_dtypes={ + "weight": TensorDataType.float32, + "scale": TensorDataType.float32, + "zero_point": TensorDataType.int32, + }, + output_dtypes={"compressed_weight": TensorDataType.uint8}, + ), + get_model_kwargs=dict( + config=WeightCompressionConfig(CompressWeightsMode.INT8_ASYM), + weight_shape=(10, 4), + scale_shape=(10, 1), + zero_point_shape=(10, 1), + ), + ), + ModelGetter( + get_model_fn=get_compress_weight_model, + ov_model_params_kwargs=dict( + input_dtypes={"weight": TensorDataType.float32}, + output_dtypes={ + "compressed_weight": TensorDataType.uint8, + "scale": TensorDataType.float32, + "zero_point": TensorDataType.int32, + }, + ), + get_model_kwargs=dict( + config=WeightCompressionConfig(CompressWeightsMode.INT8_ASYM), + weight_shape=(10, 4), + reduction_axes=(1,), + ), + ), + ModelGetter( + get_model_fn=get_compress_decompress_weight_model, + ov_model_params_kwargs=dict( + input_dtypes={ + "weight": TensorDataType.float32, + "scale": TensorDataType.float32, + "zero_point": TensorDataType.int32, + }, + output_dtypes={ + "decompressed_weight": TensorDataType.float32, + }, + ), + get_model_kwargs=dict( + config=WeightCompressionConfig(CompressWeightsMode.INT8_ASYM), + weight_shape=(10, 4), + scale_shape=(10, 1), + zero_point_shape=(10, 1), + ), + ), + ModelGetter( + get_model_fn=get_compress_decompress_weight_model, + ov_model_params_kwargs=dict( + input_dtypes={ + "weight": TensorDataType.float32, + }, + output_dtypes={ + "decompressed_weight": TensorDataType.float32, + "compressed_weight": TensorDataType.int32, + "scale": TensorDataType.float32, + "zero_point": TensorDataType.int32, + }, + ), + get_model_kwargs=dict( + config=WeightCompressionConfig(CompressWeightsMode.INT8_ASYM), + weight_shape=(10, 4), + reduction_axes=(1,), + return_compressed_weight=True, + ), + ), + ModelGetter( + get_model_fn=get_astype_model, + ov_model_params_kwargs=dict( + input_dtypes={ + "input": TensorDataType.float32, + }, + output_dtypes={ + "output": TensorDataType.bfloat16, + }, + ), + get_model_kwargs=dict( + input_shape=(10, 4), + ), + ), +] + + +@pytest.mark.parametrize( + "model_getter,input_shapes,ref_cache_size", + [ + ( + MODEL_GETTERS[0], + [ + dict(weight_shape=(10, 4), scale_shape=(10, 1), zero_point_shape=(10, 1)), + dict(weight_shape=(20, 6), scale_shape=(20, 1), zero_point_shape=(20, 1)), + dict(weight_shape=(20, 8), scale_shape=(20, 1), zero_point_shape=(20, 1)), + dict(weight_shape=(10, 4, 4), scale_shape=(10, 4, 1), zero_point_shape=(10, 4, 1)), + dict(weight_shape=(10, 8, 4), scale_shape=(10, 8, 1), zero_point_shape=(10, 8, 1)), + ], + {False: 5, True: 2}, + ), + ( + MODEL_GETTERS[1], + [ + dict(weight_shape=(10, 4)), + dict(weight_shape=(20, 6)), + dict(weight_shape=(20, 8)), + dict(weight_shape=(10, 4, 4)), + dict(weight_shape=(10, 8, 4)), + ], + {False: 5, True: 2}, + ), + ( + MODEL_GETTERS[2], + [ + dict(weight_shape=(10, 4), scale_shape=(10, 1), zero_point_shape=(10, 1)), + dict(weight_shape=(20, 6), scale_shape=(20, 1), zero_point_shape=(20, 1)), + dict(weight_shape=(20, 8), scale_shape=(20, 1), zero_point_shape=(20, 1)), + dict(weight_shape=(10, 4, 4), scale_shape=(10, 4, 1), zero_point_shape=(10, 4, 1)), + dict(weight_shape=(10, 8, 4), scale_shape=(10, 8, 1), zero_point_shape=(10, 8, 1)), + ], + {False: 10, True: 4}, + ), + ( + MODEL_GETTERS[3], + [ + dict(weight_shape=(10, 4)), + dict(weight_shape=(20, 6)), + dict(weight_shape=(20, 8)), + dict(weight_shape=(10, 4, 4)), + dict(weight_shape=(10, 8, 4)), + ], + {False: 10, True: 4}, + ), + ( + MODEL_GETTERS[4], + [ + dict(input_shape=(10, 1)), + dict(input_shape=(10, 2)), + dict(input_shape=(20, 3)), + dict(input_shape=(10, 4, 4)), + dict(input_shape=(10, 8, 4)), + ], + {False: 5, True: 2}, + ), + ], +) +@pytest.mark.parametrize("dynamic_shapes", [False, True]) +def test_dynamic_shapes(model_getter, input_shapes, ref_cache_size, dynamic_shapes): + # Check that model cache contains fewer elements with dynamic shapes enabled + OV_MODEL_CACHE.clear() + for shape_kwargs in input_shapes: + model_getter.get(ov_model_params_kwargs=dict(dynamic_shapes=dynamic_shapes), get_model_kwargs=shape_kwargs) + assert len(OV_MODEL_CACHE._cache) == ref_cache_size[dynamic_shapes] + + +@pytest.mark.parametrize("model_getter", MODEL_GETTERS) +@pytest.mark.parametrize("recompile", [True, False]) +def test_recompile(model_getter, recompile): + # Check that with recompilation ov models are not cached + OV_MODEL_CACHE.clear() + model_getter.get(ov_model_params_kwargs=dict(recompile=recompile)) + ref_size = 0 if recompile else (2 if model_getter._get_model_fn == get_compress_decompress_weight_model else 1) + assert len(OV_MODEL_CACHE._cache) == ref_size + + +@pytest.mark.parametrize("model_getter", MODEL_GETTERS) +@pytest.mark.parametrize("return_ov_tensors", [True, False]) +def test_return_ov_tensors(model_getter, return_ov_tensors): + # Check that ov tensors are returned + OV_MODEL_CACHE.clear() + inputs = [] + for input_name, input_dtype in model_getter._ov_model_params_kwargs["input_dtypes"].items(): + input_shape = model_getter._get_model_kwargs.get(f"{input_name}_shape") + inputs.append(Tensor(np.zeros(input_shape, dtype=DTYPE_MAP_NP[input_dtype]))) + + model_run_fn = model_getter.get(ov_model_params_kwargs=dict(return_ov_tensors=return_ov_tensors)) + outputs = model_run_fn(inputs) + + assert all([out.backend == (TensorBackend.ov if return_ov_tensors else TensorBackend.numpy) for out in outputs]) + + +@pytest.mark.parametrize("release_memory", [True, False]) +def test_release_memory(mocker, release_memory): + compiled_model = mocker.Mock() + compiled_model.release_memory = mocker.Mock() + + input_mock = mocker.Mock() + input_mock.any_name = "input" + compiled_model.inputs = [input_mock] + + output_mock = mocker.Mock() + compiled_model.return_value = [output_mock] + + ov_model_params = OVModelParameters(input_dtypes={"input": TensorDataType.float32}, release_memory=release_memory) + input_tensor = mocker.Mock() + input_tensor.dtype = TensorDataType.float32 + input_tensor.data = [1, 2, 3] + inputs = [input_tensor] + + _infer_ov_model(ov_model_params, compiled_model, inputs=inputs) + if release_memory: + compiled_model.release_memory.assert_called_once() + else: + compiled_model.release_memory.assert_not_called() + + +@pytest.mark.parametrize("share_inputs", [True, False]) +@pytest.mark.parametrize("share_outputs", [True, False]) +@pytest.mark.parametrize("return_ov_tensors", [True, False]) +def test_share_inputs_outputs(mocker, share_inputs, share_outputs, return_ov_tensors): + compiled_model = mocker.Mock() + + input_mock = mocker.Mock() + input_mock.any_name = "input" + compiled_model.inputs = [input_mock] + + output_mock = mocker.Mock() + + if return_ov_tensors: + infer_request = mocker.Mock() + compiled_model.create_infer_request.return_value = infer_request + + infer_request.infer = mocker.Mock() + infer_request.results = [output_mock] + + infer_request.get_output_tensor.return_value = output_mock + else: + compiled_model.return_value = [output_mock] + + ov_model_params = OVModelParameters( + input_dtypes={"input": TensorDataType.float32}, + return_ov_tensors=return_ov_tensors, + share_inputs=share_inputs, + share_outputs=share_outputs, + ) + + input_tensor = mocker.Mock() + input_tensor.dtype = TensorDataType.float32 + input_tensor.data = [1, 2, 3] + inputs = [input_tensor] + + _infer_ov_model(ov_model_params, compiled_model, inputs=inputs) + + if return_ov_tensors: + infer_request.infer.assert_called_once_with( + [input_tensor.data], share_inputs=share_inputs, share_outputs=share_outputs + ) + else: + compiled_model.assert_called_once_with( + [input_tensor.data], share_inputs=share_inputs, share_outputs=share_outputs + ) + + +@pytest.mark.parametrize( + "weight,convertable_division,ref_compressed_weight", + [ + ([[0.70361328125, 0.92919921875, 0.37109375, -0.98974609375]], True, [[225, 255, 181, 0]]), + ([[0.70361328125, 0.92919921875, 0.37109375, -0.98974609375]], False, [[226, 255, 181, 0]]), + ], +) +def test_convertable_divison(weight, convertable_division, ref_compressed_weight): + ov_model_params = OVModelParameters( + input_dtypes={"weight": TensorDataType.float32}, + dynamic_shapes=not convertable_division, + convertable_division=convertable_division, + ) + config = WeightCompressionConfig(CompressWeightsMode.INT8_ASYM) + + weight = np.array(weight, np.float32) + ref_compressed_weight = np.array(ref_compressed_weight, np.uint8) + model_run_fn = get_compress_weight_model(ov_model_params, config, weight.shape, reduction_axes=(1,)) + compressed_weight = model_run_fn([Tensor(weight)])[0] + np.testing.assert_allclose(compressed_weight.data, ref_compressed_weight, atol=0, rtol=0) diff --git a/tests/openvino/native/test_tensor.py b/tests/openvino/native/test_tensor.py new file mode 100644 index 00000000000..394b22b512e --- /dev/null +++ b/tests/openvino/native/test_tensor.py @@ -0,0 +1,94 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import openvino as ov +import openvino.runtime.opset13 as opset +import pytest + +import nncf.tensor.functions as fns +from nncf.tensor import Tensor +from nncf.tensor import TensorDataType +from nncf.tensor.definitions import TensorBackend +from nncf.tensor.definitions import TensorDeviceType +from nncf.tensor.functions.numpy_numeric import DTYPE_MAP as DTYPE_MAP_NP +from nncf.tensor.functions.ov_numeric import DTYPE_MAP as DTYPE_MAP_OV + + +class TestOVNNCFTensorOperators: + @staticmethod + def to_tensor(x, backend=TensorBackend.ov, dtype=TensorDataType.float32): + if backend == TensorBackend.ov: + if dtype in [TensorDataType.bfloat16, TensorDataType.uint4, TensorDataType.int4]: + ov_const = opset.constant(x, dtype=DTYPE_MAP_OV[dtype]) + return ov.Tensor(ov_const.data, ov_const.data.shape, DTYPE_MAP_OV[dtype]) + else: + return ov.Tensor(np.array(x, dtype=DTYPE_MAP_NP[dtype])) + elif backend == TensorBackend.numpy: + if dtype in [TensorDataType.bfloat16, TensorDataType.uint4, TensorDataType.int4]: + raise ValueError(f"Can't create NumPY tensor in dtype {dtype}") + return np.array(x, dtype=DTYPE_MAP_NP[dtype]) + else: + raise ValueError("Unsupported backend") + + @staticmethod + def backend() -> TensorBackend: + return TensorBackend.ov + + def test_property_backend(self): + tensor_a = Tensor(self.to_tensor([1, 2])) + assert tensor_a.backend == self.backend() + + def test_device(self): + tensor = Tensor(self.to_tensor([1])) + assert tensor.device == TensorDeviceType.CPU + + def test_size(self): + tensor = Tensor(self.to_tensor([1, 1])) + res = tensor.size + assert res == 2 + + def test_astype(self): + tensor = Tensor(self.to_tensor([1])) + res = tensor.astype(TensorDataType.int8) + assert isinstance(res, Tensor) + assert res.dtype == TensorDataType.int8 + assert res.device == tensor.device + + def test_fn_astype(self): + tensor = Tensor(self.to_tensor([1])) + res = fns.astype(tensor, TensorDataType.int8) + assert isinstance(res, Tensor) + assert res.dtype == TensorDataType.int8 + + def test_reshape(self): + tensor = Tensor(self.to_tensor([1, 1])) + res = tensor.reshape((1, 2)) + assert tensor.shape == (2,) + assert res.shape == (1, 2) + assert res.device == tensor.device + + def test_fn_reshape(self): + tensor = Tensor(self.to_tensor([1, 1])) + res = fns.reshape(tensor, (1, 2)) + assert tensor.shape == (2,) + assert res.shape == (1, 2) + assert res.device == tensor.device + + @pytest.mark.parametrize("from_backend", [TensorBackend.numpy, TensorBackend.ov]) + def test_as_numpy_tensor(self, from_backend): + tensor1 = Tensor(self.to_tensor([1], backend=from_backend)) + assert tensor1.backend == from_backend + tensor2 = tensor1.as_numpy_tensor() + assert tensor2.backend == TensorBackend.numpy + assert tensor1.dtype == tensor2.dtype + assert tensor1.shape == tensor2.shape + assert tensor1.device == tensor2.device