Skip to content

Commit

Permalink
Merge branch 'main' into enhance_json_serialization_for_numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Jan 14, 2025
2 parents d0791a6 + 895cffa commit 67eeed8
Show file tree
Hide file tree
Showing 39 changed files with 813 additions and 321 deletions.
11 changes: 11 additions & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class ReservedKey(object):
CUSTOM_PROPS = "__custom_props__"
EXCEPTIONS = "__exceptions__"
PROCESS_TYPE = "__process_type__" # type of the current process (SP, CP, SJ, CJ)
JOB_PROCESS_ARGS = "__job_process_args__"


class FLContextKey(object):
Expand Down Expand Up @@ -187,6 +188,7 @@ class FLContextKey(object):
SERVER_CONFIG = "__server_config__"
SERVER_HOST_NAME = "__server_host_name__"
PROCESS_TYPE = ReservedKey.PROCESS_TYPE
JOB_PROCESS_ARGS = ReservedKey.JOB_PROCESS_ARGS


class ProcessType:
Expand Down Expand Up @@ -437,6 +439,7 @@ class SecureTrainConst:
SSL_ROOT_CERT = "ssl_root_cert"
SSL_CERT = "ssl_cert"
PRIVATE_KEY = "ssl_private_key"
CONNECTION_SECURITY = "connection_security"


class FLMetaKey:
Expand All @@ -454,6 +457,14 @@ class FLMetaKey:
SITE_NAME = "site_name"
PROCESS_RC_FILE = "_process_rc.txt"
SUBMIT_MODEL_NAME = "submit_model_name"
AUTH_TOKEN = "auth_token"
AUTH_TOKEN_SIGNATURE = "auth_token_signature"


class CellMessageAuthHeaderKey:
CLIENT_NAME = "client_name"
TOKEN = "__token__"
TOKEN_SIGNATURE = "__token_signature__"


class StreamCtxKey:
Expand Down
25 changes: 24 additions & 1 deletion nvflare/apis/job_launcher_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,29 @@
from nvflare.fuel.common.exit_codes import ProcessExitCode


class JobProcessArgs:

EXE_MODULE = "exe_module"
WORKSPACE = "workspace"
STARTUP_DIR = "startup_dir"
APP_ROOT = "app_root"
AUTH_TOKEN = "auth_token"
TOKEN_SIGNATURE = "auth_signature"
SSID = "ssid"
JOB_ID = "job_id"
CLIENT_NAME = "client_name"
ROOT_URL = "root_url"
PARENT_URL = "parent_url"
SERVICE_HOST = "service_host"
SERVICE_PORT = "service_port"
HA_MODE = "ha_mode"
TARGET = "target"
SCHEME = "scheme"
STARTUP_CONFIG_FILE = "startup_config_file"
RESTORE_SNAPSHOT = "restore_snapshot"
OPTIONS = "options"


class JobReturnCode(ProcessExitCode):
SUCCESS = 0
EXECUTION_ERROR = 1
Expand Down Expand Up @@ -67,7 +90,7 @@ def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec:
"""To launch a job run.
Args:
job_meta: job meta data
job_meta: job metadata
fl_ctx: FLContext
Returns: boolean to indicates the job launch success or fail.
Expand Down
15 changes: 13 additions & 2 deletions nvflare/app_common/executors/client_api_launcher_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
import os
from typing import Optional

from nvflare.apis.fl_constant import FLMetaKey
from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.executors.launcher_executor import LauncherExecutor
from nvflare.client.config import ConfigKey, ExchangeFormat, TransferType, write_config_to_file
from nvflare.client.constants import CLIENT_API_CONFIG
from nvflare.fuel.data_event.utils import get_scope_property
from nvflare.fuel.utils.attributes_exportable import ExportMode


Expand Down Expand Up @@ -125,12 +126,22 @@ def prepare_config_for_launch(self, fl_ctx: FLContext):
ConfigKey.HEARTBEAT_TIMEOUT: self.heartbeat_timeout,
}

site_name = fl_ctx.get_identity_name()
auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA")
signature = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA")

config_data = {
ConfigKey.TASK_EXCHANGE: task_exchange_attributes,
FLMetaKey.SITE_NAME: fl_ctx.get_identity_name(),
FLMetaKey.SITE_NAME: site_name,
FLMetaKey.JOB_ID: fl_ctx.get_job_id(),
FLMetaKey.AUTH_TOKEN: auth_token,
FLMetaKey.AUTH_TOKEN_SIGNATURE: signature,
}

conn_sec = get_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY)
if conn_sec:
config_data[SecureTrainConst.CONNECTION_SECURITY] = conn_sec

config_file_path = self._get_external_config_file_path(fl_ctx)
write_config_to_file(config_data=config_data, config_file_path=config_file_path)

Expand Down
2 changes: 1 addition & 1 deletion nvflare/app_common/job_launcher/client_process_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@

class ClientProcessJobLauncher(ProcessJobLauncher):
def get_command(self, job_meta, fl_ctx) -> str:
return generate_client_command(job_meta, fl_ctx)
return generate_client_command(fl_ctx)
2 changes: 1 addition & 1 deletion nvflare/app_common/job_launcher/server_process_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@

class ServerProcessJobLauncher(ProcessJobLauncher):
def get_command(self, job_meta, fl_ctx) -> str:
return generate_server_command(job_meta, fl_ctx)
return generate_server_command(fl_ctx)
5 changes: 2 additions & 3 deletions nvflare/app_opt/job_launcher/docker_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class DOCKER_STATE:


class DockerJobHandle(JobHandleSpec):

def __init__(self, container, timeout=None):
super().__init__()

Expand Down Expand Up @@ -182,14 +181,14 @@ def get_command(self, job_meta, fl_ctx) -> (str, str):
class ClientDockerJobLauncher(DockerJobLauncher):
def get_command(self, job_meta, fl_ctx) -> (str, str):
job_id = job_meta.get(JobConstants.JOB_ID)
command = generate_client_command(job_meta, fl_ctx)
command = generate_client_command(fl_ctx)

return f"client-{job_id}", command


class ServerDockerJobLauncher(DockerJobLauncher):
def get_command(self, job_meta, fl_ctx) -> (str, str):
job_id = job_meta.get(JobConstants.JOB_ID)
command = generate_server_command(job_meta, fl_ctx)
command = generate_server_command(fl_ctx)

return f"server-{job_id}", command
102 changes: 24 additions & 78 deletions nvflare/app_opt/job_launcher/k8s_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLContextKey, JobConstants
from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_launcher_spec import JobHandleSpec, JobLauncherSpec, JobReturnCode, add_launcher
from nvflare.apis.workspace import Workspace
from nvflare.utils.job_launcher_utils import extract_job_image
from nvflare.apis.job_launcher_spec import JobHandleSpec, JobLauncherSpec, JobProcessArgs, JobReturnCode, add_launcher
from nvflare.utils.job_launcher_utils import extract_job_image, get_client_job_args, get_server_job_args


class JobState(Enum):
Expand Down Expand Up @@ -223,15 +222,21 @@ def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec:
args = fl_ctx.get_prop(FLContextKey.ARGS)
job_image = extract_job_image(job_meta, fl_ctx.get_identity_name())
self.logger.info(f"launch job use image: {job_image}")

job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS)
if not job_args:
raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext")

_, job_cmd = job_args[JobProcessArgs.EXE_MODULE]
job_config = {
"name": job_id,
"image": job_image,
"container_name": f"container-{job_id}",
"command": self.get_command(),
"command": job_cmd,
"volume_mount_list": [{"name": self.workspace, "mountPath": self.mount_path}],
"volume_list": [{"name": self.workspace, "hostPath": {"path": self.root_hostpath, "type": "Directory"}}],
"module_args": self.get_module_args(job_id, fl_ctx),
"set_list": self.get_set_list(args, fl_ctx),
"set_list": args.set,
}

self.logger.info(f"launch job with k8s_launcher. Job_id:{job_id}")
Expand All @@ -255,15 +260,6 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
if job_image:
add_launcher(self, fl_ctx)

@abstractmethod
def get_command(self):
"""To get the run command of the launcher
Returns: the command for the launcher process
"""
pass

@abstractmethod
def get_module_args(self, job_id, fl_ctx: FLContext):
"""To get the args to run the launcher
Expand All @@ -277,78 +273,28 @@ def get_module_args(self, job_id, fl_ctx: FLContext):
"""
pass

@abstractmethod
def get_set_list(self, args, fl_ctx: FLContext):
"""To get the command set_list

Args:
args: command args
fl_ctx: FLContext
Returns: set_list command options
"""
pass
def _job_args_dict(job_args: dict, arg_names: list) -> dict:
result = {}
for name in arg_names:
n, v = job_args[name]
result[n] = v
return result


class ClientK8sJobLauncher(K8sJobLauncher):
def get_command(self):
return "nvflare.private.fed.app.client.worker_process"

def get_module_args(self, job_id, fl_ctx: FLContext):
workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT)
args = fl_ctx.get_prop(FLContextKey.ARGS)
client = fl_ctx.get_prop(FLContextKey.SITE_OBJ)
server_config = fl_ctx.get_prop(FLContextKey.SERVER_CONFIG)
if not server_config:
raise RuntimeError(f"missing {FLContextKey.SERVER_CONFIG} in FL context")
service = server_config[0].get("service", {})
if not isinstance(service, dict):
raise RuntimeError(f"expect server config data to be dict but got {type(service)}")
self.logger.info(f"K8sJobLauncher start to launch job: {job_id} for client: {client.client_name}")

return {
"-m": args.workspace,
"-w": (workspace_obj.get_startup_kit_dir()),
"-t": client.token,
"-d": client.ssid,
"-n": job_id,
"-c": client.client_name,
"-p": str(client.cell.get_internal_listener_url()),
"-g": service.get("target"),
"-scheme": service.get("scheme", "grpc"),
"-s": "fed_client.json",
}
job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS)
if not job_args:
raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext")

def get_set_list(self, args, fl_ctx: FLContext):
args.set.append("print_conf=True")
return args.set
return _job_args_dict(job_args, get_client_job_args(False, False))


class ServerK8sJobLauncher(K8sJobLauncher):
def get_command(self):
return "nvflare.private.fed.app.server.runner_process"

def get_module_args(self, job_id, fl_ctx: FLContext):
workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT)
args = fl_ctx.get_prop(FLContextKey.ARGS)
server = fl_ctx.get_prop(FLContextKey.SITE_OBJ)

return {
"-m": args.workspace,
"-s": "fed_server.json",
"-r": workspace_obj.get_app_dir(),
"-n": str(job_id),
"-p": str(server.cell.get_internal_listener_url()),
"-u": str(server.cell.get_root_url_for_child()),
"--host": str(server.server_state.host),
"--port": str(server.server_state.service_port),
"--ssid": str(server.server_state.ssid),
"--ha_mode": str(server.ha_mode),
}
job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS)
if not job_args:
raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext")

def get_set_list(self, args, fl_ctx: FLContext):
restore_snapshot = fl_ctx.get_prop(FLContextKey.SNAPSHOT, False)
args.set.append("print_conf=True")
args.set.append("restore_snapshot=" + str(restore_snapshot))
return args.set
return _job_args_dict(job_args, get_server_job_args(False, False))
13 changes: 13 additions & 0 deletions nvflare/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
from typing import Dict, Optional

from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst
from nvflare.fuel.utils.config_factory import ConfigFactory


Expand Down Expand Up @@ -155,6 +156,18 @@ def get_heartbeat_timeout(self):
self.config.get(ConfigKey.METRICS_EXCHANGE, {}).get(ConfigKey.HEARTBEAT_TIMEOUT, 60),
)

def get_connection_security(self):
return self.config.get(SecureTrainConst.CONNECTION_SECURITY)

def get_site_name(self):
return self.config.get(FLMetaKey.SITE_NAME)

def get_auth_token(self):
return self.config.get(FLMetaKey.AUTH_TOKEN)

def get_auth_token_signature(self):
return self.config.get(FLMetaKey.AUTH_TOKEN_SIGNATURE)

def to_json(self, config_file: str):
with open(config_file, "w") as f:
json.dump(self.config, f, indent=2)
Expand Down
15 changes: 14 additions & 1 deletion nvflare/client/ex_process/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, Dict, Optional, Tuple

from nvflare.apis.analytix import AnalyticsDataType
from nvflare.apis.fl_constant import FLMetaKey
from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst
from nvflare.apis.utils.analytix_utils import create_analytic_dxo
from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.client.api_spec import APISpec
Expand All @@ -26,6 +26,7 @@
from nvflare.client.flare_agent import FlareAgentException
from nvflare.client.flare_agent_with_fl_model import FlareAgentWithFLModel
from nvflare.client.model_registry import ModelRegistry
from nvflare.fuel.data_event.utils import set_scope_property
from nvflare.fuel.utils.import_utils import optional_import
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.fuel.utils.pipe.pipe import Pipe
Expand All @@ -36,6 +37,18 @@ def _create_client_config(config: str) -> ClientConfig:
client_config = from_file(config_file=config)
else:
raise ValueError(f"config should be a string but got: {type(config)}")

site_name = client_config.get_site_name()
conn_sec = client_config.get_connection_security()
if conn_sec:
set_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY, conn_sec)

# get message auth info and put them into Databus for CellPipe to use
auth_token = client_config.get_auth_token()
signature = client_config.get_auth_token_signature()
set_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, value=auth_token)
set_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, value=signature)

return client_config


Expand Down
Loading

0 comments on commit 67eeed8

Please sign in to comment.