diff --git a/nvflare/fuel/utils/class_utils.py b/nvflare/fuel/utils/class_utils.py index cad767e601..ba00204cda 100644 --- a/nvflare/fuel/utils/class_utils.py +++ b/nvflare/fuel/utils/class_utils.py @@ -19,7 +19,6 @@ from typing import Dict, List, Optional from nvflare.security.logging import secure_format_exception -from nvflare.utils.components_utils import create_classes_table_static DEPRECATED_PACKAGES = ["nvflare.app_common.pt", "nvflare.app_common.homomorphic_encryption"] @@ -86,10 +85,10 @@ def __init__(self, base_pkgs: List[str], module_names: List[str], exclude_libs=T self.exclude_libs = exclude_libs self._logger = logging.getLogger(self.__class__.__name__) - self._class_table = create_classes_table_static() + self._class_table: Dict[str, str] = {} + self._create_classes_table() def _create_classes_table(self): - class_table: Dict[str, str] = {} scan_result_table = {} for base in self.base_pkgs: package = importlib.import_module(base) @@ -112,21 +111,20 @@ def _create_classes_table(self): # same class name exists in multiple modules if name in scan_result_table: scan_result = scan_result_table[name] - if name in class_table: - class_table.pop(name) - class_table[f"{scan_result.module_name}.{name}"] = module_name - class_table[f"{module_name}.{name}"] = module_name + if name in self._class_table: + self._class_table.pop(name) + self._class_table[f"{scan_result.module_name}.{name}"] = module_name + self._class_table[f"{module_name}.{name}"] = module_name else: scan_result = _ModuleScanResult(class_name=name, module_name=module_name) scan_result_table[name] = scan_result - class_table[name] = module_name + self._class_table[name] = module_name except (ModuleNotFoundError, RuntimeError) as e: self._logger.debug( f"Try to import module {module_name}, but failed: {secure_format_exception(e)}. " f"Can't use name in config to refer to classes in module: {module_name}." ) pass - return class_table def get_module_name(self, class_name) -> Optional[str]: """Gets the name of the module that contains this class. diff --git a/nvflare/utils/components_utils.py b/nvflare/utils/components_utils.py deleted file mode 100644 index 82519b7216..0000000000 --- a/nvflare/utils/components_utils.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -def create_classes_table_static(): - from nvflare.app_common.aggregators import InTimeAccumulateWeightedAggregator - from nvflare.app_common.aggregators.collect_and_assemble_aggregator import CollectAndAssembleAggregator - from nvflare.app_common.aggregators.dxo_aggregator import DXOAggregator - from nvflare.app_common.ccwf import ( - CrossSiteEvalClientController, - CrossSiteEvalServerController, - CyclicClientController, - SwarmClientController, - SwarmServerController, - ) - from nvflare.app_common.ccwf.swarm_client_ctl import Gatherer - from nvflare.app_common.response_processors.global_weights_initializer import GlobalWeightsInitializer - from nvflare.app_common.shareablegenerators import FullModelShareableGenerator - from nvflare.app_common.workflows.cross_site_eval import CrossSiteEval - from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval - from nvflare.app_common.workflows.cyclic_ctl import CyclicController - from nvflare.app_common.workflows.global_model_eval import GlobalModelEval - from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather - from nvflare.app_common.workflows.scatter_and_gather_scaffold import ScatterAndGatherScaffold - from nvflare.app_opt.pt import PTFileModelLocator, PTFileModelPersistor - - classes = { - ScatterAndGather, - ScatterAndGatherScaffold, - CollectAndAssembleAggregator, - CrossSiteEval, - CrossSiteEvalClientController, - CrossSiteEvalServerController, - CrossSiteModelEval, - CyclicClientController, - CyclicController, - DXOAggregator, - GlobalModelEval, - GlobalWeightsInitializer, - Gatherer, - SwarmClientController, - SwarmServerController, - FullModelShareableGenerator, - InTimeAccumulateWeightedAggregator, - PTFileModelPersistor, - PTFileModelLocator, - } - - class_table = {} - for item in classes: - class_table[item.__name__] = item.__module__ - return class_table