diff --git a/nncf/torch/graph/graph.py b/nncf/torch/graph/graph.py index 1a651f9a363..63637f759b0 100644 --- a/nncf/torch/graph/graph.py +++ b/nncf/torch/graph/graph.py @@ -60,6 +60,13 @@ def get_op_nodes_in_scope(self, scope: Scope) -> List[NNCFNode]: matching_graph_op_nodes.extend(nodes_in_module) return matching_graph_op_nodes + def get_op_node_in_scope(self, scope: Scope) -> List[NNCFNode]: + for scope_str, nodes_in_module in self._layer_name_vs_shared_nodes.items(): + module_scope = Scope.from_str(scope_str) + if module_scope == scope: + return nodes_in_module + return [] + def get_scope_by_node_name(self, node_name: NNCFNodeName) -> Scope: matches = [] for node_id, scope_str in self._node_ids_vs_layer_names.items(): diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index a27d338a77a..c06fd82b0a7 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -41,6 +41,7 @@ from nncf.common.utils.debug import is_debug from nncf.torch.debug import CombinedDebugInterface from nncf.torch.debug import debuggable_forward +from nncf.torch.dynamic_graph.context import PreHookId from nncf.torch.dynamic_graph.context import TracingContext from nncf.torch.dynamic_graph.graph import DynamicGraph from nncf.torch.dynamic_graph.graph import ShapeIgnoringTensorMetaComparator @@ -60,6 +61,7 @@ from nncf.torch.dynamic_graph.wrappers import wrap_module_call from nncf.torch.dynamic_graph.wrappers import wrap_parameters from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_NAME +from nncf.torch.external_hook import ExternalOpCallHook from nncf.torch.graph.graph import PTNNCFGraph from nncf.torch.graph.graph_builder import GraphBuilder from nncf.torch.graph.graph_builder import GraphConverter @@ -67,9 +69,13 @@ from nncf.torch.graph.operator_metatypes import PTSplitMetatype from nncf.torch.graph.transformations.commands import DEFAULT_HOOKS_GROUP_NAME from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.knowledge_distillation.knowledge_distillation_handler import KnowledgeDistillationLossHandler from nncf.torch.layer_utils import _NNCFModuleMixin +from nncf.torch.module_operations import UpdateWeight from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME from nncf.torch.utils import compute_FLOPs_hook @@ -778,6 +784,116 @@ def _collect_eval_op_scopes(self, model: nn.Module, dummy_forward_fn: Callable) result.append(scope_in_model) return result + def get_applied_transformation_layout(self) -> PTTransformationLayout: + """ + Collects all hooks applied to the NNCFNetwork, converts them to insertion commands + and returns in PTTransformationLayout format. Default hooks group name is used in + recovered commands, so hooks group names specified diring the model modification + become outdated. + + :return: Transformation layout with all commands applied to the NNCFNetwork. + """ + + def _create_pt_insert_command(module, target_type, target_node_name, priority, input_port_id): + target_point = PTTargetPoint( + target_type=target_type, target_node_name=target_node_name, input_port_id=input_port_id + ) + return PTInsertionCommand(point=target_point, fn=module, priority=priority) + + def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str): + assert hasattr( + self, hook._storage_name + ), f"Storage name {hook._storage_name} is not registered. Info: {info}" + assert hook._storage_key in getattr( + self, hook._storage_name + ), f"Storage key {hook._storage_key} is not registered. Info: {info}" + + context_hooks = defaultdict(lambda: defaultdict(list)) + transformation_layout = PTTransformationLayout() + nncf_graph = self.get_graph() + nncf_node_names_map = self.get_op_address_to_op_name_map() + + # Collect pre/post layer and op with weights insertion commands + for nncf_module, module_scope in self.get_nncf_modules().items(): + for ops, target_type in ( + (nncf_module.pre_ops, TargetType.PRE_LAYER_OPERATION), + (nncf_module.post_ops, TargetType.POST_LAYER_OPERATION), + ): + for priority, module in enumerate(ops.values()): + nodes_in_scope = nncf_graph.get_op_node_in_scope(module_scope) + assert len(nodes_in_scope) == 1 + nncf_node = nodes_in_scope[0] + if isinstance(module, UpdateWeight): + target_type = TargetType.OPERATION_WITH_WEIGHTS + module = module.op + if not isinstance(module, ExternalOpCallHook): + command = _create_pt_insert_command(module, target_type, nncf_node.node_name, priority, None) + transformation_layout.register(command) + continue + + info = f"TargetType: {target_type}, nncf node name: {nncf_node.node_name}," + f" priority: {priority}, fn: {module}" + _check_external_call_hook_is_valid(module, info) + + context_hooks[module._storage_name][module._storage_key].append( + (target_type, nncf_node.node_name, priority, module, None) + ) + + # Collect all pre/post hooks commands + for ops, target_type in ( + (self._compressed_context._pre_hooks, TargetType.OPERATOR_PRE_HOOK), + (self._compressed_context._post_hooks, TargetType.OPERATOR_POST_HOOK), + ): + for op_address, hooks in ops.items(): + if isinstance(op_address, PreHookId): + input_port_id = op_address.input_port_id + op_address = op_address.op_address + else: + input_port_id = None + for priority, fn in enumerate(hooks.values()): + target_node_names = nncf_node_names_map[op_address] + assert len(target_node_names) == 1 + target_node_name = target_node_names[0] + + if not isinstance(fn, ExternalOpCallHook): + command = _create_pt_insert_command(fn, target_type, target_node_name, priority, input_port_id) + transformation_layout.register(command) + continue + + info = f"TargetType: {target_type}, op_address: {op_address}, priority: {priority}, fn: {fn}" + _check_external_call_hook_is_valid(fn, info) + + context_hooks[fn._storage_name][fn._storage_key].append( + (target_type, target_node_name, priority, fn, input_port_id) + ) + + # Create shared fn insertion commands according to external hooks collected from + # pre/post layer, pre/post hooks and op with weights target points. + for module_type_name, storage in context_hooks.items(): + for storage_key, call_hook_list_info in storage.items(): + compression_module = getattr(self, module_type_name)[storage_key] + target_points = [] + for target_type, target_node_name, priority, fn, input_port_id in call_hook_list_info: + target_points.append(PTTargetPoint(target_type, target_node_name, input_port_id=input_port_id)) + + if module_type_name == EXTERNAL_QUANTIZERS_STORAGE_NAME: + module_type = ExtraCompressionModuleType.EXTERNAL_QUANTIZER + elif module_type_name == EXTERNAL_OP_STORAGE_NAME: + module_type = ExtraCompressionModuleType.EXTERNAL_OP + else: + raise RuntimeError(f"Module type {module_type_name} is not supported") + + command = PTSharedFnInsertionCommand( + target_points=target_points, + fn=compression_module, + op_unique_name=storage_key, + compression_module_type=module_type, + priority=priority, + ) + transformation_layout.register(command) + + return transformation_layout + def get_node_to_op_address_mapping(self) -> Dict[NNCFNodeName, OperationAddress]: """ Returns map of NNCFGraph node names vs DynamicGraph operation addresses. @@ -796,6 +912,17 @@ def get_node_to_op_address_mapping(self) -> Dict[NNCFNodeName, OperationAddress] retval[nncf_node.node_name] = op_address return retval + def get_op_address_to_op_name_map(self) -> Dict[OperationAddress, NNCFNodeName]: + """ + Returns map of DynamicGraph operation addresses vs NNCFGraph node names. + + :return: DynamicGraph operation addresses vs NNCFGraph node names. + """ + retval = defaultdict(list) + for nncf_node_name, op_address in self.get_node_to_op_address_mapping().items(): + retval[op_address].append(nncf_node_name) + return retval + def set_compression_controller(self, ctrl: CompressionAlgorithmController): self.compression_controller = ctrl diff --git a/tests/torch/test_nncf_network.py b/tests/torch/test_nncf_network.py index c4da4be8c82..f282c95887a 100644 --- a/tests/torch/test_nncf_network.py +++ b/tests/torch/test_nncf_network.py @@ -27,6 +27,7 @@ from nncf.common.graph import NNCFNode from nncf.common.graph.operator_metatypes import UnknownMetatype from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationPriority from nncf.common.hook_handle import HookHandle from nncf.torch import register_module from nncf.torch.dynamic_graph.io_handling import ExampleInputInfo @@ -40,9 +41,14 @@ from nncf.torch.graph.operator_metatypes import PTConv2dMetatype from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.layer_utils import _NNCFModuleMixin from nncf.torch.layers import NNCFConv2d from nncf.torch.model_creation import wrap_model +from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.nncf_network import PTInsertionPoint @@ -1024,3 +1030,176 @@ def test_insert_at_point_hook_handles(self, target_type: TargetType, target_node del ref_hooks[-2] _check(ref_hooks) + + +class DummyOpWithState(torch.nn.Module): + def __init__(self, state: str): + super().__init__() + self._state = state + + def __call__(self, *args): + if len(args) == 1: + return args[0] + # To work correctly with + # TargetType.PRE_LAYER_OPERATION + # TargetType.POST_LAYER_OPERATION + return None + + def get_state(self): + return self._state.copy() + + @classmethod + def from_state(cls, state: str): + return cls(state) + + +TWO_CONV_MODEL_NODES_NAMES = [ + "TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0", + "TwoConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0", +] + + +def _create_pt_insertion_command( + target_type: TargetType, priority: TransformationPriority, group: str = "default_group" +): + target_point = PTTargetPoint( + target_type=target_type, target_node_name=TWO_CONV_MODEL_NODES_NAMES[0], input_port_id=0 + ) + fn = DummyOpWithState("DUMMY_STATE") + return PTInsertionCommand(point=target_point, fn=fn, priority=priority, hooks_group_name=group) + + +def _create_pt_shared_fn_insertion_command( + target_type: TargetType, + priority: TransformationPriority, + compression_module_type: ExtraCompressionModuleType, + group: str = "default_group", + op_unique_name: str = "UNIQUE_NAME", +): + target_points = [] + + for node_name in TWO_CONV_MODEL_NODES_NAMES: + target_points.append(PTTargetPoint(target_type=target_type, target_node_name=node_name, input_port_id=0)) + fn = DummyOpWithState("DUMMY_STATE") + return PTSharedFnInsertionCommand( + target_points=target_points, + fn=fn, + compression_module_type=compression_module_type, + op_unique_name=op_unique_name, + priority=priority, + hooks_group_name=group, + ) + + +@pytest.mark.parametrize( + "target_type", + ( + TargetType.OPERATION_WITH_WEIGHTS, + TargetType.OPERATOR_PRE_HOOK, + TargetType.OPERATOR_POST_HOOK, + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ), +) +@pytest.mark.parametrize( + "command_builder,command_type", + ( + (_create_pt_insertion_command, PTInsertionCommand), + ( + functools.partial( + _create_pt_shared_fn_insertion_command, compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP + ), + PTSharedFnInsertionCommand, + ), + ( + functools.partial( + _create_pt_shared_fn_insertion_command, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, + ), + PTSharedFnInsertionCommand, + ), + ), +) +class TestGetAppliedModificationCommands: + def test_get_applied_modification_commands(self, command_builder, target_type, command_type): + command = command_builder(target_type, TransformationPriority.DEFAULT_PRIORITY) + if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ]: + pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") + + model = TwoConvTestModel() + nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) + model_tranformer = PTModelTransformer(nncf_model) + + layout = PTTransformationLayout() + layout.register(command) + model_tranformer.transform(layout) + + applied_commands = nncf_model.nncf.get_applied_transformation_layout() + + assert len(applied_commands.transformations) == 1 + applied_command = applied_commands.transformations[0] + self._check_commands_are_equal_except_priority_and_hooks_group(command, applied_command) + + def test_priority_of_get_applied_modification_commands(self, command_builder, target_type, command_type): + layout = PTTransformationLayout() + commands = dict() + for priority in (0, 3, 2, 4, 1): + if command_type is PTSharedFnInsertionCommand: + command = command_builder(target_type, priority, op_unique_name=f"UNIQUE_NAME_{priority}") + else: + command = command_builder(target_type, priority) + layout.register(command) + commands[priority] = command + else: + if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ]: + pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") + + model = TwoConvTestModel() + nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) + model_tranformer = PTModelTransformer(nncf_model) + + model_tranformer.transform(layout) + + applied_commands = nncf_model.nncf.get_applied_transformation_layout() + assert len(applied_commands.transformations) == len(commands) + for applied_command in applied_commands.transformations: + command = commands[applied_command.priority] + self._check_commands_are_equal_except_priority_and_hooks_group(command, applied_command) + + @staticmethod + def _target_points_are_equal(tp_original: PTTargetPoint, tp_recovered: PTTargetPoint): + if tp_original != tp_recovered: + return False + if tp_original.target_type == TargetType.OPERATOR_PRE_HOOK: + return tp_original.input_port_id == tp_recovered.input_port_id + return True + + @staticmethod + def _check_commands_are_equal_except_priority_and_hooks_group(command, applied_command): + assert type(applied_command) is type(command) + # Check reference to functions are equal. + # Important for the priority check + assert applied_command.fn is command.fn + ### TODO: map hooks group name + # assert applied_command.hooks_group_name == command.hooks_group_name + + if isinstance(applied_command, PTInsertionCommand): + assert TestGetAppliedModificationCommands._target_points_are_equal( + command.target_point, applied_command.target_point + ) + elif isinstance(applied_command, PTSharedFnInsertionCommand): + all( + TestGetAppliedModificationCommands._target_points_are_equal(a, b) + for a, b in zip(command.target_points, applied_command.target_points) + ) + assert applied_command.target_points == command.target_points + assert applied_command.op_name == command.op_name + assert applied_command.compression_module_type == command.compression_module_type + else: + raise RuntimeError()