Skip to content

Commit

Permalink
[Torch] NNCFNetwork.get_applied_transformation_layout
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Mar 22, 2024
1 parent d3c03b0 commit ada7683
Show file tree
Hide file tree
Showing 3 changed files with 311 additions and 0 deletions.
7 changes: 7 additions & 0 deletions nncf/torch/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def get_op_nodes_in_scope(self, scope: Scope) -> List[NNCFNode]:
matching_graph_op_nodes.extend(nodes_in_module)
return matching_graph_op_nodes

def get_op_node_in_scope(self, scope: Scope) -> List[NNCFNode]:
for scope_str, nodes_in_module in self._layer_name_vs_shared_nodes.items():
module_scope = Scope.from_str(scope_str)
if module_scope == scope:
return nodes_in_module
return None

def get_scope_by_node_name(self, node_name: NNCFNodeName) -> Scope:
matches = []
for node_id, scope_str in self._node_ids_vs_layer_names.items():
Expand Down
125 changes: 125 additions & 0 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from nncf.common.utils.debug import is_debug
from nncf.torch.debug import CombinedDebugInterface
from nncf.torch.debug import debuggable_forward
from nncf.torch.dynamic_graph.context import PreHookId
from nncf.torch.dynamic_graph.context import TracingContext
from nncf.torch.dynamic_graph.graph import DynamicGraph
from nncf.torch.dynamic_graph.graph import ShapeIgnoringTensorMetaComparator
Expand All @@ -60,16 +61,21 @@
from nncf.torch.dynamic_graph.wrappers import wrap_module_call
from nncf.torch.dynamic_graph.wrappers import wrap_parameters
from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_NAME
from nncf.torch.external_hook import ExternalOpCallHook
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.graph_builder import GraphBuilder
from nncf.torch.graph.graph_builder import GraphConverter
from nncf.torch.graph.operator_metatypes import OPERATORS_WITH_WEIGHTS_METATYPES
from nncf.torch.graph.operator_metatypes import PTSplitMetatype
from nncf.torch.graph.transformations.commands import DEFAULT_HOOKS_GROUP_NAME
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.layout import PTTransformationLayout
from nncf.torch.knowledge_distillation.knowledge_distillation_handler import KnowledgeDistillationLossHandler
from nncf.torch.layer_utils import _NNCFModuleMixin
from nncf.torch.module_operations import UpdateWeight
from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules
from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME
from nncf.torch.utils import compute_FLOPs_hook
Expand Down Expand Up @@ -778,6 +784,114 @@ def _collect_eval_op_scopes(self, model: nn.Module, dummy_forward_fn: Callable)
result.append(scope_in_model)
return result

def get_applied_transformation_layout(self) -> PTTransformationLayout:
"""
Collects all hooks applied to the NNCFNetwork, converts them to insertion commands
and returns in PTTransformationLayout format.
:return: Transformation layout with all commands applied to the NNCFNetwork.
"""

def _create_pt_insert_command(module, target_type, target_node_name, priority, input_port_id):
target_point = PTTargetPoint(
target_type=target_type, target_node_name=target_node_name, input_port_id=input_port_id
)
return PTInsertionCommand(point=target_point, fn=module, priority=priority)

def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str):
assert hasattr(
self, hook._storage_name
), f"Storage name {hook._storage_name} is not registered. Info: {info}"
assert hook._storage_key in getattr(
self, hook._storage_name
), f"Storage key {hook._storage_key} is not registered. Info: {info}"

context_hooks = defaultdict(lambda: defaultdict(list))
transformation_layout = PTTransformationLayout
nncf_graph = self.get_graph()
nncf_node_names_map = self.get_op_address_to_op_name_map()

# Collect pre/post layer and op with weights insertion commands
for nncf_module, module_scope in self.get_nncf_modules().items():
for ops, target_type in (
(nncf_module.pre_ops, TargetType.PRE_LAYER_OPERATION),
(nncf_module.post_ops, TargetType.POST_LAYER_OPERATION),
):
for priority, module in enumerate(ops.values()):
nodes_in_scope = nncf_graph.get_op_node_in_scope(module_scope)
assert len(nodes_in_scope) == 1
nncf_node = nodes_in_scope[0]
if isinstance(module, UpdateWeight):
target_type = TargetType.OPERATION_WITH_WEIGHTS
module = module.op
if not isinstance(module, ExternalOpCallHook):
command = _create_pt_insert_command(module, target_type, nncf_node.node_name, priority, None)
transformation_layout.register(command)
continue

info = f"TargetType: {target_type}, nncf node name: {nncf_node.node_name},"
f" priority: {priority}, fn: {module}"
_check_external_call_hook_is_valid(module, info)

context_hooks[module._storage_name][module._storage_key].append(
(target_type, nncf_node.node_name, priority, module, None)
)

# Collect all pre/post hooks commands
for ops, target_type in (
(self._compressed_context._pre_hooks, TargetType.OPERATOR_PRE_HOOK),
(self._compressed_context._post_hooks, TargetType.OPERATOR_POST_HOOK),
):
for op_address, hooks in ops.items():
if isinstance(op_address, PreHookId):
input_port_id = op_address.input_port_id
op_address = op_address.op_address
else:
input_port_id = None
for priority, fn in enumerate(hooks.values()):
target_node_names = nncf_node_names_map[op_address]
assert len(target_node_names) == 1
target_node_name = target_node_names[0]

if not isinstance(fn, ExternalOpCallHook):
command = _create_pt_insert_command(fn, target_type, target_node_name, priority, input_port_id)
transformation_layout.register(command)
continue

info = f"TargetType: {target_type}, op_address: {op_address}, priority: {priority}, fn: {fn}"
_check_external_call_hook_is_valid(fn, info)

context_hooks[fn._storage_name][fn._storage_key].append(
(target_type, target_node_name, priority, fn, input_port_id)
)

# Create shared fn insertion commands according to external hooks collected from
# pre/post layer, pre/post hooks and op with weights target points.
for module_type_name, storage in context_hooks.items():
for storage_key, call_hook_list_info in storage.items():
compression_module = getattr(self, module_type_name)[storage_key]
target_points = []
for target_type, target_node_name, priority, fn, input_port_id in call_hook_list_info:
target_points.append(PTTargetPoint(target_type, target_node_name, input_port_id=input_port_id))

if module_type_name == EXTERNAL_QUANTIZERS_STORAGE_NAME:
module_type = ExtraCompressionModuleType.EXTERNAL_QUANTIZER
elif module_type_name == EXTERNAL_OP_STORAGE_NAME:
module_type = ExtraCompressionModuleType.EXTERNAL_OP
else:
raise RuntimeError(f"Module type {module_type_name} is not supported")

command = PTSharedFnInsertionCommand(
target_points=target_points,
fn=compression_module,
op_unique_name=storage_key,
compression_module_type=module_type,
priority=priority,
)
transformation_layout.register(command)

return transformation_layout

def get_node_to_op_address_mapping(self) -> Dict[NNCFNodeName, OperationAddress]:
"""
Returns map of NNCFGraph node names vs DynamicGraph operation addresses.
Expand All @@ -796,6 +910,17 @@ def get_node_to_op_address_mapping(self) -> Dict[NNCFNodeName, OperationAddress]
retval[nncf_node.node_name] = op_address
return retval

def get_op_address_to_op_name_map(self) -> Dict[OperationAddress, NNCFNodeName]:
"""
Returns map of DynamicGraph operation addresses vs NNCFGraph node names.
:return: DynamicGraph operation addresses vs NNCFGraph node names.
"""
retval = defaultdict(list)
for nncf_node_name, op_address in self.get_node_to_op_address_mapping().items():
retval[op_address].append(nncf_node_name)
return retval

def set_compression_controller(self, ctrl: CompressionAlgorithmController):
self.compression_controller = ctrl

Expand Down
179 changes: 179 additions & 0 deletions tests/torch/test_nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nncf.common.graph import NNCFNode
from nncf.common.graph.operator_metatypes import UnknownMetatype
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.hook_handle import HookHandle
from nncf.torch import register_module
from nncf.torch.dynamic_graph.io_handling import ExampleInputInfo
Expand All @@ -40,9 +41,14 @@
from nncf.torch.graph.operator_metatypes import PTConv2dMetatype
from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.layout import PTTransformationLayout
from nncf.torch.layer_utils import _NNCFModuleMixin
from nncf.torch.layers import NNCFConv2d
from nncf.torch.model_creation import wrap_model
from nncf.torch.model_transformer import PTModelTransformer
from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.nncf_network import PTInsertionPoint
Expand Down Expand Up @@ -1024,3 +1030,176 @@ def test_insert_at_point_hook_handles(self, target_type: TargetType, target_node

del ref_hooks[-2]
_check(ref_hooks)


class DummyOpWithState(torch.nn.Module):
def __init__(self, state: str):
super().__init__()
self._state = state

def __call__(self, *args):
if len(args) == 1:
return args[0]
# To work correctly with
# TargetType.PRE_LAYER_OPERATION
# TargetType.POST_LAYER_OPERATION
return None

def get_state(self):
return self._state.copy()

@classmethod
def from_state(cls, state: str):
return cls(state)


TWO_CONV_MODEL_NODES_NAMES = [
"TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0",
"TwoConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0",
]


def _create_pt_insertion_command(
target_type: TargetType, priority: TransformationPriority, group: str = "default_group"
):
target_point = PTTargetPoint(
target_type=target_type, target_node_name=TWO_CONV_MODEL_NODES_NAMES[0], input_port_id=0
)
fn = DummyOpWithState("DUMMY_STATE")
return PTInsertionCommand(point=target_point, fn=fn, priority=priority, hooks_group_name=group)


def _create_pt_shared_fn_insertion_command(
target_type: TargetType,
priority: TransformationPriority,
compression_module_type: ExtraCompressionModuleType,
group: str = "default_group",
op_unique_name: str = "UNIQUE_NAME",
):
target_points = []

for node_name in TWO_CONV_MODEL_NODES_NAMES:
target_points.append(PTTargetPoint(target_type=target_type, target_node_name=node_name, input_port_id=0))
fn = DummyOpWithState("DUMMY_STATE")
return PTSharedFnInsertionCommand(
target_points=target_points,
fn=fn,
compression_module_type=compression_module_type,
op_unique_name=op_unique_name,
priority=priority,
hooks_group_name=group,
)


@pytest.mark.parametrize(
"target_type",
(
TargetType.OPERATION_WITH_WEIGHTS,
TargetType.OPERATOR_PRE_HOOK,
TargetType.OPERATOR_POST_HOOK,
TargetType.PRE_LAYER_OPERATION,
TargetType.POST_LAYER_OPERATION,
),
)
@pytest.mark.parametrize(
"command_builder,command_type",
(
(_create_pt_insertion_command, PTInsertionCommand),
(
functools.partial(
_create_pt_shared_fn_insertion_command, compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP
),
PTSharedFnInsertionCommand,
),
(
functools.partial(
_create_pt_shared_fn_insertion_command,
compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER,
),
PTSharedFnInsertionCommand,
),
),
)
class TestGetAppliedModificationCommands:
def test_get_applied_modification_commands(self, command_builder, target_type, command_type):
command = command_builder(target_type, TransformationPriority.DEFAULT_PRIORITY)
if isinstance(command, PTSharedFnInsertionCommand) and target_type in [
TargetType.PRE_LAYER_OPERATION,
TargetType.POST_LAYER_OPERATION,
]:
pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}")

model = TwoConvTestModel()
nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])]))
model_tranformer = PTModelTransformer(nncf_model)

layout = PTTransformationLayout()
layout.register(command)
model_tranformer.transform(layout)

applied_commands = nncf_model.nncf.get_applied_transformation_layout()

assert len(applied_commands) == 1
applied_command = applied_commands[0]
self._check_commands_are_equal_except_priority_and_hooks_group(command, applied_command)

def test_priority_of_get_applied_modification_commands(self, command_builder, target_type, command_type):
layout = PTTransformationLayout()
commands = dict()
for priority in (0, 3, 2, 4, 1):
if command_type is PTSharedFnInsertionCommand:
command = command_builder(target_type, priority, op_unique_name=f"UNIQUE_NAME_{priority}")
else:
command = command_builder(target_type, priority)
layout.register(command)
commands[priority] = command
else:
if isinstance(command, PTSharedFnInsertionCommand) and target_type in [
TargetType.PRE_LAYER_OPERATION,
TargetType.POST_LAYER_OPERATION,
]:
pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}")

model = TwoConvTestModel()
nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])]))
model_tranformer = PTModelTransformer(nncf_model)

model_tranformer.transform(layout)

applied_commands = nncf_model.nncf.get_applied_transformation_layout()
assert len(applied_commands) == len(commands)
for applied_command in applied_commands:
command = commands[applied_command.priority]
self._check_commands_are_equal_except_priority_and_hooks_group(command, applied_command)

@staticmethod
def _target_points_are_equal(tp_original: PTTargetPoint, tp_recovered: PTTargetPoint):
if tp_original != tp_recovered:
return False
if tp_original.target_type == TargetType.OPERATOR_PRE_HOOK:
return tp_original.input_port_id == tp_recovered.input_port_id
return True

@staticmethod
def _check_commands_are_equal_except_priority_and_hooks_group(command, applied_command):
assert type(applied_command) is type(command)
# Check reference to functions are equal.
# Important for the priority check
assert applied_command.fn is command.fn
### TODO: map hooks group name
# assert applied_command.hooks_group_name == command.hooks_group_name

if isinstance(applied_command, PTInsertionCommand):
assert TestGetAppliedModificationCommands._target_points_are_equal(
command.target_point, applied_command.target_point
)
elif isinstance(applied_command, PTSharedFnInsertionCommand):
all(
TestGetAppliedModificationCommands._target_points_are_equal(a, b)
for a, b in zip(command.target_points, applied_command.target_points)
)
assert applied_command.target_points == command.target_points
assert applied_command.op_name == command.op_name
assert applied_command.compression_module_type == command.compression_module_type
else:
raise RuntimeError()

0 comments on commit ada7683

Please sign in to comment.