Skip to content

Commit

Permalink
Fix target_type detection bug
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Apr 2, 2024
1 parent 297c1d2 commit 0eca703
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 157 deletions.
11 changes: 7 additions & 4 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,20 +823,23 @@ def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str):
nodes_in_scope = nncf_graph.get_op_node_in_scope(module_scope)
assert len(nodes_in_scope) == 1
nncf_node = nodes_in_scope[0]
command_target_type = target_type
if isinstance(module, UpdateWeight):
target_type = TargetType.OPERATION_WITH_WEIGHTS
command_target_type = TargetType.OPERATION_WITH_WEIGHTS
module = module.op
if not isinstance(module, ExternalOpCallHook):
command = _create_pt_insert_command(module, target_type, nncf_node.node_name, priority, None)
command = _create_pt_insert_command(
module, command_target_type, nncf_node.node_name, priority, None
)
transformation_layout.register(command)
continue

info = f"TargetType: {target_type}, nncf node name: {nncf_node.node_name},"
info = f"TargetType: {command_target_type}, nncf node name: {nncf_node.node_name},"
f" priority: {priority}, fn: {module}"
_check_external_call_hook_is_valid(module, info)

context_hooks[module._storage_name][module._storage_key].append(
(target_type, nncf_node.node_name, priority, module, None)
(command_target_type, nncf_node.node_name, priority, module, None)
)

# Collect all pre/post hooks commands
Expand Down
183 changes: 183 additions & 0 deletions tests/torch/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import functools
import itertools
import numbers
from abc import ABC
from abc import abstractmethod
Expand All @@ -29,6 +32,8 @@

import nncf
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.config import NNCFConfig
from nncf.config.extractors import extract_algorithm_names
from nncf.config.structures import BNAdaptationInitArgs
Expand All @@ -38,8 +43,13 @@
from nncf.torch.dynamic_graph.io_handling import FillerInputInfo
from nncf.torch.dynamic_graph.operation_address import OperationAddress
from nncf.torch.dynamic_graph.scope import Scope
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.initialization import PTInitializingDataLoader
from nncf.torch.initialization import register_default_init_args
from nncf.torch.layer_utils import COMPRESSION_MODULES
from nncf.torch.layers import NNCF_MODULES_MAP
from nncf.torch.model_creation import create_compressed_model
from nncf.torch.module_operations import UpdateWeight
Expand Down Expand Up @@ -172,6 +182,12 @@ def nz_bias_num(self):


class TwoConvTestModel(nn.Module):
INPUT_SHAPE = [1, 1, 4, 4]
NNCF_CONV_NODES_NAMES = [
"TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0",
"TwoConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0",
]

def __init__(self):
super().__init__()
self.features = []
Expand All @@ -198,6 +214,113 @@ def nz_weights_num(self):
def nz_bias_num(self):
return 2

@staticmethod
def create_pt_insertion_command(
target_type: TargetType, priority: TransformationPriority, fn=None, group: str = "default_group"
):
target_point = PTTargetPoint(
target_type=target_type, target_node_name=TwoConvTestModel.NNCF_CONV_NODES_NAMES[0], input_port_id=0
)
if fn is None:
fn = DummyOpWithState("DUMMY_STATE")
return PTInsertionCommand(point=target_point, fn=fn, priority=priority, hooks_group_name=group)

@staticmethod
def create_pt_shared_fn_insertion_command(
target_type: TargetType,
priority: TransformationPriority,
compression_module_type: ExtraCompressionModuleType,
fn=None,
group: str = "default_group",
op_unique_name: str = "UNIQUE_NAME",
):
target_points = []

for node_name in TwoConvTestModel.NNCF_CONV_NODES_NAMES:
target_points.append(PTTargetPoint(target_type=target_type, target_node_name=node_name, input_port_id=0))
if fn is None:
fn = DummyOpWithState("DUMMY_STATE")
return PTSharedFnInsertionCommand(
target_points=target_points,
fn=fn,
compression_module_type=compression_module_type,
op_unique_name=op_unique_name,
priority=priority,
hooks_group_name=group,
)

AVAILABLE_TARGET_TYPES = (
TargetType.OPERATION_WITH_WEIGHTS,
TargetType.OPERATOR_PRE_HOOK,
TargetType.OPERATOR_POST_HOOK,
TargetType.PRE_LAYER_OPERATION,
TargetType.POST_LAYER_OPERATION,
)

@staticmethod
def get_command_builders():
return (
TwoConvTestModel.create_pt_insertion_command,
functools.partial(
TwoConvTestModel.create_pt_shared_fn_insertion_command,
compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP,
),
functools.partial(
TwoConvTestModel.create_pt_shared_fn_insertion_command,
compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER,
),
)

COMMAND_TYPES = [PTInsertionCommand, PTSharedFnInsertionCommand, PTSharedFnInsertionCommand]
PRIORITIES = (TransformationPriority.QUANTIZATION_PRIORITY, TransformationPriority.QUANTIZATION_PRIORITY.value + 1)

@classmethod
def get_all_available_commands(
cls, dummy_op_state, skip_model_transformer_unsupported=False
) -> TransformationLayout:
"""
Returns all possible commands to insert:
all target types x all command class x all compression module types x different priorities.
"""
layout = TransformationLayout()
for idx, (target_type, (command_builder, command_type), priority) in enumerate(
itertools.product(
cls.AVAILABLE_TARGET_TYPES, zip(cls.get_command_builders(), cls.COMMAND_TYPES), cls.PRIORITIES
)
):
if command_type is PTSharedFnInsertionCommand:
if skip_model_transformer_unsupported and target_type in [
TargetType.PRE_LAYER_OPERATION,
TargetType.POST_LAYER_OPERATION,
]:
continue
command = cls._create_command(
command_builder, target_type, priority, dummy_op_state, op_unique_name=f"UNIQUE_NAME_{idx}"
)
else:
command = cls._create_command(command_builder, target_type, priority, dummy_op_state)

layout.register(command)
return layout

@staticmethod
def _create_command(command_builder, target_type, priority, dummy_op_state, op_unique_name=None):
group_name = "CUSTOM_HOOKS_GROUP_NAME"

if DummyOpWithState.__name__ not in COMPRESSION_MODULES.registry_dict:
registered_dummy_op_cls = COMPRESSION_MODULES.register()(DummyOpWithState)
else:
registered_dummy_op_cls = DummyOpWithState
dummy_op = registered_dummy_op_cls(dummy_op_state)
if op_unique_name is None:
command = command_builder(target_type, priority, fn=dummy_op, group=group_name)
else:
command = command_builder(
target_type, priority, fn=dummy_op, group=group_name, op_unique_name=op_unique_name
)

return command


class LeNet(nn.Module):
INPUT_SIZE = 1, 32, 32
Expand Down Expand Up @@ -228,6 +351,66 @@ def num_flat_features(self, x):
return num_features


class DummyOpWithState(torch.nn.Module):
def __init__(self, state: str):
super().__init__()
self._state = state

def __call__(self, *args):
if len(args) == 1:
return args[0]
# To work correctly with
# TargetType.PRE_LAYER_OPERATION
# TargetType.POST_LAYER_OPERATION
return None

def get_state(self):
return self._state

@classmethod
def from_state(cls, state: str):
return cls(state)


def target_points_are_equal(tp_original: PTTargetPoint, tp_recovered: PTTargetPoint):
if tp_original != tp_recovered:
return False
if tp_original.target_type == TargetType.OPERATOR_PRE_HOOK:
return tp_original.input_port_id == tp_recovered.input_port_id
return True


def are_commands_equal(
command, applied_command, check_priority: bool = True, check_hooks_group_name: bool = True, check_fn_ref=True
):
if type(applied_command) is not type(command):
return False

# Check reference to functions are equal.
if check_fn_ref and applied_command.fn is not command.fn:
return False
if check_hooks_group_name and applied_command.hooks_group_name != command.hooks_group_name:
return False
if check_priority and applied_command.priority != command.priority:
return False

if isinstance(applied_command, PTInsertionCommand):
if not target_points_are_equal(command.target_point, applied_command.target_point):
return False
elif isinstance(applied_command, PTSharedFnInsertionCommand):
if not all(target_points_are_equal(a, b) for a, b in zip(command.target_points, applied_command.target_points)):
return False
if (
applied_command.target_points != command.target_points
or applied_command.op_name != command.op_name
or applied_command.compression_module_type != command.compression_module_type
):
return False
else:
raise RuntimeError()
return True


class SharedConv(nn.Module):
INPUT_SIZE = [1, 1, 4, 4]

Expand Down
Loading

0 comments on commit 0eca703

Please sign in to comment.