Skip to content

Commit

Permalink
test_shared_fn_insertion_command_several_module_types
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Mar 20, 2024
1 parent 317f263 commit d3c03b0
Showing 1 changed file with 98 additions and 15 deletions.
113 changes: 98 additions & 15 deletions tests/torch/test_model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,23 @@ def test_quantizer_insertion_transformations(
assert command.compression_module_type is ExtraCompressionModuleType.EXTERNAL_QUANTIZER


SHARED_FN_TARGET_POINTS = [
PTTargetPoint(
TargetType.OPERATOR_POST_HOOK,
"/nncf_model_input_0",
),
PTTargetPoint(
TargetType.OPERATOR_PRE_HOOK,
"InsertionPointTestModel/linear_0",
input_port_id=0,
),
PTTargetPoint(
TargetType.OPERATION_WITH_WEIGHTS,
"InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0",
),
]


@pytest.mark.parametrize("compression_module_type", ExtraCompressionModuleType)
@pytest.mark.parametrize(
"priority", [TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION, TransformationPriority.DEFAULT_PRIORITY]
Expand All @@ -691,21 +708,7 @@ def test_shared_fn_insertion_point(
if not torch.cuda.is_available() and multidevice_model:
pytest.skip("Could not test multidevice case without cuda")

tps = [
PTTargetPoint(
TargetType.OPERATOR_POST_HOOK,
"/nncf_model_input_0",
),
PTTargetPoint(
TargetType.OPERATOR_PRE_HOOK,
"InsertionPointTestModel/linear_0",
input_port_id=0,
),
PTTargetPoint(
TargetType.OPERATION_WITH_WEIGHTS,
"InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0",
),
]
tps = SHARED_FN_TARGET_POINTS
OP_UNIQUE_NAME = "UNIQUE_NAME"
HOOK_GROUP_NAME = "shared_commands_hooks_group"
STORAGE_NAME = compression_module_type_to_attr_name(compression_module_type)
Expand Down Expand Up @@ -762,6 +765,7 @@ def _insert_external_op_mocked():
for command in commands:
assert command.target_point in tps
assert command.hooks_group_name == HOOK_GROUP_NAME
assert command.priority == priority
fn = command.fn
assert isinstance(fn, ExternalOpCallHook)
assert fn._storage_name == STORAGE_NAME
Expand All @@ -783,6 +787,85 @@ def _insert_external_op_mocked():
transformed_model.load_state_dict(state_dict)


@pytest.mark.parametrize(
"priority", [TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION, TransformationPriority.DEFAULT_PRIORITY]
)
@pytest.mark.parametrize("compression_module_registered", [False, True])
@pytest.mark.parametrize("multidevice_model", (False, True))
def test_shared_fn_insertion_command_several_module_types(
priority, compression_module_registered, multidevice_model, mocker
):
if not torch.cuda.is_available() and multidevice_model:
pytest.skip("Could not test multidevice case without cuda")

tps = SHARED_FN_TARGET_POINTS
OP_UNIQUE_NAME = "UNIQUE_NAME"
HOOK_GROUP_NAME = "shared_commands_hooks_group"
MODULE_TYPES = [t for t in ExtraCompressionModuleType]
hook_instance = Hook()

def _insert_external_op_mocked():
model = NNCFNetwork(InsertionPointTestModel(), FillerInputInfo([FillerInputElement([1, 1, 10, 10])]))
model = model.cpu()
if multidevice_model:
model.conv1.to(torch.device("cpu"))
model.conv2.to(torch.device("cuda"))

transformation_layout = PTTransformationLayout()
for compression_module_type in MODULE_TYPES:
if compression_module_registered:
model.nncf.register_compression_module_type(compression_module_type)
unique_name = f"{OP_UNIQUE_NAME}[{';'.join([tp.target_node_name for tp in tps])}]"
command = PTSharedFnInsertionCommand(
target_points=tps,
fn=hook_instance,
op_unique_name=unique_name,
compression_module_type=compression_module_type,
priority=priority,
hooks_group_name=HOOK_GROUP_NAME,
)
transformation_layout.register(command)

mocker.MagicMock()
mocker.patch(
"nncf.torch.model_transformer.PTModelTransformer._apply_shared_node_insertion_with_compression_type",
return_value=mocker.MagicMock(),
)
model_transformer = PTModelTransformer(model)
model_transformer.transform(transformation_layout=transformation_layout)
return model

transformed_model = _insert_external_op_mocked()

mock = PTModelTransformer._apply_shared_node_insertion_with_compression_type
assert len(mock.call_args_list) == len(MODULE_TYPES)

REF_STORAGE_KEY = (
"UNIQUE_NAME[/nncf_model_input_0;InsertionPointTestModel/linear_0;"
"InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0]"
)

module_types_set = set(MODULE_TYPES)
for (_, commands, device, compression_module_type), _ in mock.call_args_list:
module_types_set -= set((compression_module_type,))
assert len(commands) == 1
command = commands[0]
assert isinstance(command, PTSharedFnInsertionCommand)
assert command.fn is hook_instance
assert command.target_points is tps
assert command.compression_module_type == compression_module_type
assert command.op_name == REF_STORAGE_KEY
assert command.priority == priority
assert command.hooks_group_name == HOOK_GROUP_NAME

if multidevice_model:
assert device is None
else:
assert device == get_model_device(transformed_model)

assert not module_types_set


INSERTION_POINT_TEST_MODEL_TARGET_POINTS = (
(
TargetType.OPERATOR_POST_HOOK,
Expand Down

0 comments on commit d3c03b0

Please sign in to comment.