diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 11dd747ce2e..34a3e5a293f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.1.3 + rev: v0.3.7 hooks: - id: ruff diff --git a/examples/post_training_quantization/torch/ssd300_vgg16/main.py b/examples/post_training_quantization/torch/ssd300_vgg16/main.py index fc359e750b3..1b586f4a995 100644 --- a/examples/post_training_quantization/torch/ssd300_vgg16/main.py +++ b/examples/post_training_quantization/torch/ssd300_vgg16/main.py @@ -70,9 +70,9 @@ def run_benchmark(model_path: str, shape=None, verbose: bool = True) -> float: class COCO128Dataset(torch.utils.data.Dataset): category_mapping = [ - 1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33, - 34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60, - 61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, + 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, + 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90 ] # fmt: skip def __init__(self, data_path: str, transform: Callable): diff --git a/examples/torch/common/models/classification/mobilenet_v3_tv_092.py b/examples/torch/common/models/classification/mobilenet_v3_tv_092.py index 5ce32170bd2..1ba5f4d5cb3 100644 --- a/examples/torch/common/models/classification/mobilenet_v3_tv_092.py +++ b/examples/torch/common/models/classification/mobilenet_v3_tv_092.py @@ -305,7 +305,7 @@ def _mobilenet_v3_model( ): model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) if pretrained: - if model_urls.get(arch, None) is None: + if model_urls.get(arch) is None: raise ValueError("No checkpoint is available for model type {}".format(arch)) state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) diff --git a/nncf/common/sparsity/schedulers.py b/nncf/common/sparsity/schedulers.py index b8acc2bc809..c23131f7c0a 100644 --- a/nncf/common/sparsity/schedulers.py +++ b/nncf/common/sparsity/schedulers.py @@ -133,7 +133,7 @@ def __init__(self, controller: SparsityController, params: Dict[str, Any]): self._update_per_optimizer_step = params.get( "update_per_optimizer_step", SPARSITY_SCHEDULER_UPDATE_PER_OPTIMIZER_STEP ) - self._steps_per_epoch = params.get("steps_per_epoch", None) + self._steps_per_epoch = params.get("steps_per_epoch") self._should_skip = False def step(self, next_step: Optional[int] = None) -> None: diff --git a/nncf/experimental/common/tensor_statistics/collectors.py b/nncf/experimental/common/tensor_statistics/collectors.py index 1017877bf24..700cabb63a3 100644 --- a/nncf/experimental/common/tensor_statistics/collectors.py +++ b/nncf/experimental/common/tensor_statistics/collectors.py @@ -461,7 +461,9 @@ def __init__(self, tensor_collectors: List[TensorCollector]) -> None: self._aggregators[key] = unique_aggregator -##################################################Reducers################################################## +################################################## +# Reducers +################################################## class NoopReducer(TensorReducerBase): @@ -578,7 +580,9 @@ def __hash__(self) -> int: return hash((self.__class__.__name__, self.inplace, self._reduction_axes, self._channel_axis)) -##################################################Aggregators################################################## +################################################## +# Aggregators +################################################## class NoopAggregator(AggregatorBase): diff --git a/nncf/experimental/torch/nas/bootstrapNAS/search/search.py b/nncf/experimental/torch/nas/bootstrapNAS/search/search.py index d5ca9cd55e1..23800b019b5 100644 --- a/nncf/experimental/torch/nas/bootstrapNAS/search/search.py +++ b/nncf/experimental/torch/nas/bootstrapNAS/search/search.py @@ -710,9 +710,8 @@ def _evaluate(self, x: List[float], out: Dict[str, Any], *args, **kargs) -> NoRe result = [sample] - eval_idx = 0 bn_adaption_executed = False - for evaluator_handler in self._evaluator_handlers: + for eval_idx, evaluator_handler in enumerate(self._evaluator_handlers): in_cache, value = evaluator_handler.retrieve_from_cache(tuple(x_i)) if not in_cache: if not bn_adaption_executed and self._search.bn_adaptation is not None: @@ -720,7 +719,6 @@ def _evaluate(self, x: List[float], out: Dict[str, Any], *args, **kargs) -> NoRe bn_adaption_executed = True value = evaluator_handler.evaluate_and_add_to_cache_from_pymoo(tuple(x_i)) evaluators_arr[eval_idx].append(value) - eval_idx += 1 result.append(evaluator_handler.name) result.append(value) diff --git a/nncf/experimental/torch/nas/bootstrapNAS/training/scheduler.py b/nncf/experimental/torch/nas/bootstrapNAS/training/scheduler.py index 3579dbb6dff..9dafe0d0ee9 100644 --- a/nncf/experimental/torch/nas/bootstrapNAS/training/scheduler.py +++ b/nncf/experimental/torch/nas/bootstrapNAS/training/scheduler.py @@ -191,12 +191,10 @@ def get_current_stage_desc(self) -> Tuple[Optional[StageDescriptor], int]: :return: current stage descriptor and its index in the list of all descriptors """ partial_epochs = 0 - stage_desc_idx = 0 - for stage_desc in self.list_stage_descriptors: + for stage_desc_idx, stage_desc in enumerate(self.list_stage_descriptors): partial_epochs += stage_desc.epochs if self.current_epoch < partial_epochs: return stage_desc, stage_desc_idx - stage_desc_idx += 1 return None, -1 def get_total_training_epochs(self) -> int: diff --git a/nncf/experimental/torch/sparsity/movement/scheduler.py b/nncf/experimental/torch/sparsity/movement/scheduler.py index 4274e4206e5..33c8d9d51da 100644 --- a/nncf/experimental/torch/sparsity/movement/scheduler.py +++ b/nncf/experimental/torch/sparsity/movement/scheduler.py @@ -97,16 +97,16 @@ def from_dict(cls, params: Dict[str, Any]) -> "MovementSchedulerParams": :param params: A dict that specifies the parameters of movement sparsity scheduler. :return: A `MovementSchedulerParams` object that stores the parameters from `params`. """ - warmup_start_epoch: int = params.get("warmup_start_epoch", None) - warmup_end_epoch: int = params.get("warmup_end_epoch", None) - importance_regularization_factor: float = params.get("importance_regularization_factor", None) + warmup_start_epoch: int = params.get("warmup_start_epoch") + warmup_end_epoch: int = params.get("warmup_end_epoch") + importance_regularization_factor: float = params.get("importance_regularization_factor") enable_structured_masking: bool = params.get("enable_structured_masking", MOVEMENT_ENABLE_STRUCTURED_MASKING) - init_importance_threshold: Optional[float] = params.get("init_importance_threshold", None) + init_importance_threshold: Optional[float] = params.get("init_importance_threshold") final_importance_threshold: float = params.get( "final_importance_threshold", MOVEMENT_FINAL_IMPORTANCE_THRESHOLD ) power: float = params.get("power", MOVEMENT_POWER) - steps_per_epoch: Optional[int] = params.get("steps_per_epoch", None) + steps_per_epoch: Optional[int] = params.get("steps_per_epoch") if None in [warmup_start_epoch, warmup_end_epoch, importance_regularization_factor]: raise ValueError( diff --git a/nncf/tensorflow/graph/converter.py b/nncf/tensorflow/graph/converter.py index 83c5d9f25cd..d7eb09a2a21 100644 --- a/nncf/tensorflow/graph/converter.py +++ b/nncf/tensorflow/graph/converter.py @@ -553,8 +553,7 @@ def _collect_edge_information(self): node_name = layer_name input_shapes = self._node_info[node_name]["input_shapes"] - layer_instance_input_port_id = 0 - for inbound_node in inbound_nodes: + for layer_instance_input_port_id, inbound_node in enumerate(inbound_nodes): producer_layer_name, producer_layer_instance, producer_layer_instance_output_port, _ = inbound_node if self._is_layer_shared(producer_layer_name): @@ -573,7 +572,6 @@ def _collect_edge_information(self): "to_node_input_port_id": layer_instance_input_port_id, "from_node_output_port_id": producer_layer_instance_output_port, } - layer_instance_input_port_id += 1 def convert(self) -> NNCFGraph: nncf_graph = NNCFGraph() diff --git a/nncf/torch/nested_objects_traversal.py b/nncf/torch/nested_objects_traversal.py index f8b7f942e7d..1507b3ade12 100644 --- a/nncf/torch/nested_objects_traversal.py +++ b/nncf/torch/nested_objects_traversal.py @@ -28,7 +28,7 @@ def is_tuple(obj) -> bool: def is_named_tuple(obj) -> bool: - return is_tuple(obj) and (obj.__class__ != tuple) + return is_tuple(obj) and (obj.__class__ is not tuple) def maybe_get_iterator(obj): diff --git a/nncf/torch/quantization/algo.py b/nncf/torch/quantization/algo.py index 24ef0576e9b..1cbeeac4f69 100644 --- a/nncf/torch/quantization/algo.py +++ b/nncf/torch/quantization/algo.py @@ -534,7 +534,7 @@ def _parse_range_init_params(self) -> Optional[PTRangeInitParams]: return PTRangeInitParams(**range_init_params) if range_init_params is not None else None def _parse_precision_init_params(self, initializer_config: Dict) -> Tuple[str, BasePrecisionInitParams]: - init_precision_config = initializer_config.get("precision", None) + init_precision_config = initializer_config.get("precision") if not init_precision_config: return None, None precision_init_type = init_precision_config.get("type", "manual") @@ -934,7 +934,7 @@ def _build_insertion_commands_list_for_quantizer_setup( range_init_minmax_values = None if minmax_values_for_range_init: - minmax_stat = minmax_values_for_range_init[qp_id] if qp_id in minmax_values_for_range_init else None + minmax_stat = minmax_values_for_range_init.get(qp_id) if minmax_stat is not None: range_init_minmax_values = (minmax_stat.min_values, minmax_stat.max_values) @@ -1084,7 +1084,7 @@ def ip_str_repr_key_lambda(x): min_values = None max_values = None for qp_id in sorted_qp_ids: - minmax_stat = minmax_values_for_range_init[qp_id] if qp_id in minmax_values_for_range_init else None + minmax_stat = minmax_values_for_range_init.get(qp_id) if minmax_stat is None: continue diff --git a/nncf/torch/quantization/precision_init/hawq_init.py b/nncf/torch/quantization/precision_init/hawq_init.py index 743e458413a..a60b15ad563 100644 --- a/nncf/torch/quantization/precision_init/hawq_init.py +++ b/nncf/torch/quantization/precision_init/hawq_init.py @@ -95,7 +95,7 @@ def from_config( return cls( user_init_args=user_init_args, bitwidths=hawq_init_config_dict.get("bits", PRECISION_INIT_BITWIDTHS), - traces_per_layer_path=hawq_init_config_dict.get("traces_per_layer_path", None), + traces_per_layer_path=hawq_init_config_dict.get("traces_per_layer_path"), num_data_points=hawq_init_config_dict.get("num_data_points", HAWQ_NUM_DATA_POINTS), iter_number=hawq_init_config_dict.get("iter_number", HAWQ_ITER_NUMBER), tolerance=hawq_init_config_dict.get("tolerance", HAWQ_TOLERANCE), diff --git a/ruff.toml b/ruff.toml index 53940dc8dcb..cf3a51e0c36 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,10 +1,19 @@ line-length = 120 +exclude = ["nncf/tensorflow/__init__.py"] + +[lint] +preview = true ignore-init-module-imports = true ignore = [ + "E201", # whitespace-after-open-bracket + "E203", # whitespace-before-punctuation + "E231", # missing-whitespace + "E251", # unexpected-spaces-around-keyword-parameter-equals "E731", # lambda-assignment "SIM108", # if-else-block-instead-of-if-exp "SIM110", # reimplemented-builtin "SIM117", # multiple-with-statements + "SIM103", # needless-bool ] select = [ "E", # pycodestyle rules @@ -14,9 +23,8 @@ select = [ extend-select = [ "SIM", # https://pypi.org/project/flake8-simplify ] -exclude = ["nncf/tensorflow/__init__.py"] -[per-file-ignores] +[lint.per-file-ignores] "nncf/experimental/torch/nas/bootstrapNAS/__init__.py" = ["F401"] "nncf/torch/__init__.py" = ["F401", "E402"] "tests/**/*.py" = ["F403"] @@ -24,7 +32,7 @@ exclude = ["nncf/tensorflow/__init__.py"] "examples/**/*.py" = ["F403"] -[flake8-copyright] +[lint.flake8-copyright] notice-rgx = """\ # Copyright \\(c\\) (\\d{4}|\\d{4}-\\d{4}) Intel Corporation # Licensed under the Apache License, Version 2.0 \\(the "License"\\); diff --git a/tests/onnx/benchmarking/ac_wrapper.py b/tests/onnx/benchmarking/ac_wrapper.py index 160fa0a287d..9ae3d639859 100644 --- a/tests/onnx/benchmarking/ac_wrapper.py +++ b/tests/onnx/benchmarking/ac_wrapper.py @@ -35,7 +35,7 @@ def _read_image_annotation(image, annotations, label_id_to_label): @staticmethod def convert_to_voc(image_labels): - return [COCO_TO_VOC[label] if label in COCO_TO_VOC else 0 for label in image_labels] + return [COCO_TO_VOC.get(label, 0) for label in image_labels] if __name__ == "__main__": diff --git a/tests/openvino/tools/calibrate.py b/tests/openvino/tools/calibrate.py index 49a601ffc3d..4687c86ee03 100644 --- a/tests/openvino/tools/calibrate.py +++ b/tests/openvino/tools/calibrate.py @@ -1086,7 +1086,7 @@ def main(): "quantize_with_accuracy_control": quantize_model_with_accuracy_control, } for algo_name, algo_config in nncf_algorithms_config.items(): - algo_fn = algo_name_to_method_map.get(algo_name, None) + algo_fn = algo_name_to_method_map.get(algo_name) if algo_fn: quantize_model_arguments = { "xml_path": xml_path, diff --git a/tests/openvino/tools/config.py b/tests/openvino/tools/config.py index e05de9620e7..723dbb5f72a 100644 --- a/tests/openvino/tools/config.py +++ b/tests/openvino/tools/config.py @@ -283,7 +283,7 @@ def _configure_ac_params(self): ac_conf = ConfigReader.convert_paths(ac_conf) ConfigReader._filter_launchers(ac_conf, filtering_params, mode=mode) for req_num in ["stat_requests_number", "eval_requests_number"]: - ac_conf[req_num] = self.engine[req_num] if req_num in self.engine else None + ac_conf[req_num] = self.engine.get(req_num, None) self["engine"] = ac_conf diff --git a/tests/shared/test_templates/template_test_nncf_tensor.py b/tests/shared/test_templates/template_test_nncf_tensor.py index 2005e6c03bc..382ea2cbef9 100644 --- a/tests/shared/test_templates/template_test_nncf_tensor.py +++ b/tests/shared/test_templates/template_test_nncf_tensor.py @@ -354,11 +354,9 @@ def test_getitem(self): def test_iter(self): arr = [0, 1, 2] nncf_tensor = Tensor(self.to_tensor(arr)) - i = 0 - for x in nncf_tensor: + for i, x in enumerate(nncf_tensor): assert x == arr[i] assert isinstance(x, Tensor) - i += 1 # Math diff --git a/tests/torch/nas/test_search.py b/tests/torch/nas/test_search.py index e717f3e4453..931dbe1ff79 100644 --- a/tests/torch/nas/test_search.py +++ b/tests/torch/nas/test_search.py @@ -273,7 +273,7 @@ def validate_model_fn(model, eval_datasets): ) max_subnetwork_acc = validate_model_fn(model, eval_datasets) - _, best_config, performance_metrics = search.run(validate_model_fn, eval_datasets, tmp_path) + _, _, performance_metrics = search.run(validate_model_fn, eval_datasets, tmp_path) assert max_subnetwork_acc == search_result_descriptors.expected_accuracy assert performance_metrics[1] == search_result_descriptors.subnet_expected_accuracy[search_algo_name] diff --git a/tests/torch/qat/helpers.py b/tests/torch/qat/helpers.py index 5d46cb9ad68..da6788a6a91 100644 --- a/tests/torch/qat/helpers.py +++ b/tests/torch/qat/helpers.py @@ -96,11 +96,9 @@ def get_quantization_preset(config_quantization_params: Dict[str, Any]) -> Optio def get_advanced_ptq_parameters(config_quantization_params: Dict[str, Any]) -> AdvancedQuantizationParameters: range_estimator_params = get_range_init_type(config_quantization_params) return AdvancedQuantizationParameters( - overflow_fix=convert_overflow_fix_param(config_quantization_params.get("overflow_fix", None)), - weights_quantization_params=convert_quantization_params(config_quantization_params.get("weights", None)), - activations_quantization_params=convert_quantization_params( - config_quantization_params.get("activations", None) - ), + overflow_fix=convert_overflow_fix_param(config_quantization_params.get("overflow_fix")), + weights_quantization_params=convert_quantization_params(config_quantization_params.get("weights")), + activations_quantization_params=convert_quantization_params(config_quantization_params.get("activations")), weights_range_estimator_params=range_estimator_params, activations_range_estimator_params=range_estimator_params, ) diff --git a/tests/torch/qat/test_qat_segmentation.py b/tests/torch/qat/test_qat_segmentation.py index 2f479faa7f7..f3dff7ef899 100644 --- a/tests/torch/qat/test_qat_segmentation.py +++ b/tests/torch/qat/test_qat_segmentation.py @@ -209,7 +209,7 @@ def train( datasets.train_data_loader.sampler.set_epoch(epoch) logger.info(">>>> [Epoch: {0:d}] Validation".format(epoch)) - loss, (iou, current_miou) = val_obj.run_epoch(config.print_step) + _, (_, current_miou) = val_obj.run_epoch(config.print_step) # best_metric = max(current_miou, best_metric) acc_drop = original_metric - current_miou best_miou = max(current_miou, best_miou) @@ -225,7 +225,7 @@ def train( return acc_drop logger.info(">>>> [Epoch: {0:d}] Training".format(epoch)) - epoch_loss, (iou, miou) = train_obj.run_epoch(config.print_step) + epoch_loss, (_, miou) = train_obj.run_epoch(config.print_step) logger.info(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".format(epoch, epoch_loss, miou)) diff --git a/tests/torch/quantization/quantization_helpers.py b/tests/torch/quantization/quantization_helpers.py index 16f2c83e7c7..45e0779e420 100644 --- a/tests/torch/quantization/quantization_helpers.py +++ b/tests/torch/quantization/quantization_helpers.py @@ -63,7 +63,7 @@ def get_squeezenet_quantization_config(image_size=32, batch_size=3): def distributed_init_test_default(gpu, ngpus_per_node, config): config.batch_size = 3 - config.workers = 0 # workaround for the pytorch multiprocessingdataloader issue/ + config.workers = 0 # workaround for the pytorch multiprocessingdataloader issue/ config.gpu = gpu config.ngpus_per_node = ngpus_per_node config.rank = gpu diff --git a/tests/torch/sparsity/magnitude/test_algo.py b/tests/torch/sparsity/magnitude/test_algo.py index 0dc2da03aa0..f6f16d979c3 100644 --- a/tests/torch/sparsity/magnitude/test_algo.py +++ b/tests/torch/sparsity/magnitude/test_algo.py @@ -46,18 +46,15 @@ def test_can_create_magnitude_sparse_algo__with_defaults(): _, sparse_model_conv = check_correct_nncf_modules_replacement(model, sparse_model) - i = 0 - nncf_stats = compression_ctrl.statistics() for layer_info in nncf_stats.magnitude_sparsity.thresholds: assert layer_info.threshold == approx(0.24, 0.1) assert isinstance(compression_ctrl._weight_importance_fn, type(normed_magnitude)) - for sparse_module in sparse_model_conv.values(): + for i, sparse_module in enumerate(sparse_model_conv.values()): store = [] ref_mask = torch.ones_like(sparse_module.weight) if i == 0 else ref_mask_2 - i += 1 for op in sparse_module.pre_ops.values(): if isinstance(op, UpdateWeight) and isinstance(op.operand, BinaryMask): assert torch.allclose(op.operand.binary_mask, ref_mask) diff --git a/tests/torch/test_statistics_aggregator.py b/tests/torch/test_statistics_aggregator.py index 10dfc149c1c..a5c8a71788e 100644 --- a/tests/torch/test_statistics_aggregator.py +++ b/tests/torch/test_statistics_aggregator.py @@ -183,29 +183,29 @@ def test_successive_statistics_aggregation( if not is_standard_estimator and not is_backend_support_custom_estimators: pytest.skip("Custom estimators are not supported for this backend yet") - ### Register operations before statistic collection + # Register operations before statistic collection def fn(x): return x * 2 target_point = self.get_target_point(test_parameters.target_type) model = self.__add_fn_to_model(model, target_point, fn) - ### Check hook inserted correctly + # Check hook inserted correctly self.__check_successive_hooks(test_parameters, model, target_point, fn) - ### Register and collect statistics after inserted operations + # Register and collect statistics after inserted operations statistic_points = self.__get_statistic_points( test_parameters, model, quantizer_config, dataset_samples, inplace_statistics, mocker ) tensor_collector = self.__collect_statistics_get_collector(statistic_points, model, dataset_samples) - ### Check values are changed because of the inserted operation + # Check values are changed because of the inserted operation self.__check_collector( test_parameters, tensor_collector, is_stat_in_shape_of_scale, ) - ### Check the inserted operation is inside the model + # Check the inserted operation is inside the model self.__check_successive_hooks(test_parameters, model, target_point, fn) @pytest.mark.parametrize( @@ -270,7 +270,7 @@ def test_nested_statistics_aggregation( if not is_standard_estimator and not is_backend_support_custom_estimators: pytest.skip("Custom estimators are not supported for this backend yet") - ### Register operations before statistic collection + # Register operations before statistic collection @register_operator() def fn(x): return x * 2 @@ -280,22 +280,22 @@ def fn(x): nested_target_point = PTMinMaxAlgoBackend.target_point(nested_target_type, nested_target_node_name, 0) model = self.__add_fn_to_model(model, nested_target_point, fn) - ### Check hook inserted correctly + # Check hook inserted correctly self.__check_nested_hooks(test_parameters, model, target_point, nested_target_type, nested_target_node_name, fn) - ### Register and collect statistics after inserted operations + # Register and collect statistics after inserted operations statistic_points = self.__get_statistic_points( test_parameters, model, quantizer_config, dataset_samples, inplace_statistics, mocker ) tensor_collector = self.__collect_statistics_get_collector(statistic_points, model, dataset_samples) - ### Check values are changed because of the inserted operation + # Check values are changed because of the inserted operation self.__check_collector( test_parameters, tensor_collector, is_stat_in_shape_of_scale, ) - ### Check the inserted operation is inside the model + # Check the inserted operation is inside the model self.__check_nested_hooks(test_parameters, model, target_point, nested_target_type, nested_target_node_name, fn) @staticmethod diff --git a/tests/torch/test_tracing_context.py b/tests/torch/test_tracing_context.py index 75b127eb2d6..7c4a23a588a 100644 --- a/tests/torch/test_tracing_context.py +++ b/tests/torch/test_tracing_context.py @@ -110,10 +110,10 @@ def test_traced_tensors_are_stripped_on_context_exit(): assert isinstance(module.weight, TracedParameter) assert isinstance(module.conv2d.weight, TracedParameter) assert isinstance(result, TracedTensor) - assert type(module.cached_tensor) == torch.Tensor - assert type(result) == torch.Tensor - assert type(module.weight) == torch.nn.Parameter - assert type(module.conv2d.weight) == torch.nn.Parameter + assert isinstance(module.cached_tensor, torch.Tensor) + assert isinstance(result, torch.Tensor) + assert isinstance(module.weight, torch.nn.Parameter) + assert isinstance(module.conv2d.weight, torch.nn.Parameter) def test_no_cross_forward_run_dependency(): diff --git a/tools/extract_ov_subgraph.py b/tools/extract_ov_subgraph.py index 022ee777e8d..0739bfbf1eb 100644 --- a/tools/extract_ov_subgraph.py +++ b/tools/extract_ov_subgraph.py @@ -94,8 +94,8 @@ def get_nodes(xml_dict: Dict, edges: Dict): try: attributes = node["attributes"] data = node["data"]["attributes"] if "data" in node else None - inp = node["input"] if "input" in node else None - out = node["output"] if "output" in node else None + inp = node.get("input", None) + out = node.get("output", None) node_id = int(attributes["id"]) node_name = attributes["name"]