diff --git a/nvflare/client/ex_process/api.py b/nvflare/client/ex_process/api.py index d3688c0bfe..a8c9542663 100644 --- a/nvflare/client/ex_process/api.py +++ b/nvflare/client/ex_process/api.py @@ -26,7 +26,6 @@ from nvflare.client.flare_agent import FlareAgentException from nvflare.client.flare_agent_with_fl_model import FlareAgentWithFLModel from nvflare.client.model_registry import ModelRegistry -from nvflare.fuel.utils import fobs from nvflare.fuel.utils.import_utils import optional_import from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.fuel.utils.pipe.pipe import Pipe @@ -52,14 +51,6 @@ def _create_pipe_using_config(client_config: ClientConfig, section: str) -> Tupl return pipe, pipe_channel_name -def _register_tensor_decomposer(): - tensor_decomposer, ok = optional_import(module="nvflare.app_opt.pt.decomposers", name="TensorDecomposer") - if ok: - fobs.register(tensor_decomposer) - else: - raise RuntimeError(f"Can't import TensorDecomposer for format: {ExchangeFormat.PYTORCH}") - - class ExProcessClientAPI(APISpec): def __init__(self): self.process_model_registry = None @@ -93,8 +84,12 @@ def init(self, rank: Optional[str] = None): flare_agent = None try: if rank == "0": - if client_config.get_exchange_format() == ExchangeFormat.PYTORCH: - _register_tensor_decomposer() + if client_config.get_exchange_format() in [ExchangeFormat.PYTORCH, ExchangeFormat.NUMPY]: + # both numpy and pytorch exchange format can need tensor decomposer + # import here, and register later when needed + _, ok = optional_import(module="nvflare.app_opt.pt.decomposers", name="TensorDecomposer") + if not ok: + raise RuntimeError("Can't import TensorDecomposer") pipe, task_channel_name = None, "" if ConfigKey.TASK_EXCHANGE in client_config.config: