diff --git a/nvflare/app_common/widgets/metric_relay.py b/nvflare/app_common/widgets/metric_relay.py index deac9c1b34..702e5d3473 100644 --- a/nvflare/app_common/widgets/metric_relay.py +++ b/nvflare/app_common/widgets/metric_relay.py @@ -33,6 +33,7 @@ def __init__( pipe_id: str, read_interval=0.1, heartbeat_interval=5.0, + heartbeat_timeout=60.0, pipe_channel_name=PipeChannelName.METRIC, event_type: str = ANALYTIC_EVENT_TYPE, fed_event: bool = True, @@ -41,6 +42,7 @@ def __init__( self.pipe_id = pipe_id self._read_interval = read_interval self._heartbeat_interval = heartbeat_interval + self._heartbeat_timeout = heartbeat_timeout self.pipe_channel_name = pipe_channel_name self.pipe = None self.pipe_handler = None @@ -62,7 +64,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): pipe=self.pipe, read_interval=self._read_interval, heartbeat_interval=self._heartbeat_interval, - heartbeat_timeout=0, + heartbeat_timeout=self._heartbeat_timeout, ) self.pipe_handler.set_status_cb(self._pipe_status_cb) self.pipe_handler.set_message_cb(self._pipe_msg_cb) diff --git a/nvflare/client/api.py b/nvflare/client/api.py index f075d48d2e..5ec8679e29 100644 --- a/nvflare/client/api.py +++ b/nvflare/client/api.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import logging import os from enum import Enum from typing import Any, Dict, Optional @@ -45,12 +47,14 @@ def init(rank: Optional[str] = None): api_type_name = os.environ.get(CLIENT_API_TYPE_KEY, ClientAPIType.IN_PROCESS_API.value) api_type = ClientAPIType(api_type_name) global client_api - if api_type == ClientAPIType.IN_PROCESS_API: - client_api = data_bus.get_data(CLIENT_API_KEY) + if client_api is None: + if api_type == ClientAPIType.IN_PROCESS_API: + client_api = data_bus.get_data(CLIENT_API_KEY) + else: + client_api = ExProcessClientAPI() + client_api.init(rank=rank) else: - client_api = ExProcessClientAPI() - - client_api.init(rank=rank) + logging.warning("Warning: called init() more than once. The subsequence calls are ignored") def receive(timeout: Optional[float] = None) -> Optional[FLModel]: diff --git a/nvflare/client/config.py b/nvflare/client/config.py index bbe59201af..5558ece3da 100644 --- a/nvflare/client/config.py +++ b/nvflare/client/config.py @@ -44,6 +44,7 @@ class ConfigKey: TASK_NAME = "TASK_NAME" TASK_EXCHANGE = "TASK_EXCHANGE" METRICS_EXCHANGE = "METRICS_EXCHANGE" + HEARTBEAT_TIMEOUT = "HEARTBEAT_TIMEOUT" class ClientConfig: @@ -133,19 +134,19 @@ def get_pipe_class(self, section: str) -> str: return self.config[section][ConfigKey.PIPE][ConfigKey.CLASS_NAME] def get_exchange_format(self) -> str: - return self.config[ConfigKey.TASK_EXCHANGE][ConfigKey.EXCHANGE_FORMAT] + return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.EXCHANGE_FORMAT, "") def get_transfer_type(self) -> str: return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.TRANSFER_TYPE, "FULL") def get_train_task(self): - return self.config[ConfigKey.TASK_EXCHANGE][ConfigKey.TRAIN_TASK_NAME] + return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.TRAIN_TASK_NAME, "") def get_eval_task(self): - return self.config[ConfigKey.TASK_EXCHANGE][ConfigKey.EVAL_TASK_NAME] + return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.EVAL_TASK_NAME, "") def get_submit_model_task(self): - return self.config[ConfigKey.TASK_EXCHANGE][ConfigKey.SUBMIT_MODEL_TASK_NAME] + return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.SUBMIT_MODEL_TASK_NAME, "") def to_json(self, config_file: str): with open(config_file, "w") as f: diff --git a/nvflare/client/ex_process/api.py b/nvflare/client/ex_process/api.py index 912aaea03a..cc694a8f9b 100644 --- a/nvflare/client/ex_process/api.py +++ b/nvflare/client/ex_process/api.py @@ -28,6 +28,7 @@ from nvflare.client.model_registry import ModelRegistry from nvflare.fuel.utils import fobs from nvflare.fuel.utils.import_utils import optional_import +from nvflare.fuel.utils.obj_utils import get_logger from nvflare.fuel.utils.pipe.pipe import Pipe @@ -35,7 +36,7 @@ def _create_client_config(config: str) -> ClientConfig: if isinstance(config, str): client_config = from_file(config_file=config) else: - raise ValueError("config should be a string but got: {type(config)}") + raise ValueError(f"config should be a string but got: {type(config)}") return client_config @@ -62,6 +63,7 @@ def _register_tensor_decomposer(): class ExProcessClientAPI(APISpec): def __init__(self): self.process_model_registry = None + self.logger = get_logger(self) def get_model_registry(self) -> ModelRegistry: """Gets the ModelRegistry.""" @@ -81,10 +83,11 @@ def init(self, rank: Optional[str] = None): rank = os.environ.get("RANK", "0") if self.process_model_registry: - print("Warning: called init() more than once. The subsequence calls are ignored") + self.logger.warning("Warning: called init() more than once. The subsequence calls are ignored") return - client_config = _create_client_config(config=f"config/{CLIENT_API_CONFIG}") + config_file = f"config/{CLIENT_API_CONFIG}" + client_config = _create_client_config(config=config_file) flare_agent = None try: @@ -92,9 +95,11 @@ def init(self, rank: Optional[str] = None): if client_config.get_exchange_format() == ExchangeFormat.PYTORCH: _register_tensor_decomposer() - pipe, task_channel_name = _create_pipe_using_config( - client_config=client_config, section=ConfigKey.TASK_EXCHANGE - ) + pipe, task_channel_name = None, "" + if ConfigKey.TASK_EXCHANGE in client_config.config: + pipe, task_channel_name = _create_pipe_using_config( + client_config=client_config, section=ConfigKey.TASK_EXCHANGE + ) metric_pipe, metric_channel_name = None, "" if ConfigKey.METRICS_EXCHANGE in client_config.config: metric_pipe, metric_channel_name = _create_pipe_using_config( @@ -106,12 +111,13 @@ def init(self, rank: Optional[str] = None): task_channel_name=task_channel_name, metric_pipe=metric_pipe, metric_channel_name=metric_channel_name, + heartbeat_timeout=client_config.config.get(ConfigKey.HEARTBEAT_TIMEOUT, 60), ) flare_agent.start() self.process_model_registry = ModelRegistry(client_config, rank, flare_agent) except Exception as e: - print(f"flare.init failed: {e}") + self.logger.error(f"flare.init failed: {e}") raise e def receive(self, timeout: Optional[float] = None) -> Optional[FLModel]: diff --git a/nvflare/client/flare_agent.py b/nvflare/client/flare_agent.py index 200616ef1f..23ca4e79e8 100644 --- a/nvflare/client/flare_agent.py +++ b/nvflare/client/flare_agent.py @@ -64,14 +64,14 @@ def __init__(self, task_id, task_name: str, msg_id): class FlareAgent: def __init__( self, - pipe: Pipe, + pipe: Optional[Pipe] = None, read_interval=0.1, heartbeat_interval=5.0, - heartbeat_timeout=30.0, + heartbeat_timeout=60.0, resend_interval=2.0, max_resends=None, - submit_result_timeout=30.0, - metric_pipe=None, + submit_result_timeout=60.0, + metric_pipe: Optional[Pipe] = None, task_channel_name: str = PipeChannelName.TASK, metric_channel_name: str = PipeChannelName.METRIC, close_pipe: bool = True, @@ -103,6 +103,10 @@ def __init__( Usually for ``FilePipe`` we set to False, for ``CellPipe`` we set to True. decomposer_module (str): the module name which contains the external decomposers. """ + if pipe is None and metric_pipe is None: + raise RuntimeError( + "Please configure at least one pipe. Both the task pipe and the metric pipe are set to None." + ) flare_decomposers.register() common_decomposers.register() if decomposer_module: @@ -110,14 +114,16 @@ def __init__( self.logger = logging.getLogger(self.__class__.__name__) self.pipe = pipe - self.pipe_handler = PipeHandler( - pipe=self.pipe, - read_interval=read_interval, - heartbeat_interval=heartbeat_interval, - heartbeat_timeout=heartbeat_timeout, - resend_interval=resend_interval, - max_resends=max_resends, - ) + self.pipe_handler = None + if self.pipe: + self.pipe_handler = PipeHandler( + pipe=self.pipe, + read_interval=read_interval, + heartbeat_interval=heartbeat_interval, + heartbeat_timeout=heartbeat_timeout, + resend_interval=resend_interval, + max_resends=max_resends, + ) self.submit_result_timeout = submit_result_timeout self.task_channel_name = task_channel_name self.metric_channel_name = metric_channel_name @@ -148,14 +154,17 @@ def start(self): Returns: None """ - self.pipe.open(self.task_channel_name) - self.pipe_handler.set_status_cb(self._status_cb, pipe_handler=self.pipe_handler, channel=self.task_channel_name) - self.pipe_handler.start() + if self.pipe: + self.pipe.open(self.task_channel_name) + self.pipe_handler.set_status_cb( + self._status_cb, pipe_handler=self.pipe_handler, channel=self.task_channel_name + ) + self.pipe_handler.start() if self.metric_pipe: self.metric_pipe.open(self.metric_channel_name) self.metric_pipe_handler.set_status_cb( - self._status_cb, pipe_handler=self.metric_pipe_handler, channel=self.metric_channel_name + self._metrics_status_cb, pipe_handler=self.metric_pipe_handler, channel=self.metric_channel_name ) self.metric_pipe_handler.start() @@ -164,6 +173,11 @@ def _status_cb(self, msg: Message, pipe_handler: PipeHandler, channel): self.asked_to_stop = True pipe_handler.stop(self._close_pipe) + def _metrics_status_cb(self, msg: Message, pipe_handler: PipeHandler, channel): + self.logger.info(f"{channel} pipe status changed to {msg.topic}: {msg.data}") + self.asked_to_stop = True + pipe_handler.stop(self._close_metric_pipe) + def stop(self): """Stop the agent. @@ -172,9 +186,9 @@ def stop(self): Returns: None """ - self.logger.info("Calling flare agent stop") self.asked_to_stop = True - self.pipe_handler.stop(self._close_pipe) + if self.pipe_handler: + self.pipe_handler.stop(self._close_pipe) if self.metric_pipe_handler: self.metric_pipe_handler.stop(self._close_metric_pipe) @@ -226,6 +240,8 @@ def get_task(self, timeout: Optional[float] = None) -> Optional[Task]: has been submitted. """ + if not self.pipe_handler: + raise RuntimeError("task pipe is not available") start_time = time.time() while True: if self.asked_to_stop: @@ -278,6 +294,8 @@ def submit_result(self, result, rc=RC.OK) -> bool: made a single time regardless whether the submission is successful. """ + if not self.pipe_handler: + raise RuntimeError("task pipe is not available") with self.task_lock: current_task = self.current_task if not current_task: