Skip to content

Commit

Permalink
use const graph for nncf.quantize
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Mar 24, 2024
1 parent b7ba5ad commit 1c06382
Show file tree
Hide file tree
Showing 41 changed files with 11,935 additions and 7,964 deletions.
14 changes: 10 additions & 4 deletions nncf/quantization/algorithms/fast_bias_correction/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,10 @@ def apply(
input_fp, input_shape = self._get_fp_inputs(statistic_points, in_node_name)
output_fp = self._get_fp_outputs(statistic_points, out_node_name)

extracted_model = self._extract_submodel(model_transformer, node_name)
extracted_model = self._extract_submodel(model_transformer, in_node_name, out_node_name)
if extracted_model is None:
nncf_logger.debug(f"Skipping node {node_name} because cant extract submodel")
continue

sub_input_name, sub_output_name = self._backend_entity.get_sub_input_output_names(extracted_model)

Expand Down Expand Up @@ -267,15 +270,18 @@ def output_filter_func(point):
output_fp.extend(Tensor(tensor_collector.get_statistics().mean_values))
return output_fp

def _extract_submodel(self, model_transformer: ModelTransformer, node_name: str) -> TModel:
def _extract_submodel(self, model_transformer: ModelTransformer, in_node_name: str, out_node_name: str) -> TModel:
"""
Extracts sub-model using backend-specific ModelTransformer.
:param model_transformer: Backend-specific ModelTransformer.
:param node_name: Name of the node that should be a center of the sub-model.
:param in_node_name: Name of the node that should be a center of the sub-model.
:param out_node_name: Name of the node that should be a center of the sub-model.
:return: Backend-specific sub-model.
"""
model_extraction_command = self._backend_entity.model_extraction_command([(node_name, 0)], [(node_name, 0)])
model_extraction_command = self._backend_entity.model_extraction_command(
[(in_node_name, 0)], [(out_node_name, 0)]
)
me_transformation_layout = TransformationLayout()
me_transformation_layout.register(model_extraction_command)
extracted_model = model_transformer.transform(me_transformation_layout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend
from nncf.torch.graph.transformations.command_creation import create_bias_correction_command
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
from nncf.torch.graph.transformations.commands import PTModelExtractionWithFusedBiasCommand
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_analyzer import get_fused_bias_value
from nncf.torch.model_analyzer import get_potential_fused_node
from nncf.torch.model_analyzer import is_node_with_fused_bias
from nncf.torch.model_analyzer import is_quantized_weights
from nncf.torch.model_graph_manager import get_fused_bias_value
from nncf.torch.model_graph_manager import get_potential_fused_node
from nncf.torch.model_graph_manager import is_node_with_fused_bias
from nncf.torch.model_graph_manager import is_quantized_weights
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.tensor_statistics.collectors import get_mean_statistic_collector

Expand Down Expand Up @@ -56,8 +56,8 @@ def create_bias_correction_command(
@staticmethod
def model_extraction_command(
input_ids: List[Tuple[str, int]], output_ids: List[Tuple[str, int]]
) -> PTModelExtractionWithFusedBiasCommand:
return PTModelExtractionWithFusedBiasCommand(input_ids[0][0])
) -> PTModelExtractionCommand:
return PTModelExtractionCommand([input_ids[0][0]], [output_ids[0][0]])

@staticmethod
def mean_statistic_collector(
Expand Down
6 changes: 3 additions & 3 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import collections
import dataclasses
from copy import deepcopy
from typing import Any, Dict, List, Optional, OrderedDict, Set, TypeVar, Union
from typing import Any, Dict, List, Optional, OrderedDict, Set, Tuple, TypeVar, Union

import numpy as np

Expand Down Expand Up @@ -681,7 +681,7 @@ def _get_activation_quantization_target_point(

def _get_quantization_target_points(
self, model: TModel, nncf_graph: NNCFGraph
) -> OrderedDict[TargetPoint, QuantizerConfig]:
) -> Tuple[OrderedDict[TargetPoint, QuantizerConfig], List[List[TargetPoint]]]:
"""
Returns Quantization Target Points.
In the Compression Pipeline logic NNCF assumes that the compression pipeline works only on the single model.
Expand Down Expand Up @@ -1049,7 +1049,7 @@ def _is_node_after_producers(node):
quantizer_setup.discard(fq_2_q_key, True)
continue

# In the case of the two quantizers without the brancking after them,
# In the case of the two quantizers without the branching after them,
# it needs to check that all quantizers follows after producer nodes.
if _is_node_after_producers(fq_1_producer) and _is_node_after_producers(fq_2_producer):
fq_1_prod_shape = np.prod(nncf_graph.get_output_edges(fq_1_producer)[0].tensor_shape)
Expand Down
20 changes: 10 additions & 10 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class PTMinMaxAlgoBackend(MinMaxAlgoBackend):

@property
def mat_mul_metatypes(self) -> List[OperatorMetatype]:
return [om.PTModuleLinearMetatype, om.PTLinearMetatype, om.PTMatMulMetatype]
return [om.PTLinearMetatype, om.PTLinearMetatype, om.PTMatMulMetatype]

@property
def post_processing_metatypes(self) -> List[OperatorMetatype]:
Expand All @@ -78,18 +78,18 @@ def read_variable_metatypes(self) -> List[OperatorMetatype]:

@property
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.PTModuleConv1dMetatype, om.PTModuleConv2dMetatype, om.PTModuleConv3dMetatype]
return [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype]

@property
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
return [
om.PTModuleConv1dMetatype,
om.PTModuleConv2dMetatype,
om.PTModuleConv3dMetatype,
om.PTModuleLinearMetatype,
om.PTModuleConvTranspose1dMetatype,
om.PTModuleConvTranspose2dMetatype,
om.PTModuleConvTranspose3dMetatype,
om.PTConv1dMetatype,
om.PTConv2dMetatype,
om.PTConv3dMetatype,
om.PTLinearMetatype,
om.PTConvTranspose1dMetatype,
om.PTConvTranspose2dMetatype,
om.PTConvTranspose3dMetatype,
]

@property
Expand Down Expand Up @@ -214,7 +214,7 @@ def get_statistic_collector(

@staticmethod
def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]:
return [None]
return node.metatype.weight_port_ids

@staticmethod
def get_weight_name(nncf_graph: NNCFGraph, target_point: PTTargetPoint) -> str:
Expand Down
23 changes: 14 additions & 9 deletions nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from nncf.torch.graph.transformations.command_creation import create_command_to_update_weight
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_graph_manager import get_const_data
from nncf.torch.model_graph_manager import get_const_node
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.torch.tensor_statistics.collectors import PTAbsMaxReducer
Expand All @@ -52,14 +54,14 @@ class PTSmoothQuantAlgoBackend(SmoothQuantAlgoBackend):
@property
def convolution_metatypes(self) -> List[OperatorMetatype]:
return [
om.PTModuleConv1dMetatype,
om.PTModuleConv2dMetatype,
om.PTModuleConv3dMetatype,
om.PTConv1dMetatype,
om.PTConv2dMetatype,
om.PTConv3dMetatype,
]

@property
def matmul_metatypes(self) -> List[OperatorMetatype]:
return [om.PTModuleLinearMetatype]
return [om.PTLinearMetatype]

@property
def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]:
Expand Down Expand Up @@ -98,10 +100,13 @@ def get_abs_max_channel_collector(

@staticmethod
def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork) -> Tensor:
node_module = model.nncf.get_containing_module(node_with_weight.node_name)
if node_module.weight is None:
raise RuntimeError(f"{node_module} module has no .weight attribute.")
return Tensor(node_module.weight.data)
weight_node = get_const_node(
node_with_weight, node_with_weight.metatype.weight_port_ids[0], model.nncf.get_graph()
)
if weight_node is None:
raise RuntimeError(f"{node_with_weight} node has no weight node.")
weight_data = get_const_data(weight_node, model)
return Tensor(weight_data)

@staticmethod
def get_weight_tensor_port_id(node: NNCFNode) -> int:
Expand Down Expand Up @@ -131,7 +136,7 @@ def scale_insertion_command(

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
if node.metatype == om.PTModuleLinearMetatype:
if node.metatype == om.PTLinearMetatype:
return -1
# TODO: Add activation axis calculation when MatMul will be supported
return 1
Expand Down
9 changes: 6 additions & 3 deletions nncf/torch/graph/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.definitions import MODEL_CONST_OP_NAME
from nncf.common.graph.operator_metatypes import CONST_NOOP_METATYPES
from nncf.common.graph.operator_metatypes import INPUT_NOOP_METATYPES
from nncf.torch.dynamic_graph.context import TracingContext
Expand Down Expand Up @@ -79,9 +80,11 @@ class GraphConverter:
def convert(dynamic_graph: DynamicGraph, traced_parameters) -> PTNNCFGraph:
module_id_vs_known_op_addrs_map: Dict[int, Set[Scope]] = defaultdict(set)
for dynamic_graph_node in dynamic_graph.get_all_nodes():
module_id_vs_known_op_addrs_map[dynamic_graph_node.calling_module_id].add(
dynamic_graph_node.op_exec_context.op_address
)
# Skip const nodes to detect shared nodes
if dynamic_graph_node.op_exec_context.operator_name != MODEL_CONST_OP_NAME:
module_id_vs_known_op_addrs_map[dynamic_graph_node.calling_module_id].add(
dynamic_graph_node.op_exec_context.op_address
)

module_id_vs_sorted_scopes_map = {
k: list(sorted([s.scope_in_model for s in v], key=str)) for k, v in module_id_vs_known_op_addrs_map.items()
Expand Down
9 changes: 5 additions & 4 deletions nncf/torch/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,18 @@ def requires_graph_rebuild(self):
return True


class PTModelExtractionWithFusedBiasCommand(PTCommand):
class PTModelExtractionCommand(PTCommand):
"""
Extracts sequence by name with node that contain fused bias.
Extracts submodel based on the sub-model input and output names
"""

def __init__(self, node_name: str):
def __init__(self, input_node_names: List[str], output_node_names: List[str]):
"""
:param node_name: Node name that will be extracted.
"""
super().__init__(TransformationType.EXTRACT)
self.node_name = node_name
self.input_node_names = input_node_names
self.output_node_names = output_node_names


class PTBiasCorrectionCommand(PTTransformationCommand):
Expand Down
98 changes: 0 additions & 98 deletions nncf/torch/model_analyzer.py

This file was deleted.

Loading

0 comments on commit 1c06382

Please sign in to comment.