Skip to content

Commit

Permalink
enhance the rc handling for MPM. (#1985)
Browse files Browse the repository at this point in the history
* enhance the rc handling for MPM.

* consolidate the version_ckeck() to a single function.

* fixed a variable name error.

* Added rc enahncement for client_exxecutor.

* extract the common function for get_return_code().

* Removed the no used import.

* codestyle fix.

* codestyle fixes, and addressed a simulator_runner run_processs() exception handling.

* Added enahcements for get_return_code().

---------

Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <[email protected]>
  • Loading branch information
yhwen and YuanTingHsieh authored Nov 9, 2023
1 parent 319d505 commit a060c60
Show file tree
Hide file tree
Showing 14 changed files with 175 additions and 114 deletions.
1 change: 1 addition & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ class FLMetaKey:
TOTAL_ROUNDS = "total_rounds"
JOB_ID = "job_id"
SITE_NAME = "site_name"
PROCESS_RC_FILE = "_process_rc.txt"
SUBMIT_MODEL_NAME = "submit_model_name"


Expand Down
16 changes: 13 additions & 3 deletions nvflare/fuel/f3/mpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import threading
import time

from nvflare.apis.fl_constant import FLMetaKey
from nvflare.fuel.common.excepts import ComponentNotAuthorized, ConfigError
from nvflare.fuel.common.exit_codes import ProcessExitCode
from nvflare.fuel.f3.drivers.aio_context import AioContext
Expand Down Expand Up @@ -128,7 +129,7 @@ def _do_cleanup(cls, waiter: threading.Event):
waiter.set()

@classmethod
def run(cls, main_func, shutdown_grace_time=1.5, cleanup_grace_time=1.5):
def run(cls, main_func, run_dir=None, shutdown_grace_time=1.5, cleanup_grace_time=1.5, **kwargs):
if not callable(main_func):
raise ValueError("main_func must be runnable")

Expand All @@ -139,11 +140,18 @@ def run(cls, main_func, shutdown_grace_time=1.5, cleanup_grace_time=1.5):
f"{cls.name}: the mpm.run() method is called from {t.name}: it must be called from the MainThread"
)

if not run_dir:
run_dir = os.getcwd()
rc_file = os.path.join(run_dir, FLMetaKey.PROCESS_RC_FILE)

# call and wait for the main_func to complete
logger = cls.logger()
logger.debug(f"=========== {cls.name}: started to run forever")
try:
rc = main_func()
if os.path.exists(rc_file):
os.remove(rc_file)

rc = main_func(**kwargs)
except ConfigError as ex:
# already handled
rc = ProcessExitCode.CONFIG_ERROR
Expand All @@ -169,10 +177,12 @@ def run(cls, main_func, shutdown_grace_time=1.5, cleanup_grace_time=1.5):
if thread.name != "MainThread" and not thread.daemon:
logger.warning(f"#### {cls.name}: still running thread {thread.name}")
num_active_threads += 1

logger.info(f"{cls.name}: Good Bye!")
if num_active_threads > 0:
try:
with open(rc_file, "w") as outfile:
outfile.write(f"{rc}")

os.kill(os.getpid(), signal.SIGKILL)
except Exception as ex:
logger.debug(f"Failed to kill process {os.getpid()}: {secure_format_exception(ex)}")
Expand Down
29 changes: 15 additions & 14 deletions nvflare/private/fed/app/client/client_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.private.defs import AppFolderConstants
from nvflare.private.fed.app.fl_conf import FLClientStarterConfiger, create_privacy_manager
from nvflare.private.fed.app.utils import version_check
from nvflare.private.fed.client.admin import FedAdminAgent
from nvflare.private.fed.client.client_engine import ClientEngine
from nvflare.private.fed.client.client_status import ClientStatus
Expand All @@ -36,19 +37,7 @@
from nvflare.security.logging import secure_format_exception


def main():
if sys.version_info >= (3, 11):
raise RuntimeError("Python versions 3.11 and above are not yet supported. Please use Python 3.8, 3.9 or 3.10.")
if sys.version_info < (3, 8):
raise RuntimeError("Python versions 3.7 and below are not supported. Please use Python 3.8, 3.9 or 3.10")

parser = argparse.ArgumentParser()
parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True)
parser.add_argument("--fed_client", "-s", type=str, help="client config json file", required=True)
parser.add_argument("--set", metavar="KEY=VALUE", nargs="*")
parser.add_argument("--local_rank", type=int, default=0)

args = parser.parse_args()
def main(args):
kv_list = parse_vars(args.set)

config_folder = kv_list.get("config_folder", "")
Expand Down Expand Up @@ -143,6 +132,16 @@ def main():
print(f"ConfigError: {secure_format_exception(e)}")


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True)
parser.add_argument("--fed_client", "-s", type=str, help="client config json file", required=True)
parser.add_argument("--set", metavar="KEY=VALUE", nargs="*")
parser.add_argument("--local_rank", type=int, default=0)
args = parser.parse_args()
return args


def create_admin_agent(req_processors, federated_client: FederatedClient, client_engine: ClientEngine):
"""Creates an admin agent.
Expand Down Expand Up @@ -182,5 +181,7 @@ def create_admin_agent(req_processors, federated_client: FederatedClient, client
# multiprocessing.set_start_method('spawn')

# main()
rc = mpm.run(main_func=main)
version_check()
args = parse_arguments()
rc = mpm.run(main_func=main, run_dir=args.workspace, args=args)
sys.exit(rc)
33 changes: 19 additions & 14 deletions nvflare/private/fed/app/client/sub_worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,19 +309,7 @@ def stop(self):
self.done = True


def main():
"""Sub_worker process program."""
parser = argparse.ArgumentParser()
parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True)
parser.add_argument("--num_processes", type=str, help="Listen ports", required=True)
parser.add_argument("--job_id", "-n", type=str, help="job_id", required=True)
parser.add_argument("--client_name", "-c", type=str, help="client name", required=True)
parser.add_argument("--simulator_engine", "-s", type=str, help="simulator engine", required=True)
parser.add_argument("--parent_pid", type=int, help="parent process pid", required=True)
parser.add_argument("--root_url", type=str, help="root cell url", required=True)
parser.add_argument("--parent_url", type=str, help="parent cell url", required=True)

args = parser.parse_args()
def main(args):
workspace = Workspace(args.workspace, args.client_name)
app_custom_folder = workspace.get_client_custom_dir()
if os.path.isdir(app_custom_folder):
Expand Down Expand Up @@ -367,9 +355,26 @@ def main():
logger.warning(err)


def parse_arguments():
"""Sub_worker process program."""
parser = argparse.ArgumentParser()
parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True)
parser.add_argument("--num_processes", type=str, help="Listen ports", required=True)
parser.add_argument("--job_id", "-n", type=str, help="job_id", required=True)
parser.add_argument("--client_name", "-c", type=str, help="client name", required=True)
parser.add_argument("--simulator_engine", "-s", type=str, help="simulator engine", required=True)
parser.add_argument("--parent_pid", type=int, help="parent process pid", required=True)
parser.add_argument("--root_url", type=str, help="root cell url", required=True)
parser.add_argument("--parent_url", type=str, help="parent cell url", required=True)
args = parser.parse_args()
return args


if __name__ == "__main__":
"""
This is the program for running rank processes in multi-process mode.
"""
# main()
mpm.run(main_func=main)
args = parse_arguments()
run_dir = os.path.join(args.workspace, args.job_id)
mpm.run(main_func=main, run_dir=run_dir, args=args)
48 changes: 25 additions & 23 deletions nvflare/private/fed/app/client/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,7 @@
from nvflare.security.logging import secure_format_exception


def main():
"""Worker process start program."""
parser = argparse.ArgumentParser()
parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True)
parser.add_argument("--startup", "-w", type=str, help="startup folder", required=True)
parser.add_argument("--token", "-t", type=str, help="token", required=True)
parser.add_argument("--ssid", "-d", type=str, help="ssid", required=True)
parser.add_argument("--job_id", "-n", type=str, help="job_id", required=True)
parser.add_argument("--client_name", "-c", type=str, help="client name", required=True)
# parser.add_argument("--listen_port", "-p", type=str, help="listen port", required=True)
parser.add_argument("--sp_target", "-g", type=str, help="Sp target", required=True)
parser.add_argument("--parent_url", "-p", type=str, help="parent_url", required=True)

parser.add_argument(
"--fed_client", "-s", type=str, help="an aggregation server specification json file", required=True
)

parser.add_argument("--set", metavar="KEY=VALUE", nargs="*")

parser.add_argument("--local_rank", type=int, default=0)

args = parser.parse_args()
def main(args):
kv_list = parse_vars(args.set)

# get parent process id
Expand Down Expand Up @@ -159,6 +138,27 @@ def main():
logger.warning(err)


def parse_arguments():
"""Worker process start program."""
parser = argparse.ArgumentParser()
parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True)
parser.add_argument("--startup", "-w", type=str, help="startup folder", required=True)
parser.add_argument("--token", "-t", type=str, help="token", required=True)
parser.add_argument("--ssid", "-d", type=str, help="ssid", required=True)
parser.add_argument("--job_id", "-n", type=str, help="job_id", required=True)
parser.add_argument("--client_name", "-c", type=str, help="client name", required=True)
# parser.add_argument("--listen_port", "-p", type=str, help="listen port", required=True)
parser.add_argument("--sp_target", "-g", type=str, help="Sp target", required=True)
parser.add_argument("--parent_url", "-p", type=str, help="parent_url", required=True)
parser.add_argument(
"--fed_client", "-s", type=str, help="an aggregation server specification json file", required=True
)
parser.add_argument("--set", metavar="KEY=VALUE", nargs="*")
parser.add_argument("--local_rank", type=int, default=0)
args = parser.parse_args()
return args


def _create_sp(args):
sp = SP()
target = args.sp_target.split(":")
Expand Down Expand Up @@ -190,5 +190,7 @@ def remove_restart_file(workspace: Workspace):
"""

# main()
rc = mpm.run(main_func=main)
args = parse_arguments()
run_dir = os.path.join(args.workspace, args.job_id)
rc = mpm.run(main_func=main, run_dir=run_dir, args=args)
sys.exit(rc)
44 changes: 24 additions & 20 deletions nvflare/private/fed/app/server/runner_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,7 @@
from nvflare.security.logging import secure_format_exception, secure_log_traceback


def main():
"""FL Server program starting point."""
parser = argparse.ArgumentParser()
parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True)
parser.add_argument(
"--fed_server", "-s", type=str, help="an aggregation server specification json file", required=True
)
parser.add_argument("--app_root", "-r", type=str, help="App Root", required=True)
parser.add_argument("--job_id", "-n", type=str, help="job id", required=True)
parser.add_argument("--root_url", "-u", type=str, help="root_url", required=True)
parser.add_argument("--host", "-host", type=str, help="server host", required=True)
parser.add_argument("--port", "-port", type=str, help="service port", required=True)
parser.add_argument("--ssid", "-id", type=str, help="SSID", required=True)
parser.add_argument("--parent_url", "-p", type=str, help="parent_url", required=True)
parser.add_argument("--ha_mode", "-ha_mode", type=str, help="HA mode", required=True)

parser.add_argument("--set", metavar="KEY=VALUE", nargs="*")

args = parser.parse_args()
def main(args):
kv_list = parse_vars(args.set)

config_folder = kv_list.get("config_folder", "")
Expand Down Expand Up @@ -153,10 +135,32 @@ def main():
raise e


def parse_arguments():
"""FL Server program starting point."""
parser = argparse.ArgumentParser()
parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True)
parser.add_argument(
"--fed_server", "-s", type=str, help="an aggregation server specification json file", required=True
)
parser.add_argument("--app_root", "-r", type=str, help="App Root", required=True)
parser.add_argument("--job_id", "-n", type=str, help="job id", required=True)
parser.add_argument("--root_url", "-u", type=str, help="root_url", required=True)
parser.add_argument("--host", "-host", type=str, help="server host", required=True)
parser.add_argument("--port", "-port", type=str, help="service port", required=True)
parser.add_argument("--ssid", "-id", type=str, help="SSID", required=True)
parser.add_argument("--parent_url", "-p", type=str, help="parent_url", required=True)
parser.add_argument("--ha_mode", "-ha_mode", type=str, help="HA mode", required=True)
parser.add_argument("--set", metavar="KEY=VALUE", nargs="*")
args = parser.parse_args()
return args


if __name__ == "__main__":
"""
This is the program when starting the child process for running the NVIDIA FLARE server runner.
"""
# main()
rc = mpm.run(main_func=main)
args = parse_arguments()
run_dir = os.path.join(args.workspace, args.job_id)
rc = mpm.run(main_func=main, run_dir=run_dir, args=args)
sys.exit(rc)
32 changes: 16 additions & 16 deletions nvflare/private/fed/app/server/server_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,14 @@
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.private.defs import AppFolderConstants
from nvflare.private.fed.app.fl_conf import FLServerStarterConfiger, create_privacy_manager
from nvflare.private.fed.app.utils import create_admin_server
from nvflare.private.fed.app.utils import create_admin_server, version_check
from nvflare.private.fed.server.server_status import ServerStatus
from nvflare.private.fed.utils.fed_utils import add_logfile_handler, fobs_initialize, security_init
from nvflare.private.privacy_manager import PrivacyService
from nvflare.security.logging import secure_format_exception


def main():
if sys.version_info >= (3, 11):
raise RuntimeError("Python versions 3.11 and above are not yet supported. Please use Python 3.8, 3.9 or 3.10.")
if sys.version_info < (3, 8):
raise RuntimeError("Python versions 3.7 and below are not supported. Please use Python 3.8, 3.9 or 3.10")

parser = argparse.ArgumentParser()
parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True)
parser.add_argument(
"--fed_server", "-s", type=str, help="an aggregation server specification json file", required=True
)
parser.add_argument("--set", metavar="KEY=VALUE", nargs="*")

args = parser.parse_args()
def main(args):
kv_list = parse_vars(args.set)

config_folder = kv_list.get("config_folder", "")
Expand Down Expand Up @@ -153,10 +140,23 @@ def main():
raise e


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True)
parser.add_argument(
"--fed_server", "-s", type=str, help="an aggregation server specification json file", required=True
)
parser.add_argument("--set", metavar="KEY=VALUE", nargs="*")
args = parser.parse_args()
return args


if __name__ == "__main__":
"""
This is the main program when starting the NVIDIA FLARE server process.
"""

rc = mpm.run(main_func=main)
version_check()
args = parse_arguments()
rc = mpm.run(main_func=main, run_dir=args.workspace, args=args)
sys.exit(rc)
4 changes: 2 additions & 2 deletions nvflare/private/fed/app/simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sys import platform

from nvflare.private.fed.app.simulator.simulator_runner import SimulatorRunner
from nvflare.private.fed.app.utils import version_check


def define_simulator_parser(simulator_parser):
Expand Down Expand Up @@ -59,8 +60,7 @@ def run_simulator(simulator_args):

multiprocessing.set_start_method("spawn")

if sys.version_info < (3, 8):
raise RuntimeError("Please use Python 3.8 or above.")
version_check()

parser = argparse.ArgumentParser()
define_simulator_parser(parser)
Expand Down
9 changes: 8 additions & 1 deletion nvflare/private/fed/app/simulator/simulator_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from nvflare.apis.job_def import ALL_SITES, JobMetaKey
from nvflare.apis.utils.job_utils import convert_legacy_zipped_app_to_job
from nvflare.apis.workspace import Workspace
from nvflare.fuel.common.exit_codes import ProcessExitCode
from nvflare.fuel.common.multi_process_executor_constants import CommunicationMetaData
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.f3.stats_pool import StatsPoolManager
Expand Down Expand Up @@ -348,7 +349,13 @@ def run(self):

def run_processs(self, return_dict):
# run_status = self.simulator_run_main()
run_status = mpm.run(main_func=self.simulator_run_main, shutdown_grace_time=3, cleanup_grace_time=6)
try:
run_status = mpm.run(
main_func=self.simulator_run_main, run_dir=self.workspace, shutdown_grace_time=3, cleanup_grace_time=6
)
except Exception as e:
self.logger.error(f"Simulator main run with exception: {secure_format_exception(e)}")
run_status = ProcessExitCode.EXCEPTION

return_dict["run_status"] = run_status

Expand Down
Loading

0 comments on commit a060c60

Please sign in to comment.