diff --git a/nncf/torch/graph/graph.py b/nncf/torch/graph/graph.py index 1a651f9a363..3a43d51d132 100644 --- a/nncf/torch/graph/graph.py +++ b/nncf/torch/graph/graph.py @@ -53,6 +53,12 @@ def get_input_shape_for_insertion_point(self, insertion_point: PTTargetPoint) -> return quantizer_input_shape def get_op_nodes_in_scope(self, scope: Scope) -> List[NNCFNode]: + """ + Returns all NNCFNodes inside the given scope. + + :param scope: Given scope. + :return: All NNCFNodes inside the given scope. + """ matching_graph_op_nodes = [] for scope_str, nodes_in_module in self._layer_name_vs_shared_nodes.items(): module_scope = Scope.from_str(scope_str) @@ -60,7 +66,22 @@ def get_op_nodes_in_scope(self, scope: Scope) -> List[NNCFNode]: matching_graph_op_nodes.extend(nodes_in_module) return matching_graph_op_nodes + def get_op_nodes_with_scope(self, scope: Scope) -> List[NNCFNode]: + """ + Returns all NNCFNodes which share the given scope. + + :param scope: Given scope. + :return: All NNCFNodes which share the given scope. + """ + return self._layer_name_vs_shared_nodes[str(scope)] + def get_scope_by_node_name(self, node_name: NNCFNodeName) -> Scope: + """ + Returns a scope which corresponds to the given NNCF node name. + + :param node_name: Given node name. + :return: A scope which corresponds to the given NNCF node name. + """ matches = [] for node_id, scope_str in self._node_ids_vs_layer_names.items(): node = self.get_node_by_id(node_id) diff --git a/nncf/torch/graph/transformations/command_creation.py b/nncf/torch/graph/transformations/command_creation.py index 6146803ae19..bb2bf59a122 100644 --- a/nncf/torch/graph/transformations/command_creation.py +++ b/nncf/torch/graph/transformations/command_creation.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import List, Optional, Union +import torch from torch import Tensor from nncf.common.graph.graph import NNCFNode @@ -82,3 +83,26 @@ def create_shared_quantizer_insertion_command( compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, priority=TransformationPriority.QUANTIZATION_PRIORITY, ) + + +def create_pt_insertion_command( + module: torch.nn.Module, + target_type: TargetType, + target_node_name: str, + priority: int, + input_port_id: Optional[int], +) -> PTInsertionCommand: + """ + Creates a PTInsertionCommand. + + :param module: Torch module to insert. + :param target_type: Insertion command target type. + :param target_name: Insertion command target name. + :param priority: Insertion command priority. + :param input_port_id: Insertion command input port id. + :return: A PTInsertionCommand + """ + target_point = PTTargetPoint( + target_type=target_type, target_node_name=target_node_name, input_port_id=input_port_id + ) + return PTInsertionCommand(point=target_point, fn=module, priority=priority) diff --git a/nncf/torch/graph/transformations/commands.py b/nncf/torch/graph/transformations/commands.py index b2461277a5f..1dd2647d584 100644 --- a/nncf/torch/graph/transformations/commands.py +++ b/nncf/torch/graph/transformations/commands.py @@ -10,7 +10,7 @@ # limitations under the License. from enum import Enum -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Union import torch @@ -139,7 +139,7 @@ def __init__( self, point: PTTargetPoint, fn: Callable, - priority: TransformationPriority = TransformationPriority.DEFAULT_PRIORITY, + priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, hooks_group_name: str = DEFAULT_HOOKS_GROUP_NAME, ): super().__init__(TransformationType.INSERT, point) @@ -164,7 +164,7 @@ def __init__( fn: Callable, op_unique_name: str, compression_module_type: ExtraCompressionModuleType = ExtraCompressionModuleType.EXTERNAL_OP, - priority: TransformationPriority = TransformationPriority.DEFAULT_PRIORITY, + priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, hooks_group_name: str = DEFAULT_HOOKS_GROUP_NAME, ): super().__init__(TransformationType.INSERT, None) diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index a27d338a77a..0ca8a14380f 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -41,6 +41,7 @@ from nncf.common.utils.debug import is_debug from nncf.torch.debug import CombinedDebugInterface from nncf.torch.debug import debuggable_forward +from nncf.torch.dynamic_graph.context import PreHookId from nncf.torch.dynamic_graph.context import TracingContext from nncf.torch.dynamic_graph.graph import DynamicGraph from nncf.torch.dynamic_graph.graph import ShapeIgnoringTensorMetaComparator @@ -60,16 +61,21 @@ from nncf.torch.dynamic_graph.wrappers import wrap_module_call from nncf.torch.dynamic_graph.wrappers import wrap_parameters from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_NAME +from nncf.torch.external_hook import ExternalOpCallHook from nncf.torch.graph.graph import PTNNCFGraph from nncf.torch.graph.graph_builder import GraphBuilder from nncf.torch.graph.graph_builder import GraphConverter from nncf.torch.graph.operator_metatypes import OPERATORS_WITH_WEIGHTS_METATYPES from nncf.torch.graph.operator_metatypes import PTSplitMetatype +from nncf.torch.graph.transformations.command_creation import create_pt_insertion_command from nncf.torch.graph.transformations.commands import DEFAULT_HOOKS_GROUP_NAME from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.knowledge_distillation.knowledge_distillation_handler import KnowledgeDistillationLossHandler from nncf.torch.layer_utils import _NNCFModuleMixin +from nncf.torch.module_operations import UpdateWeight from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME from nncf.torch.utils import compute_FLOPs_hook @@ -778,6 +784,127 @@ def _collect_eval_op_scopes(self, model: nn.Module, dummy_forward_fn: Callable) result.append(scope_in_model) return result + def transformation_layout(self) -> PTTransformationLayout: + """ + Collects all hooks applied to the NNCFNetwork, converts them to insertion commands + and returns in PTTransformationLayout format. Default hooks group name is used in + recovered commands, so hooks group names specified during the model modification + become outdated. + + :return: Transformation layout with all commands applied to the NNCFNetwork. + """ + + def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str): + """ + Check given external op call hook reference is correct. + + :param hook: External op call hook to check correctness. + :param info: Info to log in case op call hook references are broken. + """ + assert hasattr( + self, hook._storage_name + ), f"Storage name {hook._storage_name} is not registered. Info: {info}" + assert hook._storage_key in getattr( + self, hook._storage_name + ), f"Key {hook._storage_key} is not registered in {hook._storage_name}. Info: {info}" + + context_hooks = defaultdict(lambda: defaultdict(list)) + transformation_layout = PTTransformationLayout() + nncf_graph = self.get_graph() + nncf_node_names_map = self.get_op_address_to_op_name_map() + + # Collect pre/post layer and op with weights insertion commands + for nncf_module, module_scope in self.get_nncf_modules().items(): + for ops, target_type in ( + (nncf_module.pre_ops, TargetType.PRE_LAYER_OPERATION), + (nncf_module.post_ops, TargetType.POST_LAYER_OPERATION), + ): + for priority, module in enumerate(ops.values()): + nodes_in_scope = nncf_graph.get_op_nodes_with_scope(module_scope) + # Several NNCFNodes means that current NNCFModule was called + # several times. Only one insertion command is required to + # call hook as much times as the current NNCFModule, therefore + # we use first correspondent NNCFNode. + nncf_node = nodes_in_scope[0] + command_target_type = target_type + if isinstance(module, UpdateWeight): + command_target_type = TargetType.OPERATION_WITH_WEIGHTS + module = module.op + if not isinstance(module, ExternalOpCallHook): + command = create_pt_insertion_command( + module, command_target_type, nncf_node.node_name, priority, None + ) + transformation_layout.register(command) + continue + + info = ( + f"TargetType: {command_target_type}, nncf node name: {nncf_node.node_name}," + f" priority: {priority}, fn: {module}" + ) + _check_external_call_hook_is_valid(module, info) + + context_hooks[module._storage_name][module._storage_key].append( + (command_target_type, nncf_node.node_name, priority, module, None) + ) + + # Collect all pre/post hooks commands + for ops, target_type in ( + (self._compressed_context._pre_hooks, TargetType.OPERATOR_PRE_HOOK), + (self._compressed_context._post_hooks, TargetType.OPERATOR_POST_HOOK), + ): + for op_address, hooks in ops.items(): + if isinstance(op_address, PreHookId): + input_port_id = op_address.input_port_id + op_address = op_address.op_address + else: + input_port_id = None + for priority, fn in enumerate(hooks.values()): + target_node_names = nncf_node_names_map[op_address] + # Operation address is unique for each module call + assert len(target_node_names) == 1 + target_node_name = target_node_names[0] + + if not isinstance(fn, ExternalOpCallHook): + command = create_pt_insertion_command( + fn, target_type, target_node_name, priority, input_port_id + ) + transformation_layout.register(command) + continue + + info = f"TargetType: {target_type}, op_address: {op_address}, priority: {priority}, fn: {fn}" + _check_external_call_hook_is_valid(fn, info) + + context_hooks[fn._storage_name][fn._storage_key].append( + (target_type, target_node_name, priority, fn, input_port_id) + ) + + # Create shared fn insertion commands according to external hooks collected from + # pre/post layer, pre/post hooks and op with weights target points. + for module_type_name, storage in context_hooks.items(): + for storage_key, call_hook_list_info in storage.items(): + compression_module = getattr(self, module_type_name)[storage_key] + target_points = [] + for target_type, target_node_name, priority, fn, input_port_id in call_hook_list_info: + target_points.append(PTTargetPoint(target_type, target_node_name, input_port_id=input_port_id)) + + if module_type_name == EXTERNAL_QUANTIZERS_STORAGE_NAME: + module_type = ExtraCompressionModuleType.EXTERNAL_QUANTIZER + elif module_type_name == EXTERNAL_OP_STORAGE_NAME: + module_type = ExtraCompressionModuleType.EXTERNAL_OP + else: + raise RuntimeError(f"Module type {module_type_name} is not supported") + + command = PTSharedFnInsertionCommand( + target_points=target_points, + fn=compression_module, + op_unique_name=storage_key, + compression_module_type=module_type, + priority=priority, + ) + transformation_layout.register(command) + + return transformation_layout + def get_node_to_op_address_mapping(self) -> Dict[NNCFNodeName, OperationAddress]: """ Returns map of NNCFGraph node names vs DynamicGraph operation addresses. @@ -796,6 +923,17 @@ def get_node_to_op_address_mapping(self) -> Dict[NNCFNodeName, OperationAddress] retval[nncf_node.node_name] = op_address return retval + def get_op_address_to_op_name_map(self) -> Dict[OperationAddress, NNCFNodeName]: + """ + Returns map of DynamicGraph operation addresses vs NNCFGraph node names. + + :return: DynamicGraph operation addresses vs NNCFGraph node names. + """ + retval = defaultdict(list) + for nncf_node_name, op_address in self.get_node_to_op_address_mapping().items(): + retval[op_address].append(nncf_node_name) + return retval + def set_compression_controller(self, ctrl: CompressionAlgorithmController): self.compression_controller = ctrl diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index 3dfe3a3df7e..43b0b69c60b 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -8,6 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import contextlib import numbers from abc import ABC @@ -38,6 +39,8 @@ from nncf.torch.dynamic_graph.io_handling import FillerInputInfo from nncf.torch.dynamic_graph.operation_address import OperationAddress from nncf.torch.dynamic_graph.scope import Scope +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.initialization import PTInitializingDataLoader from nncf.torch.initialization import register_default_init_args from nncf.torch.layers import NNCF_MODULES_MAP @@ -172,6 +175,16 @@ def nz_bias_num(self): class TwoConvTestModel(nn.Module): + INPUT_SHAPE = [1, 1, 4, 4] + NNCF_CONV_NODES_NAMES = [ + "TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0", + "TwoConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0", + ] + CONV_NODES_NAMES = [ + "TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0", + "TwoConvTestModel/Sequential[features]/Sequential[1]/Conv2d[0]/conv2d_0", + ] + def __init__(self): super().__init__() self.features = [] @@ -199,6 +212,30 @@ def nz_bias_num(self): return 2 +class TwoSharedConvTestModel(nn.Module): + INPUT_SHAPE = [1, 1, 4, 4] + NNCF_CONV_NODES_NAMES = [ + "TwoSharedConvTestModel/NNCFConv2d[conv1]/conv2d_0", + "TwoSharedConvTestModel/NNCFConv2d[conv2]/conv2d_0", + ] + CONV_NODES_NAMES = [ + "TwoSharedConvTestModel/Conv2d[conv1]/conv2d_0", + "TwoSharedConvTestModel/Conv2d[conv2]/conv2d_0", + ] + + def __init__(self): + super().__init__() + self.features = [] + self.conv1 = create_conv(1, 1, 1, -1, -2) + self.conv2 = create_conv(1, 1, 1, 0, 0) + + def forward(self, x): + for _ in range(2): + x = self.conv1(x) + x = self.conv2(x) + return x + + class LeNet(nn.Module): INPUT_SIZE = 1, 32, 32 @@ -228,6 +265,72 @@ def num_flat_features(self, x): return num_features +class DummyOpWithState(torch.nn.Module): + def __init__(self, state: str): + super().__init__() + self._state = state + + def __call__(self, *args): + if len(args) == 1: + return args[0] + # To work correctly with + # TargetType.PRE_LAYER_OPERATION + # TargetType.POST_LAYER_OPERATION + return None + + def get_state(self): + return self._state + + @classmethod + def from_state(cls, state: str): + return cls(state) + + +def commands_are_equal( + command_left: Union[PTInsertionCommand, PTSharedFnInsertionCommand], + command_right: Union[PTInsertionCommand, PTSharedFnInsertionCommand], + check_priority: bool = True, + check_hooks_group_name: bool = True, + check_fn_ref=True, +) -> bool: + """ + Returns True if given commands are equal and False elsewhere. + + :param command_left: The first command. + :param command_right: The second command. + :param check_priority: Whether to check insertion priority or not. + :param check_hooks_group_name: Whether to check hooks group name or not. + :param check_fn_ref: Whether to check fn by reference or not. + :returns: True if given commands are equal and False elsewhere. + """ + if type(command_right) is not type(command_left): + return False + + # Check reference to functions are equal. + if check_fn_ref and command_right.fn is not command_left.fn: + return False + if check_hooks_group_name and command_right.hooks_group_name != command_left.hooks_group_name: + return False + if check_priority and command_right.priority != command_left.priority: + return False + + if isinstance(command_right, PTInsertionCommand): + if command_left.target_point != command_right.target_point: + return False + elif isinstance(command_right, PTSharedFnInsertionCommand): + if not all(a == b for a, b in zip(command_left.target_points, command_right.target_points)): + return False + if ( + command_right.target_points != command_left.target_points + or command_right.op_name != command_left.op_name + or command_right.compression_module_type != command_left.compression_module_type + ): + return False + else: + raise RuntimeError() + return True + + class SharedConv(nn.Module): INPUT_SIZE = [1, 1, 4, 4] diff --git a/tests/torch/nncf_network/helpers.py b/tests/torch/nncf_network/helpers.py new file mode 100644 index 00000000000..06805cd59b7 --- /dev/null +++ b/tests/torch/nncf_network/helpers.py @@ -0,0 +1,215 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import itertools +from typing import Optional, Type + +import torch + +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.commands import PTTransformationCommand +from nncf.torch.layer_utils import COMPRESSION_MODULES +from tests.torch.helpers import DummyOpWithState +from tests.torch.helpers import TwoConvTestModel +from tests.torch.helpers import TwoSharedConvTestModel + + +class SimplestModel(torch.nn.Module): + INPUT_SIZE = [1, 1, 32, 32] + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + return self.conv(x) + + +AVAILABLE_TARGET_TYPES = ( + TargetType.OPERATION_WITH_WEIGHTS, + TargetType.OPERATOR_PRE_HOOK, + TargetType.OPERATOR_POST_HOOK, + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, +) + + +class InsertionCommandBuilder: + """ + Contains methods which allows to build all possible commands + for the given torch.nn.Module. Target module should have + NNCF_CONV_NODES_NAMES and CONV_NODES_NAMES with names of + target model convolutions and names of nncf-wrapped target model convolutions + """ + + AVAILABLE_MODELS = (TwoConvTestModel, TwoSharedConvTestModel) + + def __init__(self, model_cls: Type[torch.nn.Module]): + self.model_cls = model_cls + + TRACE_VS_NODE_NAMES = {True: "CONV_NODES_NAMES", False: "NNCF_CONV_NODES_NAMES"} + + @staticmethod + def get_input_port_id(target_type: TargetType, trace_parameters: bool) -> Optional[int]: + if target_type is TargetType.OPERATOR_PRE_HOOK: + return 0 + if trace_parameters and target_type in [TargetType.PRE_LAYER_OPERATION, TargetType.OPERATION_WITH_WEIGHTS]: + return 1 + return None + + def create_pt_insertion_command( + self, + target_type: TargetType, + priority: TransformationPriority, + trace_parameters: bool, + fn: Optional[torch.nn.Module] = None, + group: str = "default_group", + op_unique_name: Optional[str] = None, + ): + attr_name = self.TRACE_VS_NODE_NAMES[trace_parameters] + target_point = PTTargetPoint( + target_type=target_type, + target_node_name=getattr(self.model_cls, attr_name)[0], + input_port_id=self.get_input_port_id(target_type, trace_parameters), + ) + if fn is None: + fn = DummyOpWithState("DUMMY_STATE") + return PTInsertionCommand(point=target_point, fn=fn, priority=priority, hooks_group_name=group) + + def create_pt_shared_fn_insertion_command( + self, + target_type: TargetType, + priority: TransformationPriority, + trace_parameters: bool, + compression_module_type: ExtraCompressionModuleType, + fn: Optional[torch.nn.Module] = None, + group: str = "default_group", + op_unique_name: str = "UNIQUE_NAME", + ): + target_points = [] + attr_name = self.TRACE_VS_NODE_NAMES[trace_parameters] + for node_name in getattr(self.model_cls, attr_name): + target_points.append( + PTTargetPoint( + target_type=target_type, + target_node_name=node_name, + input_port_id=self.get_input_port_id(target_type, trace_parameters), + ) + ) + if fn is None: + fn = DummyOpWithState("DUMMY_STATE") + return PTSharedFnInsertionCommand( + target_points=target_points, + fn=fn, + compression_module_type=compression_module_type, + op_unique_name=op_unique_name, + priority=priority, + hooks_group_name=group, + ) + + def get_command_builders(self): + """ + Get all command builders available and their types in a tuple of pairs. + """ + return ( + (self.create_pt_insertion_command, PTInsertionCommand), + ( + functools.partial( + self.create_pt_shared_fn_insertion_command, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP, + ), + PTSharedFnInsertionCommand, + ), + ( + functools.partial( + self.create_pt_shared_fn_insertion_command, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, + ), + PTSharedFnInsertionCommand, + ), + ) + + # Check priority as an enum member and as an int + PRIORITIES = (TransformationPriority.QUANTIZATION_PRIORITY, TransformationPriority.QUANTIZATION_PRIORITY.value + 1) + + def get_all_available_commands( + self, dummy_op_state, trace_parameters, skip_model_transformer_unsupported=False + ) -> TransformationLayout: + """ + Returns all possible commands to insert: + all target types x all command class x all compression module types x different priorities. + """ + layout = TransformationLayout() + for idx, (target_type, (command_builder, command_type), priority) in enumerate( + itertools.product(AVAILABLE_TARGET_TYPES, self.get_command_builders(), self.PRIORITIES) + ): + if skip_model_transformer_unsupported and self.is_unsupported_by_transformer_command( + command_type, target_type + ): + continue + command = self._create_command( + command_builder, + target_type, + priority, + dummy_op_state, + op_unique_name=f"UNIQUE_NAME_{idx}", + trace_parameters=trace_parameters, + ) + + layout.register(command) + return layout + + @staticmethod + def is_unsupported_by_transformer_command(command_type: PTTransformationCommand, target_type: TargetType) -> bool: + """ + Returns True if insertion parameters don't supported by the PTModelTransformer otherwise False. + """ + return command_type is PTSharedFnInsertionCommand and target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ] + + @staticmethod + def _create_command( + command_builder, + target_type, + priority, + dummy_op_state, + trace_parameters, + op_unique_name, + ): + """ + Creates command with specified parameters and dummy op. + """ + # Register dummy op name in the COMPRESSION_MODULES + if DummyOpWithState.__name__ not in COMPRESSION_MODULES.registry_dict: + registered_dummy_op_cls = COMPRESSION_MODULES.register()(DummyOpWithState) + else: + registered_dummy_op_cls = DummyOpWithState + dummy_op = registered_dummy_op_cls(dummy_op_state) + + # Build the command + group_name = "CUSTOM_HOOKS_GROUP_NAME" + return command_builder( + target_type, + priority, + fn=dummy_op, + group=group_name, + op_unique_name=op_unique_name, + trace_parameters=trace_parameters, + ) diff --git a/tests/torch/nncf_network/test_get_applied_modifications.py b/tests/torch/nncf_network/test_get_applied_modifications.py new file mode 100644 index 00000000000..075625148eb --- /dev/null +++ b/tests/torch/nncf_network/test_get_applied_modifications.py @@ -0,0 +1,189 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from typing import Tuple, Type, Union + +import pytest +import torch + +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.torch import wrap_model +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.graph.transformations.layout import PTTransformationLayout +from nncf.torch.model_transformer import PTModelTransformer +from tests.torch.helpers import commands_are_equal +from tests.torch.nncf_network.helpers import AVAILABLE_TARGET_TYPES +from tests.torch.nncf_network.helpers import InsertionCommandBuilder + +TARGET_TYPE_VS_TARGET_TYPE_DICT_FOR_NOT_REPLACED_MODULES = { + TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK, + TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK, + TargetType.OPERATION_WITH_WEIGHTS: TargetType.OPERATOR_PRE_HOOK, + TargetType.OPERATOR_PRE_HOOK: TargetType.OPERATOR_PRE_HOOK, + TargetType.OPERATOR_POST_HOOK: TargetType.OPERATOR_POST_HOOK, +} + + +@pytest.fixture(name="trace_parameters", params=(True, False)) +def trace_parameters_fixture(request) -> bool: + return request.param + + +def _get_trace_params_target_types_command_builders_and_models_cls() -> ( + Tuple[bool, Type[torch.nn.Module], TargetType, callable] +): + """ + Returns list of all avaliable command builders + """ + retval = [] + for ( + trace_parameters, + model_cls, + target_type, + ) in itertools.product( + (True, False), + InsertionCommandBuilder.AVAILABLE_MODELS, + AVAILABLE_TARGET_TYPES, + ): + for command_builder, command_cls in InsertionCommandBuilder(model_cls).get_command_builders(): + if not trace_parameters and InsertionCommandBuilder.is_unsupported_by_transformer_command( + command_cls, target_type + ): + print(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") + continue + retval.append((trace_parameters, model_cls, target_type, command_builder)) + return retval + + +def _translate_target_types(trace_parameters: bool, command: Union[PTInsertionCommand, PTSharedFnInsertionCommand]): + """ + Translates target types in case trace_parameters is True + """ + if not trace_parameters: + return + if isinstance(command, PTInsertionCommand): + target_points = [command.target_point] + else: + target_points = command.target_points + + for target_point in target_points: + new_target_type = TARGET_TYPE_VS_TARGET_TYPE_DICT_FOR_NOT_REPLACED_MODULES[target_point.type] + target_point._target_type = new_target_type + target_point.target_type = new_target_type + + +@pytest.mark.parametrize( + "trace_parameters,model_cls,target_type,command_builder", + _get_trace_params_target_types_command_builders_and_models_cls(), +) +def test_get_applied_modification_commands( + model_cls: Type[torch.nn.Module], command_builder: callable, target_type: TargetType, trace_parameters: bool +): + model = model_cls() + nncf_model = wrap_model(model, torch.zeros(model_cls.INPUT_SHAPE), trace_parameters=trace_parameters) + model_transformer = PTModelTransformer(nncf_model) + + layout = PTTransformationLayout() + command = command_builder(target_type, TransformationPriority.DEFAULT_PRIORITY, trace_parameters=trace_parameters) + layout.register(command) + model_transformer.transform(layout) + + applied_commands = nncf_model.nncf.transformation_layout() + + assert len(applied_commands.transformations) == 1 + applied_command = applied_commands.transformations[0] + _translate_target_types(trace_parameters, command) + assert commands_are_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) + + +@pytest.mark.parametrize( + "trace_parameters,model_cls,target_type,command_builder", + _get_trace_params_target_types_command_builders_and_models_cls(), +) +def test_priority_of_get_applied_modification_commands( + command_builder: callable, model_cls: Type[torch.nn.Module], target_type: TargetType, trace_parameters: bool +): + layout = PTTransformationLayout() + commands = dict() + for priority in (0, 2, 1): + command = command_builder( + target_type, priority, op_unique_name=f"UNIQUE_NAME_{priority}", trace_parameters=trace_parameters + ) + layout.register(command) + commands[priority] = command + + model = model_cls() + nncf_model = wrap_model(model, torch.zeros(model_cls.INPUT_SHAPE), trace_parameters=trace_parameters) + model_transformer = PTModelTransformer(nncf_model) + + model_transformer.transform(layout) + + applied_commands = nncf_model.nncf.transformation_layout() + assert len(applied_commands.transformations) == len(commands) + for applied_command in applied_commands.transformations: + command = commands[applied_command.priority] + _translate_target_types(trace_parameters, command) + assert commands_are_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) + + +@pytest.mark.parametrize("model_cls", InsertionCommandBuilder.AVAILABLE_MODELS) +def test_all_possible_combinations_of_commands_for_get_applied_commands( + model_cls: Type[torch.nn.Module], trace_parameters: bool +): + dummy_state = "DummyState" + commands = InsertionCommandBuilder(model_cls).get_all_available_commands( + dummy_state, skip_model_transformer_unsupported=not trace_parameters, trace_parameters=trace_parameters + ) + + model = model_cls() + nncf_model = wrap_model(model, torch.zeros(model_cls.INPUT_SHAPE), trace_parameters=trace_parameters) + model_transformer = PTModelTransformer(nncf_model) + + model_transformer.transform(commands) + + applied_commands = nncf_model.nncf.transformation_layout() + assert len(applied_commands.transformations) == len(commands.transformations) + for command in commands.transformations: + _translate_target_types(trace_parameters, command) + eq_commands = ( + commands_are_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) + for applied_command in applied_commands.transformations + ) + if sum(map(int, eq_commands)) != 1: + raise RuntimeError(f"Command {command} has no pair in recovered commands") + + +@pytest.mark.parametrize("target_type", (TargetType.OPERATION_WITH_WEIGHTS, TargetType.OPERATOR_PRE_HOOK)) +@pytest.mark.parametrize("model_cls", InsertionCommandBuilder.AVAILABLE_MODELS) +def test_get_applied_modification_commands_broken_call_hook( + model_cls: Type[torch.nn.Module], target_type: TargetType, trace_parameters: bool +): + model = model_cls() + nncf_model = wrap_model(model, torch.zeros(model_cls.INPUT_SHAPE), trace_parameters=trace_parameters) + model_transformer = PTModelTransformer(nncf_model) + + layout = PTTransformationLayout() + command = InsertionCommandBuilder(model_cls).create_pt_shared_fn_insertion_command( + target_type=target_type, + priority=0, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP, + trace_parameters=trace_parameters, + ) + layout.register(command) + model_transformer.transform(layout) + + nncf_model.nncf.external_op.clear() + with pytest.raises(AssertionError): + nncf_model.nncf.transformation_layout() diff --git a/tests/torch/nncf_network/test_hook_handlers.py b/tests/torch/nncf_network/test_hook_handlers.py new file mode 100644 index 00000000000..06c22d24a84 --- /dev/null +++ b/tests/torch/nncf_network/test_hook_handlers.py @@ -0,0 +1,119 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Tuple + +import pytest +import torch + +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.hook_handle import HookHandle +from nncf.torch.dynamic_graph.io_handling import ExampleInputInfo +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.nncf_network import PTInsertionPoint +from tests.torch.helpers import HookChecker +from tests.torch.nncf_network.helpers import SimplestModel + + +@pytest.mark.parametrize( + "target_type, target_node_name, input_port_id", + [ + (TargetType.OPERATOR_PRE_HOOK, "/nncf_model_output_0", 0), + (TargetType.OPERATOR_POST_HOOK, "/nncf_model_input_0", 0), + (TargetType.PRE_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), + (TargetType.POST_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), + ], +) +class TestHookHandles: + class TestHook(torch.nn.Module): + def __init__(self): + super().__init__() + self._p = torch.nn.Parameter(torch.zeros((1,))) + + def forward(self, x): + return x + self._p + + @staticmethod + def _prepare_hook_handles_test( + target_type: TargetType, target_node_name: str, input_port_id: int + ) -> Tuple[NNCFNetwork, PTInsertionPoint, Callable[[List[HookHandle]], None]]: + model = SimplestModel() + example_input = torch.ones(SimplestModel.INPUT_SIZE) + input_info = ExampleInputInfo.from_example_input(example_input) + nncf_model = NNCFNetwork(model, input_info) + + node_name_vs_address = nncf_model.nncf.get_node_to_op_address_mapping() + ip = PTInsertionPoint(target_type, node_name_vs_address[target_node_name], input_port_id=input_port_id) + + checker = HookChecker(nncf_model, "conv") + + def _check(ref_hooks_): + checker.clear() + checker.add_ref(ref_hooks_, target_type, target_node_name, input_port_id) + checker.check_with_reference() + + return nncf_model, ip, _check + + def test_temporary_insert_at_point_by_hook_group_name( + self, target_type: TargetType, target_node_name: str, input_port_id: int + ): + nncf_model, ip, _check = self._prepare_hook_handles_test(target_type, target_node_name, input_port_id) + permanent_hook = self.TestHook() + TEMPORARY_HOOK_GROUP_NAME = "tmp" + # Make temporary hook a ref to the permanent hook + # to check tmp hooks are not removed by their id() + temporary_hook = permanent_hook + nncf_model.nncf.insert_at_point(ip, permanent_hook) + ref_hooks = [permanent_hook] + _check(ref_hooks) + + for _ in range(2): + temporary_hook = self.TestHook() + nncf_model.nncf.insert_at_point(ip, temporary_hook, TEMPORARY_HOOK_GROUP_NAME) + ref_hooks.append(temporary_hook) + _check(ref_hooks) + + nncf_model.nncf.insert_at_point(ip, permanent_hook) + ref_hooks.append(permanent_hook) + _check(ref_hooks) + + nncf_model.nncf.remove_hooks_group(TEMPORARY_HOOK_GROUP_NAME) + del ref_hooks[-2] + _check(ref_hooks) + assert not nncf_model.nncf._groups_vs_hooks_handlers[TEMPORARY_HOOK_GROUP_NAME] + + def test_insert_at_point_hook_handles(self, target_type: TargetType, target_node_name: str, input_port_id: int): + nncf_model, ip, _check = self._prepare_hook_handles_test(target_type, target_node_name, input_port_id) + permanent_hook = self.TestHook() + # Make temporary hook a ref to the permanent hook + # to check tmp hooks are not removed by their id() + temporary_hook = permanent_hook + tmp_hh = [] + nncf_model.nncf.insert_at_point(ip, permanent_hook) + + ref_hooks = [permanent_hook] + _check(ref_hooks) + + for _ in range(2): + temporary_hook = self.TestHook() + tmp_hh.append(nncf_model.nncf.insert_at_point(ip, temporary_hook)) + ref_hooks.append(temporary_hook) + _check(ref_hooks) + + nncf_model.nncf.insert_at_point(ip, permanent_hook) + ref_hooks.append(permanent_hook) + _check(ref_hooks) + + for hh in tmp_hh: + hh.remove() + + del ref_hooks[-2] + _check(ref_hooks) diff --git a/tests/torch/test_nncf_network.py b/tests/torch/nncf_network/test_nncf_network.py similarity index 89% rename from tests/torch/test_nncf_network.py rename to tests/torch/nncf_network/test_nncf_network.py index c4da4be8c82..5fff987ea44 100644 --- a/tests/torch/test_nncf_network.py +++ b/tests/torch/nncf_network/test_nncf_network.py @@ -14,7 +14,7 @@ from abc import ABCMeta from abc import abstractmethod from copy import deepcopy -from typing import Callable, List, Tuple, Type +from typing import Callable, Type import pytest import torch @@ -27,7 +27,6 @@ from nncf.common.graph import NNCFNode from nncf.common.graph.operator_metatypes import UnknownMetatype from nncf.common.graph.transformations.commands import TargetType -from nncf.common.hook_handle import HookHandle from nncf.torch import register_module from nncf.torch.dynamic_graph.io_handling import ExampleInputInfo from nncf.torch.dynamic_graph.io_handling import FillerInputElement @@ -50,11 +49,11 @@ from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME from tests.torch.composite.test_sparsity_quantization import get_basic_sparsity_plus_quantization_config from tests.torch.helpers import BasicConvTestModel -from tests.torch.helpers import HookChecker from tests.torch.helpers import TwoConvTestModel from tests.torch.helpers import check_correct_nncf_modules_replacement from tests.torch.helpers import create_compressed_model_and_algo_for_test from tests.torch.helpers import register_bn_adaptation_init_args +from tests.torch.nncf_network.helpers import SimplestModel from tests.torch.test_models.synthetic import ManyNonEvalModules @@ -613,17 +612,6 @@ def test_can_work_with_sequential_models(): _ = model.nncf.get_clean_shallow_copy() -class SimplestModel(torch.nn.Module): - INPUT_SIZE = [1, 1, 32, 32] - - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(1, 1, 1) - - def forward(self, x): - return self.conv(x) - - @pytest.fixture(name="simple_net") def simple_net_(): model = NNCFNetwork(SimplestModel(), FillerInputInfo([FillerInputElement(SimplestModel.INPUT_SIZE)])) @@ -928,99 +916,3 @@ def test_insert_hook_after_parameter(): assert hook.forward_calls_counter == 1 assert torch.sum(result.nonzero()) > 0 assert torch.sum(result_with_hook.nonzero()) == 0 - - -@pytest.mark.parametrize( - "target_type, target_node_name, input_port_id", - [ - (TargetType.OPERATOR_PRE_HOOK, "/nncf_model_output_0", 0), - (TargetType.OPERATOR_POST_HOOK, "/nncf_model_input_0", 0), - (TargetType.PRE_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), - (TargetType.POST_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), - ], -) -class TestHookHandles: - class TestHook(torch.nn.Module): - def __init__(self): - super().__init__() - self._p = torch.nn.Parameter(torch.zeros((1,))) - - def forward(self, x): - return x + self._p - - @staticmethod - def _prepare_hook_handles_test( - target_type: TargetType, target_node_name: str, input_port_id: int - ) -> Tuple[NNCFNetwork, PTInsertionPoint, Callable[[List[HookHandle]], None]]: - model = SimplestModel() - example_input = torch.ones(SimplestModel.INPUT_SIZE) - input_info = ExampleInputInfo.from_example_input(example_input) - nncf_model = NNCFNetwork(model, input_info) - - node_name_vs_address = nncf_model.nncf.get_node_to_op_address_mapping() - ip = PTInsertionPoint(target_type, node_name_vs_address[target_node_name], input_port_id=input_port_id) - - checker = HookChecker(nncf_model, "conv") - - def _check(ref_hooks_): - checker.clear() - checker.add_ref(ref_hooks_, target_type, target_node_name, input_port_id) - checker.check_with_reference() - - return nncf_model, ip, _check - - def test_temporary_insert_at_point_by_hook_group_name( - self, target_type: TargetType, target_node_name: str, input_port_id: int - ): - nncf_model, ip, _check = self._prepare_hook_handles_test(target_type, target_node_name, input_port_id) - permanent_hook = self.TestHook() - TEMPORARY_HOOK_GROUP_NAME = "tmp" - # Make temporary hook a ref to the permanent hook - # to check tmp hooks are not removed by their id() - temporary_hook = permanent_hook - nncf_model.nncf.insert_at_point(ip, permanent_hook) - ref_hooks = [permanent_hook] - _check(ref_hooks) - - for _ in range(2): - temporary_hook = self.TestHook() - nncf_model.nncf.insert_at_point(ip, temporary_hook, TEMPORARY_HOOK_GROUP_NAME) - ref_hooks.append(temporary_hook) - _check(ref_hooks) - - nncf_model.nncf.insert_at_point(ip, permanent_hook) - ref_hooks.append(permanent_hook) - _check(ref_hooks) - - nncf_model.nncf.remove_hooks_group(TEMPORARY_HOOK_GROUP_NAME) - del ref_hooks[-2] - _check(ref_hooks) - assert not nncf_model.nncf._groups_vs_hooks_handlers[TEMPORARY_HOOK_GROUP_NAME] - - def test_insert_at_point_hook_handles(self, target_type: TargetType, target_node_name: str, input_port_id: int): - nncf_model, ip, _check = self._prepare_hook_handles_test(target_type, target_node_name, input_port_id) - permanent_hook = self.TestHook() - # Make temporary hook a ref to the permanent hook - # to check tmp hooks are not removed by their id() - temporary_hook = permanent_hook - tmp_hh = [] - nncf_model.nncf.insert_at_point(ip, permanent_hook) - - ref_hooks = [permanent_hook] - _check(ref_hooks) - - for _ in range(2): - temporary_hook = self.TestHook() - tmp_hh.append(nncf_model.nncf.insert_at_point(ip, temporary_hook)) - ref_hooks.append(temporary_hook) - _check(ref_hooks) - - nncf_model.nncf.insert_at_point(ip, permanent_hook) - ref_hooks.append(permanent_hook) - _check(ref_hooks) - - for hh in tmp_hh: - hh.remove() - - del ref_hooks[-2] - _check(ref_hooks) diff --git a/tests/torch/test_api_behavior.py b/tests/torch/test_api_behavior.py index 9e18479e94e..565eaa7e86b 100644 --- a/tests/torch/test_api_behavior.py +++ b/tests/torch/test_api_behavior.py @@ -24,7 +24,7 @@ from tests.torch.helpers import OnesDatasetMock from tests.torch.helpers import TwoConvTestModel from tests.torch.helpers import create_compressed_model_and_algo_for_test -from tests.torch.test_nncf_network import SimplestModel +from tests.torch.nncf_network.helpers import SimplestModel INPUT_SAMPLE_SIZE = [1, 1, 4, 4] CONFIG_WITH_ALL_INIT_TYPES = {