Skip to content

Commit

Permalink
[Torch] NNCFNetwork.get_applied_transformation_layout
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Apr 2, 2024
1 parent f7a5660 commit 297c1d2
Show file tree
Hide file tree
Showing 3 changed files with 313 additions and 0 deletions.
7 changes: 7 additions & 0 deletions nncf/torch/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def get_op_nodes_in_scope(self, scope: Scope) -> List[NNCFNode]:
matching_graph_op_nodes.extend(nodes_in_module)
return matching_graph_op_nodes

def get_op_node_in_scope(self, scope: Scope) -> List[NNCFNode]:
for scope_str, nodes_in_module in self._layer_name_vs_shared_nodes.items():
module_scope = Scope.from_str(scope_str)
if module_scope == scope:
return nodes_in_module
return []

def get_scope_by_node_name(self, node_name: NNCFNodeName) -> Scope:
matches = []
for node_id, scope_str in self._node_ids_vs_layer_names.items():
Expand Down
127 changes: 127 additions & 0 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from nncf.common.utils.debug import is_debug
from nncf.torch.debug import CombinedDebugInterface
from nncf.torch.debug import debuggable_forward
from nncf.torch.dynamic_graph.context import PreHookId
from nncf.torch.dynamic_graph.context import TracingContext
from nncf.torch.dynamic_graph.graph import DynamicGraph
from nncf.torch.dynamic_graph.graph import ShapeIgnoringTensorMetaComparator
Expand All @@ -60,16 +61,21 @@
from nncf.torch.dynamic_graph.wrappers import wrap_module_call
from nncf.torch.dynamic_graph.wrappers import wrap_parameters
from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_NAME
from nncf.torch.external_hook import ExternalOpCallHook
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.graph_builder import GraphBuilder
from nncf.torch.graph.graph_builder import GraphConverter
from nncf.torch.graph.operator_metatypes import OPERATORS_WITH_WEIGHTS_METATYPES
from nncf.torch.graph.operator_metatypes import PTSplitMetatype
from nncf.torch.graph.transformations.commands import DEFAULT_HOOKS_GROUP_NAME
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.layout import PTTransformationLayout
from nncf.torch.knowledge_distillation.knowledge_distillation_handler import KnowledgeDistillationLossHandler
from nncf.torch.layer_utils import _NNCFModuleMixin
from nncf.torch.module_operations import UpdateWeight
from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules
from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME
from nncf.torch.utils import compute_FLOPs_hook
Expand Down Expand Up @@ -778,6 +784,116 @@ def _collect_eval_op_scopes(self, model: nn.Module, dummy_forward_fn: Callable)
result.append(scope_in_model)
return result

def get_applied_transformation_layout(self) -> PTTransformationLayout:
"""
Collects all hooks applied to the NNCFNetwork, converts them to insertion commands
and returns in PTTransformationLayout format. Default hooks group name is used in
recovered commands, so hooks group names specified diring the model modification
become outdated.
:return: Transformation layout with all commands applied to the NNCFNetwork.
"""

def _create_pt_insert_command(module, target_type, target_node_name, priority, input_port_id):
target_point = PTTargetPoint(
target_type=target_type, target_node_name=target_node_name, input_port_id=input_port_id
)
return PTInsertionCommand(point=target_point, fn=module, priority=priority)

def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str):
assert hasattr(
self, hook._storage_name
), f"Storage name {hook._storage_name} is not registered. Info: {info}"
assert hook._storage_key in getattr(
self, hook._storage_name
), f"Storage key {hook._storage_key} is not registered. Info: {info}"

context_hooks = defaultdict(lambda: defaultdict(list))
transformation_layout = PTTransformationLayout()
nncf_graph = self.get_graph()
nncf_node_names_map = self.get_op_address_to_op_name_map()

# Collect pre/post layer and op with weights insertion commands
for nncf_module, module_scope in self.get_nncf_modules().items():
for ops, target_type in (
(nncf_module.pre_ops, TargetType.PRE_LAYER_OPERATION),
(nncf_module.post_ops, TargetType.POST_LAYER_OPERATION),
):
for priority, module in enumerate(ops.values()):
nodes_in_scope = nncf_graph.get_op_node_in_scope(module_scope)
assert len(nodes_in_scope) == 1
nncf_node = nodes_in_scope[0]
if isinstance(module, UpdateWeight):
target_type = TargetType.OPERATION_WITH_WEIGHTS
module = module.op
if not isinstance(module, ExternalOpCallHook):
command = _create_pt_insert_command(module, target_type, nncf_node.node_name, priority, None)
transformation_layout.register(command)
continue

info = f"TargetType: {target_type}, nncf node name: {nncf_node.node_name},"
f" priority: {priority}, fn: {module}"
_check_external_call_hook_is_valid(module, info)

context_hooks[module._storage_name][module._storage_key].append(
(target_type, nncf_node.node_name, priority, module, None)
)

# Collect all pre/post hooks commands
for ops, target_type in (
(self._compressed_context._pre_hooks, TargetType.OPERATOR_PRE_HOOK),
(self._compressed_context._post_hooks, TargetType.OPERATOR_POST_HOOK),
):
for op_address, hooks in ops.items():
if isinstance(op_address, PreHookId):
input_port_id = op_address.input_port_id
op_address = op_address.op_address
else:
input_port_id = None
for priority, fn in enumerate(hooks.values()):
target_node_names = nncf_node_names_map[op_address]
assert len(target_node_names) == 1
target_node_name = target_node_names[0]

if not isinstance(fn, ExternalOpCallHook):
command = _create_pt_insert_command(fn, target_type, target_node_name, priority, input_port_id)
transformation_layout.register(command)
continue

info = f"TargetType: {target_type}, op_address: {op_address}, priority: {priority}, fn: {fn}"
_check_external_call_hook_is_valid(fn, info)

context_hooks[fn._storage_name][fn._storage_key].append(
(target_type, target_node_name, priority, fn, input_port_id)
)

# Create shared fn insertion commands according to external hooks collected from
# pre/post layer, pre/post hooks and op with weights target points.
for module_type_name, storage in context_hooks.items():
for storage_key, call_hook_list_info in storage.items():
compression_module = getattr(self, module_type_name)[storage_key]
target_points = []
for target_type, target_node_name, priority, fn, input_port_id in call_hook_list_info:
target_points.append(PTTargetPoint(target_type, target_node_name, input_port_id=input_port_id))

if module_type_name == EXTERNAL_QUANTIZERS_STORAGE_NAME:
module_type = ExtraCompressionModuleType.EXTERNAL_QUANTIZER
elif module_type_name == EXTERNAL_OP_STORAGE_NAME:
module_type = ExtraCompressionModuleType.EXTERNAL_OP
else:
raise RuntimeError(f"Module type {module_type_name} is not supported")

command = PTSharedFnInsertionCommand(
target_points=target_points,
fn=compression_module,
op_unique_name=storage_key,
compression_module_type=module_type,
priority=priority,
)
transformation_layout.register(command)

return transformation_layout

def get_node_to_op_address_mapping(self) -> Dict[NNCFNodeName, OperationAddress]:
"""
Returns map of NNCFGraph node names vs DynamicGraph operation addresses.
Expand All @@ -796,6 +912,17 @@ def get_node_to_op_address_mapping(self) -> Dict[NNCFNodeName, OperationAddress]
retval[nncf_node.node_name] = op_address
return retval

def get_op_address_to_op_name_map(self) -> Dict[OperationAddress, NNCFNodeName]:
"""
Returns map of DynamicGraph operation addresses vs NNCFGraph node names.
:return: DynamicGraph operation addresses vs NNCFGraph node names.
"""
retval = defaultdict(list)
for nncf_node_name, op_address in self.get_node_to_op_address_mapping().items():
retval[op_address].append(nncf_node_name)
return retval

def set_compression_controller(self, ctrl: CompressionAlgorithmController):
self.compression_controller = ctrl

Expand Down
179 changes: 179 additions & 0 deletions tests/torch/test_nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nncf.common.graph import NNCFNode
from nncf.common.graph.operator_metatypes import UnknownMetatype
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.hook_handle import HookHandle
from nncf.torch import register_module
from nncf.torch.dynamic_graph.io_handling import ExampleInputInfo
Expand All @@ -40,9 +41,14 @@
from nncf.torch.graph.operator_metatypes import PTConv2dMetatype
from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.layout import PTTransformationLayout
from nncf.torch.layer_utils import _NNCFModuleMixin
from nncf.torch.layers import NNCFConv2d
from nncf.torch.model_creation import wrap_model
from nncf.torch.model_transformer import PTModelTransformer
from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.nncf_network import PTInsertionPoint
Expand Down Expand Up @@ -1024,3 +1030,176 @@ def test_insert_at_point_hook_handles(self, target_type: TargetType, target_node

del ref_hooks[-2]
_check(ref_hooks)


class DummyOpWithState(torch.nn.Module):
def __init__(self, state: str):
super().__init__()
self._state = state

def __call__(self, *args):
if len(args) == 1:
return args[0]
# To work correctly with
# TargetType.PRE_LAYER_OPERATION
# TargetType.POST_LAYER_OPERATION
return None

def get_state(self):
return self._state.copy()

@classmethod
def from_state(cls, state: str):
return cls(state)


TWO_CONV_MODEL_NODES_NAMES = [
"TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0",
"TwoConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0",
]


def _create_pt_insertion_command(
target_type: TargetType, priority: TransformationPriority, group: str = "default_group"
):
target_point = PTTargetPoint(
target_type=target_type, target_node_name=TWO_CONV_MODEL_NODES_NAMES[0], input_port_id=0
)
fn = DummyOpWithState("DUMMY_STATE")
return PTInsertionCommand(point=target_point, fn=fn, priority=priority, hooks_group_name=group)


def _create_pt_shared_fn_insertion_command(
target_type: TargetType,
priority: TransformationPriority,
compression_module_type: ExtraCompressionModuleType,
group: str = "default_group",
op_unique_name: str = "UNIQUE_NAME",
):
target_points = []

for node_name in TWO_CONV_MODEL_NODES_NAMES:
target_points.append(PTTargetPoint(target_type=target_type, target_node_name=node_name, input_port_id=0))
fn = DummyOpWithState("DUMMY_STATE")
return PTSharedFnInsertionCommand(
target_points=target_points,
fn=fn,
compression_module_type=compression_module_type,
op_unique_name=op_unique_name,
priority=priority,
hooks_group_name=group,
)


@pytest.mark.parametrize(
"target_type",
(
TargetType.OPERATION_WITH_WEIGHTS,
TargetType.OPERATOR_PRE_HOOK,
TargetType.OPERATOR_POST_HOOK,
TargetType.PRE_LAYER_OPERATION,
TargetType.POST_LAYER_OPERATION,
),
)
@pytest.mark.parametrize(
"command_builder,command_type",
(
(_create_pt_insertion_command, PTInsertionCommand),
(
functools.partial(
_create_pt_shared_fn_insertion_command, compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP
),
PTSharedFnInsertionCommand,
),
(
functools.partial(
_create_pt_shared_fn_insertion_command,
compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER,
),
PTSharedFnInsertionCommand,
),
),
)
class TestGetAppliedModificationCommands:
def test_get_applied_modification_commands(self, command_builder, target_type, command_type):
command = command_builder(target_type, TransformationPriority.DEFAULT_PRIORITY)
if isinstance(command, PTSharedFnInsertionCommand) and target_type in [
TargetType.PRE_LAYER_OPERATION,
TargetType.POST_LAYER_OPERATION,
]:
pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}")

model = TwoConvTestModel()
nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])]))
model_tranformer = PTModelTransformer(nncf_model)

layout = PTTransformationLayout()
layout.register(command)
model_tranformer.transform(layout)

applied_commands = nncf_model.nncf.get_applied_transformation_layout()

assert len(applied_commands.transformations) == 1
applied_command = applied_commands.transformations[0]
self._check_commands_are_equal_except_priority_and_hooks_group(command, applied_command)

def test_priority_of_get_applied_modification_commands(self, command_builder, target_type, command_type):
layout = PTTransformationLayout()
commands = dict()
for priority in (0, 3, 2, 4, 1):
if command_type is PTSharedFnInsertionCommand:
command = command_builder(target_type, priority, op_unique_name=f"UNIQUE_NAME_{priority}")
else:
command = command_builder(target_type, priority)
layout.register(command)
commands[priority] = command
else:
if isinstance(command, PTSharedFnInsertionCommand) and target_type in [
TargetType.PRE_LAYER_OPERATION,
TargetType.POST_LAYER_OPERATION,
]:
pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}")

model = TwoConvTestModel()
nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])]))
model_tranformer = PTModelTransformer(nncf_model)

model_tranformer.transform(layout)

applied_commands = nncf_model.nncf.get_applied_transformation_layout()
assert len(applied_commands.transformations) == len(commands)
for applied_command in applied_commands.transformations:
command = commands[applied_command.priority]
self._check_commands_are_equal_except_priority_and_hooks_group(command, applied_command)

@staticmethod
def _target_points_are_equal(tp_original: PTTargetPoint, tp_recovered: PTTargetPoint):
if tp_original != tp_recovered:
return False
if tp_original.target_type == TargetType.OPERATOR_PRE_HOOK:
return tp_original.input_port_id == tp_recovered.input_port_id
return True

@staticmethod
def _check_commands_are_equal_except_priority_and_hooks_group(command, applied_command):
assert type(applied_command) is type(command)
# Check reference to functions are equal.
# Important for the priority check
assert applied_command.fn is command.fn
### TODO: map hooks group name
# assert applied_command.hooks_group_name == command.hooks_group_name

if isinstance(applied_command, PTInsertionCommand):
assert TestGetAppliedModificationCommands._target_points_are_equal(
command.target_point, applied_command.target_point
)
elif isinstance(applied_command, PTSharedFnInsertionCommand):
all(
TestGetAppliedModificationCommands._target_points_are_equal(a, b)
for a, b in zip(command.target_points, applied_command.target_points)
)
assert applied_command.target_points == command.target_points
assert applied_command.op_name == command.op_name
assert applied_command.compression_module_type == command.compression_module_type
else:
raise RuntimeError()

0 comments on commit 297c1d2

Please sign in to comment.