From c884d2f68b7a7158466c3d5cd32b1230e3974505 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Thu, 19 Dec 2024 10:04:12 -0500 Subject: [PATCH 1/3] Testing --- examples/hello-world/hello-pt/fedavg_script_runner_pt.py | 4 ++-- examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py | 3 ++- nvflare/app_common/launchers/subprocess_launcher.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/hello-world/hello-pt/fedavg_script_runner_pt.py b/examples/hello-world/hello-pt/fedavg_script_runner_pt.py index 8c635aae6a..1b5c5f46a1 100644 --- a/examples/hello-world/hello-pt/fedavg_script_runner_pt.py +++ b/examples/hello-world/hello-pt/fedavg_script_runner_pt.py @@ -23,13 +23,13 @@ train_script = "src/hello-pt_cifar10_fl.py" job = FedAvgJob( - name="hello-pt_cifar10_fedavg", n_clients=n_clients, num_rounds=num_rounds, initial_model=SimpleNetwork() + name="hello-pt_cifar10_fedavg", n_clients=n_clients, num_rounds=num_rounds, initial_model=SimpleNetwork(), ) # Add clients for i in range(n_clients): executor = ScriptRunner( - script=train_script, script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" + script=train_script, script_args="", launch_external_process=True # f"--batch_size 32 --data_path /tmp/data/site-{i}" ) job.to(executor, f"site-{i + 1}") diff --git a/examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py b/examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py index 19dfb3c969..88c4bb0e15 100644 --- a/examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py +++ b/examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py @@ -77,7 +77,8 @@ def main(): print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss / 3000}") global_step = input_model.current_round * steps + epoch * len(train_loader) + i summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step) - running_loss = 0.0 + running_loss = xyz + print("Finished Training") diff --git a/nvflare/app_common/launchers/subprocess_launcher.py b/nvflare/app_common/launchers/subprocess_launcher.py index faf6151f38..f4a9e81279 100644 --- a/nvflare/app_common/launchers/subprocess_launcher.py +++ b/nvflare/app_common/launchers/subprocess_launcher.py @@ -82,7 +82,7 @@ def _start_external_process(self, fl_ctx: FLContext): self._process = subprocess.Popen( command_seq, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=self._app_dir, env=env ) - self._log_thread = Thread(target=log_subprocess_output, args=(self._process, self.logger)) + self._log_thread = Thread(target=log_subprocess_output, args=(self._process, self.logger), daemon=True) self._log_thread.start() def _stop_external_process(self): From 0f9d9232a0c1e70a436a8263f586636b6811ec06 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Fri, 20 Dec 2024 08:51:20 -0500 Subject: [PATCH 2/3] Added death watch --- nvflare/client/api.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/nvflare/client/api.py b/nvflare/client/api.py index a536b5812c..2c824100fd 100644 --- a/nvflare/client/api.py +++ b/nvflare/client/api.py @@ -14,7 +14,11 @@ import logging import os +import signal +import threading +import time from enum import Enum +from threading import Thread from typing import Any, Dict, Optional from nvflare.apis.analytix import AnalyticsDataType @@ -24,6 +28,7 @@ from .api_spec import CLIENT_API_KEY, CLIENT_API_TYPE_KEY, APISpec from .ex_process.api import ExProcessClientAPI +logger = logging.getLogger(__name__) class ClientAPIType(Enum): IN_PROCESS_API = "IN_PROCESS_API" @@ -33,6 +38,21 @@ class ClientAPIType(Enum): client_api: Optional[APISpec] = None data_bus = DataBus() +def death_watch(): + """ + Python's main thread doesn't die if there are running thread pools. + This function kills the process when the main thread is in the shutdown process + """ + try: + while True: + if threading._SHUTTING_DOWN: + os.kill(os.getpid(), signal.SIGKILL) + # Just in case kill doesn't work + logger.error(f"Process {os.getpid()} is killed but still running") + break + time.sleep(1) + except Exception as ex: + logger.warning(f"Death watch failed with error: {ex}") def init(rank: Optional[str] = None): """Initializes NVFlare Client API environment. @@ -44,6 +64,9 @@ def init(rank: Optional[str] = None): Returns: None """ + + Thread(target=death_watch, name="death_watch", daemon=False).start() + api_type_name = os.environ.get(CLIENT_API_TYPE_KEY, ClientAPIType.IN_PROCESS_API.value) api_type = ClientAPIType(api_type_name) global client_api @@ -54,7 +77,7 @@ def init(rank: Optional[str] = None): client_api = ExProcessClientAPI() client_api.init(rank=rank) else: - logging.warning("Warning: called init() more than once. The subsequence calls are ignored") + logger.warning("Warning: called init() more than once. The subsequence calls are ignored") def receive(timeout: Optional[float] = None) -> Optional[FLModel]: From 82207a7624bdf4968f468a460770d0b48470962e Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Fri, 20 Dec 2024 08:53:55 -0500 Subject: [PATCH 3/3] Format change --- examples/hello-world/hello-pt/fedavg_script_runner_pt.py | 4 ++-- examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py | 3 +-- nvflare/client/api.py | 3 +++ 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/hello-world/hello-pt/fedavg_script_runner_pt.py b/examples/hello-world/hello-pt/fedavg_script_runner_pt.py index 1b5c5f46a1..8c635aae6a 100644 --- a/examples/hello-world/hello-pt/fedavg_script_runner_pt.py +++ b/examples/hello-world/hello-pt/fedavg_script_runner_pt.py @@ -23,13 +23,13 @@ train_script = "src/hello-pt_cifar10_fl.py" job = FedAvgJob( - name="hello-pt_cifar10_fedavg", n_clients=n_clients, num_rounds=num_rounds, initial_model=SimpleNetwork(), + name="hello-pt_cifar10_fedavg", n_clients=n_clients, num_rounds=num_rounds, initial_model=SimpleNetwork() ) # Add clients for i in range(n_clients): executor = ScriptRunner( - script=train_script, script_args="", launch_external_process=True # f"--batch_size 32 --data_path /tmp/data/site-{i}" + script=train_script, script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" ) job.to(executor, f"site-{i + 1}") diff --git a/examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py b/examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py index 88c4bb0e15..19dfb3c969 100644 --- a/examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py +++ b/examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py @@ -77,8 +77,7 @@ def main(): print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss / 3000}") global_step = input_model.current_round * steps + epoch * len(train_loader) + i summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step) - running_loss = xyz - + running_loss = 0.0 print("Finished Training") diff --git a/nvflare/client/api.py b/nvflare/client/api.py index 2c824100fd..1324a5fa6c 100644 --- a/nvflare/client/api.py +++ b/nvflare/client/api.py @@ -30,6 +30,7 @@ logger = logging.getLogger(__name__) + class ClientAPIType(Enum): IN_PROCESS_API = "IN_PROCESS_API" EX_PROCESS_API = "EX_PROCESS_API" @@ -38,6 +39,7 @@ class ClientAPIType(Enum): client_api: Optional[APISpec] = None data_bus = DataBus() + def death_watch(): """ Python's main thread doesn't die if there are running thread pools. @@ -54,6 +56,7 @@ def death_watch(): except Exception as ex: logger.warning(f"Death watch failed with error: {ex}") + def init(rank: Optional[str] = None): """Initializes NVFlare Client API environment.