diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index 0c5dd94819..f6a27e26f8 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -106,6 +106,7 @@ class ReservedKey(object): CUSTOM_PROPS = "__custom_props__" EXCEPTIONS = "__exceptions__" PROCESS_TYPE = "__process_type__" # type of the current process (SP, CP, SJ, CJ) + JOB_PROCESS_ARGS = "__job_process_args__" class FLContextKey(object): @@ -187,6 +188,7 @@ class FLContextKey(object): SERVER_CONFIG = "__server_config__" SERVER_HOST_NAME = "__server_host_name__" PROCESS_TYPE = ReservedKey.PROCESS_TYPE + JOB_PROCESS_ARGS = ReservedKey.JOB_PROCESS_ARGS class ProcessType: @@ -437,6 +439,7 @@ class SecureTrainConst: SSL_ROOT_CERT = "ssl_root_cert" SSL_CERT = "ssl_cert" PRIVATE_KEY = "ssl_private_key" + CONNECTION_SECURITY = "connection_security" class FLMetaKey: @@ -454,6 +457,14 @@ class FLMetaKey: SITE_NAME = "site_name" PROCESS_RC_FILE = "_process_rc.txt" SUBMIT_MODEL_NAME = "submit_model_name" + AUTH_TOKEN = "auth_token" + AUTH_TOKEN_SIGNATURE = "auth_token_signature" + + +class CellMessageAuthHeaderKey: + CLIENT_NAME = "client_name" + TOKEN = "__token__" + TOKEN_SIGNATURE = "__token_signature__" class StreamCtxKey: diff --git a/nvflare/apis/job_launcher_spec.py b/nvflare/apis/job_launcher_spec.py index fcc0e04c94..36edc62380 100644 --- a/nvflare/apis/job_launcher_spec.py +++ b/nvflare/apis/job_launcher_spec.py @@ -19,6 +19,29 @@ from nvflare.fuel.common.exit_codes import ProcessExitCode +class JobProcessArgs: + + EXE_MODULE = "exe_module" + WORKSPACE = "workspace" + STARTUP_DIR = "startup_dir" + APP_ROOT = "app_root" + AUTH_TOKEN = "auth_token" + TOKEN_SIGNATURE = "auth_signature" + SSID = "ssid" + JOB_ID = "job_id" + CLIENT_NAME = "client_name" + ROOT_URL = "root_url" + PARENT_URL = "parent_url" + SERVICE_HOST = "service_host" + SERVICE_PORT = "service_port" + HA_MODE = "ha_mode" + TARGET = "target" + SCHEME = "scheme" + STARTUP_CONFIG_FILE = "startup_config_file" + RESTORE_SNAPSHOT = "restore_snapshot" + OPTIONS = "options" + + class JobReturnCode(ProcessExitCode): SUCCESS = 0 EXECUTION_ERROR = 1 @@ -67,7 +90,7 @@ def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: """To launch a job run. Args: - job_meta: job meta data + job_meta: job metadata fl_ctx: FLContext Returns: boolean to indicates the job launch success or fail. diff --git a/nvflare/app_common/executors/client_api_launcher_executor.py b/nvflare/app_common/executors/client_api_launcher_executor.py index 8f6660a5ae..5c10d7b390 100644 --- a/nvflare/app_common/executors/client_api_launcher_executor.py +++ b/nvflare/app_common/executors/client_api_launcher_executor.py @@ -15,12 +15,13 @@ import os from typing import Optional -from nvflare.apis.fl_constant import FLMetaKey +from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst from nvflare.apis.fl_context import FLContext from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.executors.launcher_executor import LauncherExecutor from nvflare.client.config import ConfigKey, ExchangeFormat, TransferType, write_config_to_file from nvflare.client.constants import CLIENT_API_CONFIG +from nvflare.fuel.data_event.utils import get_scope_property from nvflare.fuel.utils.attributes_exportable import ExportMode @@ -125,12 +126,22 @@ def prepare_config_for_launch(self, fl_ctx: FLContext): ConfigKey.HEARTBEAT_TIMEOUT: self.heartbeat_timeout, } + site_name = fl_ctx.get_identity_name() + auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA") + signature = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA") + config_data = { ConfigKey.TASK_EXCHANGE: task_exchange_attributes, - FLMetaKey.SITE_NAME: fl_ctx.get_identity_name(), + FLMetaKey.SITE_NAME: site_name, FLMetaKey.JOB_ID: fl_ctx.get_job_id(), + FLMetaKey.AUTH_TOKEN: auth_token, + FLMetaKey.AUTH_TOKEN_SIGNATURE: signature, } + conn_sec = get_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY) + if conn_sec: + config_data[SecureTrainConst.CONNECTION_SECURITY] = conn_sec + config_file_path = self._get_external_config_file_path(fl_ctx) write_config_to_file(config_data=config_data, config_file_path=config_file_path) diff --git a/nvflare/app_common/job_launcher/client_process_launcher.py b/nvflare/app_common/job_launcher/client_process_launcher.py index 306b84b094..e967cb4f46 100644 --- a/nvflare/app_common/job_launcher/client_process_launcher.py +++ b/nvflare/app_common/job_launcher/client_process_launcher.py @@ -18,4 +18,4 @@ class ClientProcessJobLauncher(ProcessJobLauncher): def get_command(self, job_meta, fl_ctx) -> str: - return generate_client_command(job_meta, fl_ctx) + return generate_client_command(fl_ctx) diff --git a/nvflare/app_common/job_launcher/server_process_launcher.py b/nvflare/app_common/job_launcher/server_process_launcher.py index 0adcf4c2f1..3d4b830301 100644 --- a/nvflare/app_common/job_launcher/server_process_launcher.py +++ b/nvflare/app_common/job_launcher/server_process_launcher.py @@ -18,4 +18,4 @@ class ServerProcessJobLauncher(ProcessJobLauncher): def get_command(self, job_meta, fl_ctx) -> str: - return generate_server_command(job_meta, fl_ctx) + return generate_server_command(fl_ctx) diff --git a/nvflare/app_opt/job_launcher/docker_launcher.py b/nvflare/app_opt/job_launcher/docker_launcher.py index 48627db81a..46ffd402dd 100644 --- a/nvflare/app_opt/job_launcher/docker_launcher.py +++ b/nvflare/app_opt/job_launcher/docker_launcher.py @@ -45,7 +45,6 @@ class DOCKER_STATE: class DockerJobHandle(JobHandleSpec): - def __init__(self, container, timeout=None): super().__init__() @@ -182,7 +181,7 @@ def get_command(self, job_meta, fl_ctx) -> (str, str): class ClientDockerJobLauncher(DockerJobLauncher): def get_command(self, job_meta, fl_ctx) -> (str, str): job_id = job_meta.get(JobConstants.JOB_ID) - command = generate_client_command(job_meta, fl_ctx) + command = generate_client_command(fl_ctx) return f"client-{job_id}", command @@ -190,6 +189,6 @@ def get_command(self, job_meta, fl_ctx) -> (str, str): class ServerDockerJobLauncher(DockerJobLauncher): def get_command(self, job_meta, fl_ctx) -> (str, str): job_id = job_meta.get(JobConstants.JOB_ID) - command = generate_server_command(job_meta, fl_ctx) + command = generate_server_command(fl_ctx) return f"server-{job_id}", command diff --git a/nvflare/app_opt/job_launcher/k8s_launcher.py b/nvflare/app_opt/job_launcher/k8s_launcher.py index 5f1a1193e5..39af3675ed 100644 --- a/nvflare/app_opt/job_launcher/k8s_launcher.py +++ b/nvflare/app_opt/job_launcher/k8s_launcher.py @@ -24,9 +24,8 @@ from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import FLContextKey, JobConstants from nvflare.apis.fl_context import FLContext -from nvflare.apis.job_launcher_spec import JobHandleSpec, JobLauncherSpec, JobReturnCode, add_launcher -from nvflare.apis.workspace import Workspace -from nvflare.utils.job_launcher_utils import extract_job_image +from nvflare.apis.job_launcher_spec import JobHandleSpec, JobLauncherSpec, JobProcessArgs, JobReturnCode, add_launcher +from nvflare.utils.job_launcher_utils import extract_job_image, get_client_job_args, get_server_job_args class JobState(Enum): @@ -223,15 +222,21 @@ def launch_job(self, job_meta: dict, fl_ctx: FLContext) -> JobHandleSpec: args = fl_ctx.get_prop(FLContextKey.ARGS) job_image = extract_job_image(job_meta, fl_ctx.get_identity_name()) self.logger.info(f"launch job use image: {job_image}") + + job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS) + if not job_args: + raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext") + + _, job_cmd = job_args[JobProcessArgs.EXE_MODULE] job_config = { "name": job_id, "image": job_image, "container_name": f"container-{job_id}", - "command": self.get_command(), + "command": job_cmd, "volume_mount_list": [{"name": self.workspace, "mountPath": self.mount_path}], "volume_list": [{"name": self.workspace, "hostPath": {"path": self.root_hostpath, "type": "Directory"}}], "module_args": self.get_module_args(job_id, fl_ctx), - "set_list": self.get_set_list(args, fl_ctx), + "set_list": args.set, } self.logger.info(f"launch job with k8s_launcher. Job_id:{job_id}") @@ -255,15 +260,6 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): if job_image: add_launcher(self, fl_ctx) - @abstractmethod - def get_command(self): - """To get the run command of the launcher - - Returns: the command for the launcher process - - """ - pass - @abstractmethod def get_module_args(self, job_id, fl_ctx: FLContext): """To get the args to run the launcher @@ -277,78 +273,28 @@ def get_module_args(self, job_id, fl_ctx: FLContext): """ pass - @abstractmethod - def get_set_list(self, args, fl_ctx: FLContext): - """To get the command set_list - Args: - args: command args - fl_ctx: FLContext - - Returns: set_list command options - - """ - pass +def _job_args_dict(job_args: dict, arg_names: list) -> dict: + result = {} + for name in arg_names: + n, v = job_args[name] + result[n] = v + return result class ClientK8sJobLauncher(K8sJobLauncher): - def get_command(self): - return "nvflare.private.fed.app.client.worker_process" - def get_module_args(self, job_id, fl_ctx: FLContext): - workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) - args = fl_ctx.get_prop(FLContextKey.ARGS) - client = fl_ctx.get_prop(FLContextKey.SITE_OBJ) - server_config = fl_ctx.get_prop(FLContextKey.SERVER_CONFIG) - if not server_config: - raise RuntimeError(f"missing {FLContextKey.SERVER_CONFIG} in FL context") - service = server_config[0].get("service", {}) - if not isinstance(service, dict): - raise RuntimeError(f"expect server config data to be dict but got {type(service)}") - self.logger.info(f"K8sJobLauncher start to launch job: {job_id} for client: {client.client_name}") - - return { - "-m": args.workspace, - "-w": (workspace_obj.get_startup_kit_dir()), - "-t": client.token, - "-d": client.ssid, - "-n": job_id, - "-c": client.client_name, - "-p": str(client.cell.get_internal_listener_url()), - "-g": service.get("target"), - "-scheme": service.get("scheme", "grpc"), - "-s": "fed_client.json", - } + job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS) + if not job_args: + raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext") - def get_set_list(self, args, fl_ctx: FLContext): - args.set.append("print_conf=True") - return args.set + return _job_args_dict(job_args, get_client_job_args(False, False)) class ServerK8sJobLauncher(K8sJobLauncher): - def get_command(self): - return "nvflare.private.fed.app.server.runner_process" - def get_module_args(self, job_id, fl_ctx: FLContext): - workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) - args = fl_ctx.get_prop(FLContextKey.ARGS) - server = fl_ctx.get_prop(FLContextKey.SITE_OBJ) - - return { - "-m": args.workspace, - "-s": "fed_server.json", - "-r": workspace_obj.get_app_dir(), - "-n": str(job_id), - "-p": str(server.cell.get_internal_listener_url()), - "-u": str(server.cell.get_root_url_for_child()), - "--host": str(server.server_state.host), - "--port": str(server.server_state.service_port), - "--ssid": str(server.server_state.ssid), - "--ha_mode": str(server.ha_mode), - } + job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS) + if not job_args: + raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext") - def get_set_list(self, args, fl_ctx: FLContext): - restore_snapshot = fl_ctx.get_prop(FLContextKey.SNAPSHOT, False) - args.set.append("print_conf=True") - args.set.append("restore_snapshot=" + str(restore_snapshot)) - return args.set + return _job_args_dict(job_args, get_server_job_args(False, False)) diff --git a/nvflare/client/config.py b/nvflare/client/config.py index 892a149288..477b132dc3 100644 --- a/nvflare/client/config.py +++ b/nvflare/client/config.py @@ -16,6 +16,7 @@ import os from typing import Dict, Optional +from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst from nvflare.fuel.utils.config_factory import ConfigFactory @@ -155,6 +156,18 @@ def get_heartbeat_timeout(self): self.config.get(ConfigKey.METRICS_EXCHANGE, {}).get(ConfigKey.HEARTBEAT_TIMEOUT, 60), ) + def get_connection_security(self): + return self.config.get(SecureTrainConst.CONNECTION_SECURITY) + + def get_site_name(self): + return self.config.get(FLMetaKey.SITE_NAME) + + def get_auth_token(self): + return self.config.get(FLMetaKey.AUTH_TOKEN) + + def get_auth_token_signature(self): + return self.config.get(FLMetaKey.AUTH_TOKEN_SIGNATURE) + def to_json(self, config_file: str): with open(config_file, "w") as f: json.dump(self.config, f, indent=2) diff --git a/nvflare/client/ex_process/api.py b/nvflare/client/ex_process/api.py index a8c9542663..bb94b70939 100644 --- a/nvflare/client/ex_process/api.py +++ b/nvflare/client/ex_process/api.py @@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple from nvflare.apis.analytix import AnalyticsDataType -from nvflare.apis.fl_constant import FLMetaKey +from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst from nvflare.apis.utils.analytix_utils import create_analytic_dxo from nvflare.app_common.abstract.fl_model import FLModel from nvflare.client.api_spec import APISpec @@ -26,6 +26,7 @@ from nvflare.client.flare_agent import FlareAgentException from nvflare.client.flare_agent_with_fl_model import FlareAgentWithFLModel from nvflare.client.model_registry import ModelRegistry +from nvflare.fuel.data_event.utils import set_scope_property from nvflare.fuel.utils.import_utils import optional_import from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.fuel.utils.pipe.pipe import Pipe @@ -36,6 +37,18 @@ def _create_client_config(config: str) -> ClientConfig: client_config = from_file(config_file=config) else: raise ValueError(f"config should be a string but got: {type(config)}") + + site_name = client_config.get_site_name() + conn_sec = client_config.get_connection_security() + if conn_sec: + set_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY, conn_sec) + + # get message auth info and put them into Databus for CellPipe to use + auth_token = client_config.get_auth_token() + signature = client_config.get_auth_token_signature() + set_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, value=auth_token) + set_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, value=signature) + return client_config diff --git a/nvflare/fuel/data_event/utils.py b/nvflare/fuel/data_event/utils.py new file mode 100644 index 0000000000..f08aab295d --- /dev/null +++ b/nvflare/fuel/data_event/utils.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +from nvflare.fuel.utils.validation_utils import check_str + +from .data_bus import DataBus + + +def _scope_prop_key(scope_name: str, key: str): + return f"{scope_name}::{key}" + + +def set_scope_property(scope_name: str, key: str, value: Any): + """Save the specified property of the specified scope (globally). + Args: + scope_name: name of the scope + key: key of the property to be saved + value: value of property + Returns: None + """ + check_str("scope_name", scope_name) + check_str("key", key) + data_bus = DataBus() + data_bus.put_data(_scope_prop_key(scope_name, key), value) + + +def get_scope_property(scope_name: str, key: str, default=None) -> Any: + """Get the value of a specified property from the specified scope. + Args: + scope_name: name of the scope + key: key of the scope + default: value to return if property is not found + Returns: + """ + check_str("scope_name", scope_name) + check_str("key", key) + data_bus = DataBus() + result = data_bus.get_data(_scope_prop_key(scope_name, key)) + if result is None: + result = default + return result diff --git a/nvflare/fuel/f3/cellnet/connector_manager.py b/nvflare/fuel/f3/cellnet/connector_manager.py index d44fb530ad..5c0e65a48f 100644 --- a/nvflare/fuel/f3/cellnet/connector_manager.py +++ b/nvflare/fuel/f3/cellnet/connector_manager.py @@ -39,10 +39,11 @@ class _Defaults: class ConnectorData: - def __init__(self, handle, connect_url: str, active: bool): + def __init__(self, handle, connect_url: str, active: bool, params: dict): self.handle = handle self.connect_url = connect_url self.active = active + self.params = params def get_connection_url(self): return self.connect_url @@ -192,19 +193,19 @@ def _get_connector( try: if active: - handle = self.communicator.add_connector(url, Mode.ACTIVE, ssl_required) + handle, conn_params = self.communicator.add_connector(url, Mode.ACTIVE, ssl_required) connect_url = url elif url: - handle = self.communicator.add_connector(url, Mode.PASSIVE, ssl_required) + handle, conn_params = self.communicator.add_connector(url, Mode.PASSIVE, ssl_required) connect_url = url else: self.logger.info(f"{os.getpid()}: Try start_listener Listener resources: {reqs}") - handle, connect_url = self.communicator.start_listener(scheme, reqs) + handle, connect_url, conn_params = self.communicator.start_listener(scheme, reqs) self.logger.debug(f"{os.getpid()}: ############ dynamic listener at {connect_url}") # Kludge: to wait for listener ready and avoid race time.sleep(0.5) - return ConnectorData(handle, connect_url, active) + return ConnectorData(handle, connect_url, active, conn_params) except CommError as ex: self.logger.error(f"Failed to get connector: {secure_format_exception(ex)}") return None diff --git a/nvflare/fuel/f3/cellnet/core_cell.py b/nvflare/fuel/f3/cellnet/core_cell.py index b60c4b0fa3..4e8b68fa8b 100644 --- a/nvflare/fuel/f3/cellnet/core_cell.py +++ b/nvflare/fuel/f3/cellnet/core_cell.py @@ -43,7 +43,7 @@ from nvflare.fuel.f3.comm_config import CommConfigurator from nvflare.fuel.f3.communicator import Communicator, MessageReceiver from nvflare.fuel.f3.connection import Connection -from nvflare.fuel.f3.drivers.driver_params import DriverParams +from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams from nvflare.fuel.f3.drivers.net_utils import enhance_credential_info from nvflare.fuel.f3.endpoint import Endpoint, EndpointMonitor, EndpointState from nvflare.fuel.f3.message import Message @@ -327,6 +327,16 @@ def __init__( if err: raise ValueError(f"Invalid FQCN '{fqcn}': {err}") + # Determine the value of 'secure' based on configured connection_security in credentials. + # If configured, use it; otherwise keep the original value of 'secure'. + conn_security = credentials.get(DriverParams.CONNECTION_SECURITY.value) + if conn_security: + if conn_security == ConnectionSecurity.INSECURE: + secure = False + else: + secure = True + + self.logger.debug(f"connection secure: {secure}") self.my_info = FqcnInfo(FQCN.normalize(fqcn)) self.secure = secure self.logger.debug(f"{self.my_info.fqcn}: max_msg_size={self.max_msg_size}") @@ -396,6 +406,7 @@ def __init__( self.communicator.register_message_receiver(app_id=self.APP_ID, receiver=self) self.communicator.register_monitor(monitor=self) self.req_reg = Registry() + self.in_filter_reg = Registry() # for any incoming messages self.in_req_filter_reg = Registry() # for request received self.out_reply_filter_reg = Registry() # for reply going out self.out_req_filter_reg = Registry() # for request sent @@ -981,6 +992,11 @@ def decrypt_payload(self, message: Message): if len(message.payload) != payload_len: raise RuntimeError(f"Payload size changed after decryption {len(message.payload)} <> {payload_len}") + def add_incoming_filter(self, channel: str, topic: str, cb, *args, **kwargs): + if not callable(cb): + raise ValueError(f"specified incoming_filter {type(cb)} is not callable") + self.in_filter_reg.append(channel, topic, Callback(cb, args, kwargs)) + def add_incoming_request_filter(self, channel: str, topic: str, cb, *args, **kwargs): if not callable(cb): raise ValueError(f"specified incoming_request_filter {type(cb)} is not callable") @@ -1846,6 +1862,19 @@ def _process_received_msg(self, endpoint: Endpoint, connection: Connection, mess category=self._stats_category(message), counter_name=_CounterName.RECEIVED ) + # invoke incoming filters + channel = message.get_header(MessageHeaderKey.CHANNEL, "") + topic = message.get_header(MessageHeaderKey.TOPIC, "") + in_filters = self.in_filter_reg.find(channel, topic) + if in_filters: + self.logger.debug(f"{self.my_info.fqcn}: invoking incoming filters") + assert isinstance(in_filters, list) + for f in in_filters: + assert isinstance(f, Callback) + reply = self._try_cb(message, f.cb, *f.args, **f.kwargs) + if reply: + return reply + if msg_type == MessageType.REQ and self.message_interceptor is not None: reply = self._try_cb( message, self.message_interceptor, *self.message_interceptor_args, **self.message_interceptor_kwargs diff --git a/nvflare/fuel/f3/cellnet/net_agent.py b/nvflare/fuel/f3/cellnet/net_agent.py index 6b8285a163..bc2f8d22bc 100644 --- a/nvflare/fuel/f3/cellnet/net_agent.py +++ b/nvflare/fuel/f3/cellnet/net_agent.py @@ -406,7 +406,12 @@ def get_peers(self, target_fqcn: str) -> (Union[None, dict], List[str]): @staticmethod def _connector_info(info: ConnectorData) -> dict: - return {"url": info.connect_url, "handle": info.handle, "type": "connector" if info.active else "listener"} + return { + "url": info.connect_url, + "handle": info.handle, + "type": "connector" if info.active else "listener", + "params": info.params, + } def _get_connectors(self) -> dict: cell = self.cell @@ -548,9 +553,6 @@ def stop(self): self.close() def stop_cell(self, target: str) -> str: - # if self.cell.get_fqcn() == target: - # self.stop() - # return ReturnCode.OK reply = self.cell.send_request( channel=_CHANNEL, topic=_TOPIC_STOP_CELL, request=Message(), target=target, timeout=1.0 ) diff --git a/nvflare/fuel/f3/communicator.py b/nvflare/fuel/f3/communicator.py index 04d7dc7217..02714fd84e 100644 --- a/nvflare/fuel/f3/communicator.py +++ b/nvflare/fuel/f3/communicator.py @@ -154,7 +154,7 @@ def register_message_receiver(self, app_id: int, receiver: MessageReceiver): self.conn_manager.register_message_receiver(app_id, receiver) - def add_connector(self, url: str, mode: Mode, secure: bool = False) -> str: + def add_connector(self, url: str, mode: Mode, secure: bool = False) -> (str, dict): """Load a connector. The driver is selected based on the URL Args: @@ -163,7 +163,7 @@ def add_connector(self, url: str, mode: Mode, secure: bool = False) -> str: secure: True if SSL is required. Returns: - A handle that can be used to delete connector + A tuple of (A handle that can be used to delete connector, connector params) Raises: CommError: If any errors @@ -177,9 +177,9 @@ def add_connector(self, url: str, mode: Mode, secure: bool = False) -> str: raise CommError(CommError.NOT_SUPPORTED, f"No driver found for URL {url}") params = parse_url(url) - return self.add_connector_advanced(driver_class(), mode, params, secure, False) + return self.add_connector_advanced(driver_class(), mode, params, secure, False), params - def start_listener(self, scheme: str, resources: dict) -> (str, str): + def start_listener(self, scheme: str, resources: dict) -> (str, str, dict): """Add and start a connector in passive mode on an address selected by the driver. Args: @@ -187,7 +187,7 @@ def start_listener(self, scheme: str, resources: dict) -> (str, str): resources: User specified resources like host and port ranges Returns: - A tuple with connector handle and connect url + A tuple with connector handle and connect url, and connection params Raises: CommError: If any errors like invalid host or port not available @@ -205,7 +205,7 @@ def start_listener(self, scheme: str, resources: dict) -> (str, str): handle = self.add_connector_advanced(driver_class(), Mode.PASSIVE, params, False, True) - return handle, connect_url + return handle, connect_url, params def add_connector_advanced( self, driver: Driver, mode: Mode, params: dict, secure: bool, start: bool = False @@ -229,9 +229,7 @@ def add_connector_advanced( if self.local_endpoint.conn_props: params.update(self.local_endpoint.conn_props) - if secure: - params[DriverParams.SECURE] = secure - + params[DriverParams.SECURE] = secure handle = self.conn_manager.add_connector(driver, params, mode) if not start: diff --git a/nvflare/fuel/f3/drivers/driver_params.py b/nvflare/fuel/f3/drivers/driver_params.py index fab0367897..54118855e3 100644 --- a/nvflare/fuel/f3/drivers/driver_params.py +++ b/nvflare/fuel/f3/drivers/driver_params.py @@ -33,12 +33,21 @@ class DriverParams(str, Enum): SERVER_KEY = "server_key" CLIENT_CERT = "client_cert" CLIENT_KEY = "client_key" + CONNECTION_SECURITY = "connection_security" + CUSTOM_CA_CERT = "custom_ca_cert" SECURE = "secure" PORTS = "ports" SOCKET = "socket" LOCAL_ADDR = "local_addr" PEER_ADDR = "peer_addr" PEER_CN = "peer_cn" + IMPLEMENTED_CONN_SEC = "implemented_conn_sec" + + +class ConnectionSecurity: + INSECURE = "insecure" + TLS = "tls" + MTLS = "mtls" class DriverCap(str, Enum): diff --git a/nvflare/fuel/f3/drivers/grpc/utils.py b/nvflare/fuel/f3/drivers/grpc/utils.py index 50450ca7f5..bbad1a522f 100644 --- a/nvflare/fuel/f3/drivers/grpc/utils.py +++ b/nvflare/fuel/f3/drivers/grpc/utils.py @@ -14,7 +14,7 @@ import grpc from nvflare.fuel.f3.comm_config import CommConfigurator -from nvflare.fuel.f3.drivers.driver_params import DriverParams +from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams def use_aio_grpc(): @@ -23,12 +23,32 @@ def use_aio_grpc(): def get_grpc_client_credentials(params: dict): - root_cert = _read_file(params.get(DriverParams.CA_CERT.value)) - cert_chain = _read_file(params.get(DriverParams.CLIENT_CERT)) - private_key = _read_file(params.get(DriverParams.CLIENT_KEY)) - return grpc.ssl_channel_credentials( - certificate_chain=cert_chain, private_key=private_key, root_certificates=root_cert - ) + conn_security = params.get(DriverParams.CONNECTION_SECURITY.value, ConnectionSecurity.MTLS) + if conn_security == ConnectionSecurity.TLS: + # One-way SSL + # For one-way SSL, only CA cert is needed, and no need for client cert and key. + # We try to use custom CA cert if it's provided. This is because the client may connect to ALB or proxy + # that provides its CA cert to the client. + # If the custom CA cert is not provided, we'll use Flare provisioned CA cert. + params[DriverParams.IMPLEMENTED_CONN_SEC] = "Client TLS: Custom CA Cert used" + root_cert_file = params.get(DriverParams.CUSTOM_CA_CERT) + if not root_cert_file: + params[DriverParams.IMPLEMENTED_CONN_SEC] = "Client TLS: Flare CA Cert used" + root_cert_file = params.get(DriverParams.CA_CERT.value) + if not root_cert_file: + raise RuntimeError(f"cannot get CA cert for one-way SSL: {params}") + root_cert = _read_file(root_cert_file) + return grpc.ssl_channel_credentials(root_certificates=root_cert) + else: + # For two-way SSL, we always use our own provisioned certs. + # In the future, we may change to also support other ways to get cert and key. + params[DriverParams.IMPLEMENTED_CONN_SEC] = "Client mTLS: Flare credentials used" + root_cert = _read_file(params.get(DriverParams.CA_CERT.value)) + cert_chain = _read_file(params.get(DriverParams.CLIENT_CERT)) + private_key = _read_file(params.get(DriverParams.CLIENT_KEY)) + return grpc.ssl_channel_credentials( + certificate_chain=cert_chain, private_key=private_key, root_certificates=root_cert + ) def get_grpc_server_credentials(params: dict): @@ -36,10 +56,18 @@ def get_grpc_server_credentials(params: dict): cert_chain = _read_file(params.get(DriverParams.SERVER_CERT)) private_key = _read_file(params.get(DriverParams.SERVER_KEY)) + conn_security = params.get(DriverParams.CONNECTION_SECURITY.value, ConnectionSecurity.MTLS) + require_client_auth = False if conn_security == ConnectionSecurity.TLS else True + + if require_client_auth: + params[DriverParams.IMPLEMENTED_CONN_SEC] = "Server mTLS: client auth required" + else: + params[DriverParams.IMPLEMENTED_CONN_SEC] = "Server TLS: client auth not required" + return grpc.ssl_server_credentials( [(private_key, cert_chain)], root_certificates=root_cert, - require_client_auth=True, + require_client_auth=require_client_auth, ) diff --git a/nvflare/fuel/f3/drivers/net_utils.py b/nvflare/fuel/f3/drivers/net_utils.py index 25f931d880..8aa4e1f60f 100644 --- a/nvflare/fuel/f3/drivers/net_utils.py +++ b/nvflare/fuel/f3/drivers/net_utils.py @@ -21,7 +21,7 @@ from urllib.parse import parse_qsl, urlencode, urlparse from nvflare.fuel.f3.comm_error import CommError -from nvflare.fuel.f3.drivers.driver_params import DriverParams +from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams from nvflare.fuel.utils.argument_utils import str2bool from nvflare.security.logging import secure_format_exception @@ -44,6 +44,7 @@ SSL_CLIENT_PRIVATE_KEY = "client.key" SSL_CLIENT_CERT = "client.crt" SSL_ROOT_CERT = "rootCA.pem" +CUSTOM_ROOT_CERT = "customRootCA.pem" def ssl_required(params: dict) -> bool: @@ -54,31 +55,52 @@ def ssl_required(params: dict) -> bool: def get_ssl_context(params: dict, ssl_server: bool) -> Optional[SSLContext]: if not ssl_required(params): + params[DriverParams.IMPLEMENTED_CONN_SEC.value] = "clear" return None - ca_path = params.get(DriverParams.CA_CERT.value) + conn_security = params.get(DriverParams.CONNECTION_SECURITY.value, ConnectionSecurity.MTLS) if ssl_server: + ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ca_path = params.get(DriverParams.CA_CERT.value) cert_path = params.get(DriverParams.SERVER_CERT.value) key_path = params.get(DriverParams.SERVER_KEY.value) + if conn_security == ConnectionSecurity.TLS: + # do not require client auth + ctx.verify_mode = ssl.CERT_NONE + params[DriverParams.IMPLEMENTED_CONN_SEC] = "Server TLS: client auth not required" + else: + ctx.verify_mode = ssl.CERT_REQUIRED + params[DriverParams.IMPLEMENTED_CONN_SEC] = "Server mTLS: client auth required" else: - cert_path = params.get(DriverParams.CLIENT_CERT.value) - key_path = params.get(DriverParams.CLIENT_KEY.value) + ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + ctx.verify_mode = ssl.CERT_REQUIRED + if conn_security == ConnectionSecurity.TLS: + # one-way SSL: use custom CA cert if provided + params[DriverParams.IMPLEMENTED_CONN_SEC] = "Client TLS: Custom CA Cert used" + ca_path = params.get(DriverParams.CUSTOM_CA_CERT) + if not ca_path: + # no custom CA cert: use provisioned CA cert + ca_path = params.get(DriverParams.CA_CERT.value) + params[DriverParams.IMPLEMENTED_CONN_SEC] = "Client TLS: Flare CA Cert used" + cert_path = None + key_path = None + else: + # two-way SSL: use provisioned cert + ca_path = params.get(DriverParams.CA_CERT.value) + cert_path = params.get(DriverParams.CLIENT_CERT.value) + key_path = params.get(DriverParams.CLIENT_KEY.value) + params[DriverParams.IMPLEMENTED_CONN_SEC] = "Client mTLS: Flare credentials used" - if not all([ca_path, cert_path, key_path]): + if not ca_path: scheme = params.get(DriverParams.SCHEME.value, "Unknown") role = "Server" if ssl_server else "Client" raise CommError(CommError.BAD_CONFIG, f"{role} certificate parameters are missing for scheme {scheme}") - if ssl_server: - ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - else: - ctx = ssl.create_default_context() - ctx.minimum_version = ssl.TLSVersion.TLSv1_2 - ctx.verify_mode = ssl.CERT_REQUIRED ctx.check_hostname = False ctx.load_verify_locations(ca_path) - ctx.load_cert_chain(certfile=cert_path, keyfile=key_path) + if cert_path: + ctx.load_cert_chain(certfile=cert_path, keyfile=key_path) return ctx @@ -269,38 +291,60 @@ def get_tcp_urls(scheme: str, resources: dict) -> (str, str): def enhance_credential_info(params: dict): - # must have CA + """Enhance the params by loading additional cert and key from the folder that contains the CA cert. + + This is necessary because the params initially only contains basic credentials: + - for server, only CA cert, and the server's cert and key; + - for client, only CA cert, the client's cert and key. + + However, a client could also behave like a server for other processes, and could have a server cert as well. + This function loads all certs and keys, regardless the role of the process. + + Args: + params: the dict that contains initial credentials + + Returns: None + """ + + # Must have CA since all other certs/keys are assumed to be in the same folder as the CA cert. ca_path = params.get(DriverParams.CA_CERT.value) if not ca_path: - return params + return # assume all SSL credential files are in the same folder with CA cert cred_folder = os.path.dirname(ca_path) client_cert_path = params.get(DriverParams.CLIENT_CERT.value) if not client_cert_path: - # see whether the file client cert file exists + # see whether the client cert file exists client_cert_path = os.path.join(cred_folder, SSL_CLIENT_CERT) if os.path.exists(client_cert_path): params[DriverParams.CLIENT_CERT.value] = client_cert_path client_key_path = params.get(DriverParams.CLIENT_KEY.value) if not client_key_path: - # see whether the file client key file exists + # see whether the client key file exists client_key_path = os.path.join(cred_folder, SSL_CLIENT_PRIVATE_KEY) if os.path.exists(client_key_path): params[DriverParams.CLIENT_KEY.value] = client_key_path server_cert_path = params.get(DriverParams.SERVER_CERT.value) if not server_cert_path: - # see whether the file client cert file exists + # see whether the server cert file exists server_cert_path = os.path.join(cred_folder, SSL_SERVER_CERT) if os.path.exists(server_cert_path): params[DriverParams.SERVER_CERT.value] = server_cert_path server_key_path = params.get(DriverParams.SERVER_KEY.value) if not server_key_path: - # see whether the file client key file exists + # see whether the server key file exists server_key_path = os.path.join(cred_folder, SSL_SERVER_PRIVATE_KEY) if os.path.exists(server_key_path): params[DriverParams.SERVER_KEY.value] = server_key_path + + custom_ca_cert_path = params.get(DriverParams.CUSTOM_CA_CERT.value) + if not custom_ca_cert_path: + # see whether the custom CA cert file exists + custom_ca_cert_path = os.path.join(cred_folder, CUSTOM_ROOT_CERT) + if os.path.exists(custom_ca_cert_path): + params[DriverParams.CUSTOM_CA_CERT.value] = custom_ca_cert_path diff --git a/nvflare/fuel/sec/authn.py b/nvflare/fuel/sec/authn.py new file mode 100644 index 0000000000..4bbe2f66b3 --- /dev/null +++ b/nvflare/fuel/sec/authn.py @@ -0,0 +1,34 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.fl_constant import CellMessageAuthHeaderKey +from nvflare.fuel.f3.message import Message + + +def add_authentication_headers(msg: Message, client_name: str, auth_token, token_signature): + """Add authentication headers to the specified message. + + Args: + msg: the message that the headers are added to + client_name: name of the client + auth_token: authentication token + token_signature: token signature + + Returns: + + """ + if client_name: + msg.set_header(CellMessageAuthHeaderKey.CLIENT_NAME, client_name) + + msg.set_header(CellMessageAuthHeaderKey.TOKEN, auth_token if auth_token else "NA") + msg.set_header(CellMessageAuthHeaderKey.TOKEN_SIGNATURE, token_signature if token_signature else "NA") diff --git a/nvflare/fuel/utils/pipe/cell_pipe.py b/nvflare/fuel/utils/pipe/cell_pipe.py index 9b62c8aeb6..1554a0dd44 100644 --- a/nvflare/fuel/utils/pipe/cell_pipe.py +++ b/nvflare/fuel/utils/pipe/cell_pipe.py @@ -17,13 +17,15 @@ import time from typing import Tuple, Union -from nvflare.apis.fl_constant import SystemVarName +from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst, SystemVarName +from nvflare.fuel.data_event.utils import get_scope_property from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.cellnet.cell import Message as CellMessage from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode from nvflare.fuel.f3.cellnet.net_agent import NetAgent from nvflare.fuel.f3.cellnet.utils import make_reply from nvflare.fuel.f3.drivers.driver_params import DriverParams +from nvflare.fuel.sec.authn import add_authentication_headers from nvflare.fuel.utils.attributes_exportable import ExportMode from nvflare.fuel.utils.config_service import search_file from nvflare.fuel.utils.constants import Mode @@ -112,6 +114,9 @@ class CellPipe(Pipe): _lock = threading.Lock() _cells_info = {} # (root_url, site_name, token) => _CellInfo + _auth_token = None + _token_signature = None + _site_name = None @classmethod def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_dir): @@ -131,6 +136,7 @@ def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_di """ with cls._lock: + cls._site_name = site_name cell_key = f"{root_url}.{site_name}.{token}" ci = cls._cells_info.get(cell_key) if not ci: @@ -144,6 +150,10 @@ def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_di DriverParams.CA_CERT.value: root_cert_path, } + conn_sec = get_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY) + if conn_sec: + credentials[DriverParams.CONNECTION_SECURITY.value] = conn_sec + cell = Cell( fqcn=_cell_fqcn(mode, site_name, token), root_url=root_url, @@ -151,11 +161,24 @@ def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_di credentials=credentials, create_internal_listener=False, ) + + # set filter to add additional auth headers + cell.core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=cls._add_auth_headers) + cell.core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=cls._add_auth_headers) + net_agent = NetAgent(cell) ci = _CellInfo(cell, net_agent) cls._cells_info[cell_key] = ci return ci + @classmethod + def _add_auth_headers(cls, message: CellMessage): + if not cls._auth_token: + cls._auth_token = get_scope_property(scope_name=cls._site_name, key=FLMetaKey.AUTH_TOKEN, default="NA") + cls._token_signature = get_scope_property(cls._site_name, FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA") + + add_authentication_headers(message, cls._site_name, cls._auth_token, cls._token_signature) + def __init__( self, mode: Mode, diff --git a/nvflare/ha/dummy_overseer_agent.py b/nvflare/ha/dummy_overseer_agent.py index d9afb3f0b2..90d0170cb0 100644 --- a/nvflare/ha/dummy_overseer_agent.py +++ b/nvflare/ha/dummy_overseer_agent.py @@ -26,7 +26,7 @@ class DummyOverseerAgent(OverseerAgent): SSID = "ebc6125d-0a56-4688-9b08-355fe9e4d61a" - def __init__(self, sp_end_point, heartbeat_interval=5): + def __init__(self, sp_end_point, heartbeat_interval=0.5): super().__init__() self._base_init(sp_end_point) diff --git a/nvflare/lighter/constants.py b/nvflare/lighter/constants.py index a6341d3abb..2ed5d4b63d 100644 --- a/nvflare/lighter/constants.py +++ b/nvflare/lighter/constants.py @@ -43,6 +43,8 @@ class PropKey: OVERSEER_END_POINT = "overseer_end_point" ADMIN_PORT = "admin_port" FED_LEARN_PORT = "fed_learn_port" + CONN_SECURITY = "connection_security" + CUSTOM_CA_CERT = "custom_ca_cert" class CtxKey(WorkDir, PropKey): @@ -61,6 +63,13 @@ class ProvisionMode: NORMAL = "normal" +class ConnSecurity: + CLEAR = "clear" + INSECURE = "insecure" + TLS = "tls" + MTLS = "mtls" + + class AdminRole: PROJECT_ADMIN = "project_admin" ORG_ADMIN = "org_admin" @@ -136,6 +145,7 @@ class ProvFileName: CHART_YAML = "Chart.yaml" VALUES_YAML = "values.yaml" HELM_CHART_TEMPLATES_DIR = "templates" + CUSTOM_CA_CERT_FILE_NAME = "customRootCA.pem" class CertFileBasename: diff --git a/nvflare/lighter/impl/static_file.py b/nvflare/lighter/impl/static_file.py index 742f2eaac2..76c678e600 100644 --- a/nvflare/lighter/impl/static_file.py +++ b/nvflare/lighter/impl/static_file.py @@ -15,11 +15,20 @@ import copy import json import os +import shutil import yaml from nvflare.lighter import utils -from nvflare.lighter.constants import CtxKey, OverseerRole, PropKey, ProvFileName, ProvisionMode, TemplateSectionKey +from nvflare.lighter.constants import ( + ConnSecurity, + CtxKey, + OverseerRole, + PropKey, + ProvFileName, + ProvisionMode, + TemplateSectionKey, +) from nvflare.lighter.entity import Participant from nvflare.lighter.spec import Builder, Project, ProvisionContext @@ -104,6 +113,25 @@ def _build_overseer(self, overseer: Participant, ctx: ProvisionContext): else: ctx[PropKey.OVERSEER_END_POINT] = f"{protocol}://{overseer.name}{api_root}" + @staticmethod + def _build_conn_properties(site: Participant, ctx: ProvisionContext, site_config: dict): + valid_values = [ConnSecurity.CLEAR, ConnSecurity.INSECURE, ConnSecurity.TLS, ConnSecurity.MTLS] + conn_security = site.get_prop_fb(PropKey.CONN_SECURITY) + if conn_security: + assert isinstance(conn_security, str) + conn_security = conn_security.lower() + + if conn_security not in valid_values: + raise ValueError(f"invalid connection_security '{conn_security}': must be in {valid_values}") + + if conn_security in [ConnSecurity.CLEAR, ConnSecurity.INSECURE]: + conn_security = ConnSecurity.INSECURE + site_config["connection_security"] = conn_security + + custom_ca_cert = site.get_prop_fb(PropKey.CUSTOM_CA_CERT) + if custom_ca_cert: + shutil.copyfile(custom_ca_cert, os.path.join(ctx.get_kit_dir(site), ProvFileName.CUSTOM_CA_CERT_FILE_NAME)) + def _build_server(self, server: Participant, ctx: ProvisionContext): project = ctx.get_project() config = ctx.json_load_template_section(TemplateSectionKey.FED_SERVER) @@ -118,6 +146,10 @@ def _build_server(self, server: Participant, ctx: ProvisionContext): server_0["admin_port"] = admin_port self._prepare_overseer_agent(server, config, OverseerRole.SERVER, ctx) + + # set up connection props + self._build_conn_properties(server, ctx, server_0) + utils.write(os.path.join(dest_dir, ProvFileName.FED_SERVER_JSON), json.dumps(config, indent=2), "t") replacement_dict = { @@ -193,6 +225,10 @@ def _build_client(self, client, ctx): self._prepare_overseer_agent(client, config, OverseerRole.CLIENT, ctx) + # set connection properties + client_conf = config["client"] + self._build_conn_properties(client, ctx, client_conf) + utils.write(os.path.join(dest_dir, ProvFileName.FED_CLIENT_JSON), json.dumps(config, indent=2), "t") if self.docker_image: diff --git a/nvflare/private/defs.py b/nvflare/private/defs.py index bb0f98683d..1723bcd49c 100644 --- a/nvflare/private/defs.py +++ b/nvflare/private/defs.py @@ -15,6 +15,8 @@ import time import uuid +from nvflare.apis.fl_constant import CellMessageAuthHeaderKey + # this import is to let existing scripts import from nvflare.private.defs from nvflare.fuel.f3.cellnet.defs import CellChannel, CellChannelTopic, SSLConstants # noqa: F401 from nvflare.fuel.f3.message import Message @@ -34,8 +36,8 @@ class TaskConstant(object): class EngineConstant(object): FEDERATE_CLIENT = "federate_client" - FL_TOKEN = "fl_token" - CLIENT_TOKEN_FILE = "client_token.txt" + AUTH_TOKEN = "auth_token" + AUTH_TOKEN_SIGNATURE = "auth_token_signature" ENGINE_TASK_NAME = "engine_task_name" @@ -138,10 +140,11 @@ class AppFolderConstants: class CellMessageHeaderKeys: - CLIENT_NAME = "client_name" + CLIENT_NAME = CellMessageAuthHeaderKey.CLIENT_NAME + TOKEN = CellMessageAuthHeaderKey.TOKEN + TOKEN_SIGNATURE = CellMessageAuthHeaderKey.TOKEN_SIGNATURE CLIENT_IP = "client_ip" PROJECT_NAME = "project_name" - TOKEN = "token" SSID = "ssid" UNAUTHENTICATED = "unauthenticated" JOB_ID = "job_id" @@ -150,6 +153,9 @@ class CellMessageHeaderKeys: ABORT_JOBS = "abort_jobs" +AUTH_CLIENT_NAME_FOR_SJ = "server_job" + + class JobFailureMsgKey: JOB_ID = "job_id" diff --git a/nvflare/private/fed/app/client/client_train.py b/nvflare/private/fed/app/client/client_train.py index 235631c978..f3677853c5 100644 --- a/nvflare/private/fed/app/client/client_train.py +++ b/nvflare/private/fed/app/client/client_train.py @@ -111,8 +111,6 @@ def main(args): client_engine.initialize_comm(federated_client.cell) with client_engine.new_context() as fl_ctx: - client_engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) - fl_ctx.set_prop( key=FLContextKey.CLIENT_CONFIG, value=deployer.client_config, @@ -136,6 +134,8 @@ def main(args): fl_ctx.set_prop(FLContextKey.ARGS, args, private=True, sticky=True) fl_ctx.set_prop(FLContextKey.SITE_OBJ, federated_client, private=True, sticky=True) + client_engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) + component_security_check(fl_ctx) client_engine.fire_event(EventType.BEFORE_CLIENT_REGISTER, fl_ctx) diff --git a/nvflare/private/fed/app/client/worker_process.py b/nvflare/private/fed/app/client/worker_process.py index 2f4da3db9e..b43133abe5 100644 --- a/nvflare/private/fed/app/client/worker_process.py +++ b/nvflare/private/fed/app/client/worker_process.py @@ -26,7 +26,6 @@ from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.log_utils import configure_logging, get_script_logger -from nvflare.private.defs import EngineConstant from nvflare.private.fed.app.fl_conf import FLClientStarterConfiger from nvflare.private.fed.app.utils import monitor_parent_process from nvflare.private.fed.client.client_app_runner import ClientAppRunner @@ -106,11 +105,12 @@ def main(args): federated_client = deployer.create_fed_client(args) federated_client.status = ClientStatus.STARTING + federated_client.communicator.set_auth(args.client_name, args.token, args.token_signature, args.ssid) federated_client.token = args.token + federated_client.token_signature = args.token_signature federated_client.ssid = args.ssid federated_client.client_name = args.client_name federated_client.fl_ctx.set_prop(FLContextKey.CLIENT_NAME, args.client_name, private=False) - federated_client.fl_ctx.set_prop(EngineConstant.FL_TOKEN, args.token, private=False) federated_client.fl_ctx.set_prop(FLContextKey.WORKSPACE_ROOT, args.workspace, private=True) client_app_runner = ClientAppRunner(time_out=kv_list.get("app_runner_timeout", 60.0)) @@ -145,7 +145,8 @@ def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True) parser.add_argument("--startup", "-w", type=str, help="startup folder", required=True) - parser.add_argument("--token", "-t", type=str, help="token", required=True) + parser.add_argument("--token", "-t", type=str, help="auth token", required=True) + parser.add_argument("--token_signature", "-ts", type=str, help="auth token signature", required=True) parser.add_argument("--ssid", "-d", type=str, help="ssid", required=True) parser.add_argument("--job_id", "-n", type=str, help="job_id", required=True) parser.add_argument("--client_name", "-c", type=str, help="client name", required=True) diff --git a/nvflare/private/fed/app/deployer/simulator_deployer.py b/nvflare/private/fed/app/deployer/simulator_deployer.py index cae1858bb3..bb8c615af4 100644 --- a/nvflare/private/fed/app/deployer/simulator_deployer.py +++ b/nvflare/private/fed/app/deployer/simulator_deployer.py @@ -100,7 +100,7 @@ def _create_client_cell(self, client_config, client_name, federated_client): ) cell.start() federated_client.cell = cell - federated_client.communicator.cell = cell + federated_client.communicator.set_cell(cell) # if self.engine: # self.engine.admin_agent.register_cell_cb() diff --git a/nvflare/private/fed/app/server/runner_process.py b/nvflare/private/fed/app/server/runner_process.py index 0777152be2..c0ac5aa576 100644 --- a/nvflare/private/fed/app/server/runner_process.py +++ b/nvflare/private/fed/app/server/runner_process.py @@ -23,11 +23,13 @@ from nvflare.apis.fl_constant import ConfigVarName, JobConstants, SiteType, SystemConfigs from nvflare.apis.workspace import Workspace from nvflare.fuel.common.excepts import ConfigError +from nvflare.fuel.f3.message import Message as CellMessage from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm +from nvflare.fuel.sec.authn import add_authentication_headers from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.log_utils import configure_logging, get_script_logger -from nvflare.private.defs import AppFolderConstants +from nvflare.private.defs import AUTH_CLIENT_NAME_FOR_SJ, AppFolderConstants, CellMessageHeaderKeys from nvflare.private.fed.app.fl_conf import FLServerStarterConfiger from nvflare.private.fed.app.utils import monitor_parent_process from nvflare.private.fed.server.server_app_runner import ServerAppRunner @@ -108,6 +110,11 @@ def main(args): server.cell = server.create_job_cell( args.job_id, args.root_url, args.parent_url, secure_train, server_config ) + + # set filter to add additional auth headers + server.cell.core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=_add_auth_headers, config=args) + server.cell.core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=_add_auth_headers, config=args) + server.server_state = HotState(host=args.host, port=args.port, ssid=args.ssid) snapshot = None @@ -138,15 +145,24 @@ def main(args): raise e +def _add_auth_headers(message: CellMessage, config): + message.set_header(CellMessageHeaderKeys.SSID, config.ssid) + add_authentication_headers( + message, + client_name=AUTH_CLIENT_NAME_FOR_SJ, + auth_token=config.job_id, + token_signature=config.token_signature, + ) + + def parse_arguments(): """FL Server program starting point.""" parser = argparse.ArgumentParser() parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True) - parser.add_argument( - "--fed_server", "-s", type=str, help="an aggregation server specification json file", required=True - ) + parser.add_argument("--fed_server", "-s", type=str, help="server config json file", required=True) parser.add_argument("--app_root", "-r", type=str, help="App Root", required=True) parser.add_argument("--job_id", "-n", type=str, help="job id", required=True) + parser.add_argument("--token_signature", "-ts", type=str, help="auth token signature", required=True) parser.add_argument("--root_url", "-u", type=str, help="root_url", required=True) parser.add_argument("--host", "-host", type=str, help="server host", required=True) parser.add_argument("--port", "-port", type=str, help="service port", required=True) diff --git a/nvflare/private/fed/app/simulator/simulator_worker.py b/nvflare/private/fed/app/simulator/simulator_worker.py index baae171ce9..33a663d51e 100644 --- a/nvflare/private/fed/app/simulator/simulator_worker.py +++ b/nvflare/private/fed/app/simulator/simulator_worker.py @@ -210,7 +210,13 @@ def _create_client_cell(self, federated_client, root_url, parent_url): cell.start() mpm.add_cleanup_cb(cell.stop) federated_client.cell = cell - federated_client.communicator.cell = cell + federated_client.communicator.set_cell(cell) + federated_client.communicator.set_auth( + client_name=federated_client.client_name, + token=federated_client.token, + token_signature="NA", + ssid="NA", + ) start = time.time() while not cell.is_cell_connected(FQCN.ROOT_SERVER): diff --git a/nvflare/private/fed/client/client_app_runner.py b/nvflare/private/fed/client/client_app_runner.py index 5dca67bf14..6fda61bc7e 100644 --- a/nvflare/private/fed/client/client_app_runner.py +++ b/nvflare/private/fed/client/client_app_runner.py @@ -20,7 +20,7 @@ from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.log_utils import get_module_logger from nvflare.private.admin_defs import Message -from nvflare.private.defs import CellChannel, EngineConstant, RequestHeader, TrainingTopic, new_cell_message +from nvflare.private.defs import CellChannel, RequestHeader, TrainingTopic, new_cell_message from nvflare.private.fed.app.fl_conf import create_privacy_manager from nvflare.private.fed.client.client_json_config import ClientJsonConfigurator from nvflare.private.fed.client.client_run_manager import ClientRunManager @@ -71,7 +71,6 @@ def start_run(self, app_root, args, config_folder, federated_client, secure_trai @staticmethod def _set_fl_context(fl_ctx: FLContext, app_root, args, workspace, secure_train): fl_ctx.set_prop(FLContextKey.CLIENT_NAME, args.client_name, private=False) - fl_ctx.set_prop(EngineConstant.FL_TOKEN, args.token, private=False) fl_ctx.set_prop(FLContextKey.WORKSPACE_ROOT, args.workspace, private=True) fl_ctx.set_prop(FLContextKey.ARGS, args, sticky=True) fl_ctx.set_prop(FLContextKey.APP_ROOT, app_root, private=True, sticky=True) diff --git a/nvflare/private/fed/client/client_engine.py b/nvflare/private/fed/client/client_engine.py index bde557f7c8..eeb9044634 100644 --- a/nvflare/private/fed/client/client_engine.py +++ b/nvflare/private/fed/client/client_engine.py @@ -32,7 +32,7 @@ from nvflare.fuel.f3.message import Message as CellMessage from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.private.aux_runner import AuxMsgTarget, AuxRunner -from nvflare.private.defs import ERROR_MSG_PREFIX, ClientStatusKey, EngineConstant, new_cell_message +from nvflare.private.defs import ERROR_MSG_PREFIX, ClientStatusKey, new_cell_message from nvflare.private.event import fire_event from nvflare.private.fed.server.job_meta_validator import JobMetaValidator from nvflare.private.fed.utils.app_deployer import AppDeployer @@ -371,23 +371,6 @@ def notify_job_status(self, job_id: str, job_status): def get_client_name(self): return self.client.client_name - def _write_token_file(self, job_id, open_port): - token_file = os.path.join(self.args.workspace, EngineConstant.CLIENT_TOKEN_FILE) - if os.path.exists(token_file): - os.remove(token_file) - with open(token_file, "wt") as f: - f.write( - "%s\n%s\n%s\n%s\n%s\n%s\n" - % ( - self.client.token, - self.client.ssid, - job_id, - self.client.client_name, - open_port, - list(self.client.servers.values())[0]["target"], - ) - ) - def abort_app(self, job_id: str) -> str: status = self.client_executor.get_status(job_id) if status == ClientStatus.STOPPED: diff --git a/nvflare/private/fed/client/client_executor.py b/nvflare/private/fed/client/client_executor.py index 246e456cb2..b1f980f622 100644 --- a/nvflare/private/fed/client/client_executor.py +++ b/nvflare/private/fed/client/client_executor.py @@ -19,8 +19,9 @@ from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey, RunProcessKey, SystemConfigs from nvflare.apis.fl_context import FLContext -from nvflare.apis.job_launcher_spec import JobLauncherSpec +from nvflare.apis.job_launcher_spec import JobLauncherSpec, JobProcessArgs from nvflare.apis.resource_manager_spec import ResourceManagerSpec +from nvflare.apis.workspace import Workspace from nvflare.fuel.common.exit_codes import PROCESS_EXIT_REASON, ProcessExitCode from nvflare.fuel.f3.cellnet.core_cell import FQCN from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode @@ -153,7 +154,7 @@ def start_app( Args: client: the FL client object job_id: the job_id - job_meta: job meta data + job_meta: job metadata args: admin command arguments for starting the worker process allocated_resource: allocated resources token: token from resource manager @@ -162,6 +163,45 @@ def start_app( """ job_launcher: JobLauncherSpec = get_job_launcher(job_meta, fl_ctx) + + # prepare command args for the job process + workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) + server_config = fl_ctx.get_prop(FLContextKey.SERVER_CONFIG) + if not server_config: + raise RuntimeError(f"missing {FLContextKey.SERVER_CONFIG} in FL context") + service = server_config[0].get("service", {}) + if not isinstance(service, dict): + raise RuntimeError(f"expect server config data to be dict but got {type(service)}") + command_options = "" + for t in args.set: + command_options += " " + t + command_options += " print_conf=True" + args.set.append("print_conf=True") + + # Job process args are the same for all job launchers! Letting each job launcher compute the job + # args would be error-prone and would require access to internal server components (e.g. cell). + # We prepare job process args here and save the prepared result in the fl_ctx. + # This way, the job launcher won't need to compute these args again. + # The job launcher will only need to use the args properly to launch the job process! + # + # Each arg is a tuple of (arg_option, arg_value). + # Note that the arg_option is fixed for each arg, and is not launcher specific! + job_args = { + JobProcessArgs.EXE_MODULE: ("-m", "nvflare.private.fed.app.client.worker_process"), + JobProcessArgs.JOB_ID: ("-n", job_id), + JobProcessArgs.CLIENT_NAME: ("-c", client.client_name), + JobProcessArgs.AUTH_TOKEN: ("-t", client.token), + JobProcessArgs.TOKEN_SIGNATURE: ("-ts", client.token_signature), + JobProcessArgs.SSID: ("-d", client.ssid), + JobProcessArgs.WORKSPACE: ("-m", args.workspace), + JobProcessArgs.STARTUP_DIR: ("-w", workspace_obj.get_startup_kit_dir()), + JobProcessArgs.PARENT_URL: ("-p", str(client.cell.get_internal_listener_url())), + JobProcessArgs.SCHEME: ("-scheme", service.get("scheme", "grpc")), + JobProcessArgs.TARGET: ("-g", service.get("target")), + JobProcessArgs.STARTUP_CONFIG_FILE: ("-s", "fed_client.json"), + JobProcessArgs.OPTIONS: ("--set", command_options), + } + fl_ctx.set_prop(key=FLContextKey.JOB_PROCESS_ARGS, value=job_args, private=True, sticky=False) job_handle = job_launcher.launch_job(job_meta, fl_ctx) self.logger.info(f"Launch job_id: {job_id} with job launcher: {type(job_launcher)} ") diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index 6c31f31b90..50d86cf30e 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -20,21 +20,23 @@ from nvflare.apis.event_type import EventType from nvflare.apis.filter import Filter -from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_constant import FLContextKey, FLMetaKey from nvflare.apis.fl_constant import ReturnCode as ShareableRC from nvflare.apis.fl_constant import SecureTrainConst, ServerCommandKey, ServerCommandNames from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import FLCommunicationError from nvflare.apis.shareable import Shareable from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx +from nvflare.fuel.data_event.utils import get_scope_property, set_scope_property from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.cellnet.defs import IdentityChallengeKey, MessageHeaderKey, ReturnCode from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.f3.cellnet.utils import format_size +from nvflare.fuel.f3.message import Message as CellMessage +from nvflare.fuel.sec.authn import add_authentication_headers from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.private.defs import CellChannel, CellChannelTopic, CellMessageHeaderKeys, SpecialTaskName, new_cell_message from nvflare.private.fed.client.client_engine_internal_spec import ClientEngineInternalSpec -from nvflare.private.fed.utils.fed_utils import get_scope_prop from nvflare.private.fed.utils.identity_utils import IdentityAsserter, IdentityVerifier, load_crt_bytes from nvflare.security.logging import secure_format_exception @@ -94,8 +96,40 @@ def __init__( self.timeout = timeout self.maint_msg_timeout = maint_msg_timeout + # token and token_signature are issued by the Server after the client is authenticated + # they are added to every message going to the server as proof of authentication + self.token = None + self.token_signature = None + self.ssid = None + self.client_name = None + self.logger = get_obj_logger(self) + def set_auth(self, client_name, token, token_signature, ssid): + self.ssid = ssid + self.token_signature = token_signature + self.token = token + self.client_name = client_name + + # put auth properties in databus so that they can be used elsewhere + set_scope_property(scope_name=client_name, key=FLMetaKey.AUTH_TOKEN, value=token) + set_scope_property(scope_name=client_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, value=token_signature) + + def set_cell(self, cell): + self.cell = cell + + # set filter to add additional auth headers + cell.core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=self._add_auth_headers) + cell.core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=self._add_auth_headers) + + def _add_auth_headers(self, message: CellMessage): + if self.ssid: + message.set_header(CellMessageHeaderKeys.SSID, self.ssid) + + # Note that auth info (client_name, token and signature) is not available until the client is fully + # authenticated. + add_authentication_headers(message, self.client_name, self.token, self.token_signature) + def _challenge_server(self, client_name, expected_host, root_cert_file): # ask server for its info and make sure that it matches expected host my_nonce = str(uuid.uuid4()) @@ -207,7 +241,7 @@ def client_registration(self, client_name, project_name, fl_ctx: FLContext): # the provision was done with an old version # to be backward compatible, we expect the host to be the server host we connected to # we get the host name from DataBus! - expected_host = get_scope_prop(scope_name=client_name, key=FLContextKey.SERVER_HOST_NAME) + expected_host = get_scope_property(scope_name=client_name, key=FLContextKey.SERVER_HOST_NAME) if not expected_host: raise RuntimeError("cannot determine expected_host") @@ -260,17 +294,19 @@ def client_registration(self, client_name, project_name, fl_ctx: FLContext): raise FLCommunicationError("error:client_registration " + reason) token = result.get_header(CellMessageHeaderKeys.TOKEN) + token_signature = result.get_header(CellMessageHeaderKeys.TOKEN_SIGNATURE, "NA") ssid = result.get_header(CellMessageHeaderKeys.SSID) if not token and not self.should_stop: time.sleep(self.client_register_interval) else: + self.set_auth(client_name, token, token_signature, ssid) break except Exception as ex: traceback.print_exc() raise FLCommunicationError("error:client_registration", ex) - return token, ssid + return token, token_signature, ssid def pull_task(self, project_name, token, ssid, fl_ctx: FLContext, timeout=None): """Get a task from server. @@ -293,9 +329,6 @@ def pull_task(self, project_name, token, ssid, fl_ctx: FLContext, timeout=None): client_name = fl_ctx.get_identity_name() task_message = new_cell_message( { - CellMessageHeaderKeys.TOKEN: token, - CellMessageHeaderKeys.CLIENT_NAME: client_name, - CellMessageHeaderKeys.SSID: ssid, CellMessageHeaderKeys.PROJECT_NAME: project_name, }, shareable, @@ -370,9 +403,6 @@ def submit_update( task_message = new_cell_message( { - CellMessageHeaderKeys.TOKEN: token, - CellMessageHeaderKeys.CLIENT_NAME: client_name, - CellMessageHeaderKeys.SSID: ssid, CellMessageHeaderKeys.PROJECT_NAME: project_name, }, shareable, @@ -420,9 +450,6 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext): client_name = fl_ctx.get_identity_name() quit_message = new_cell_message( { - CellMessageHeaderKeys.TOKEN: token, - CellMessageHeaderKeys.CLIENT_NAME: client_name, - CellMessageHeaderKeys.SSID: ssid, CellMessageHeaderKeys.PROJECT_NAME: task_name, }, shareable, @@ -462,9 +489,6 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C job_ids = engine.get_all_job_ids() heartbeat_message = new_cell_message( { - CellMessageHeaderKeys.TOKEN: token, - CellMessageHeaderKeys.SSID: ssid, - CellMessageHeaderKeys.CLIENT_NAME: client_name, CellMessageHeaderKeys.PROJECT_NAME: task_name, CellMessageHeaderKeys.JOB_IDS: job_ids, }, diff --git a/nvflare/private/fed/client/fed_client_base.py b/nvflare/private/fed/client/fed_client_base.py index ed10e65f9f..6a40717457 100644 --- a/nvflare/private/fed/client/fed_client_base.py +++ b/nvflare/private/fed/client/fed_client_base.py @@ -24,6 +24,7 @@ from nvflare.apis.overseer_spec import SP from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal +from nvflare.fuel.data_event.utils import set_scope_property from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.f3.cellnet.net_agent import NetAgent @@ -31,8 +32,6 @@ from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.log_utils import get_obj_logger -from nvflare.private.defs import EngineConstant -from nvflare.private.fed.utils.fed_utils import set_scope_prop from nvflare.security.logging import secure_format_exception from .client_status import ClientStatus @@ -77,6 +76,7 @@ def __init__( self.client_name = client_name self.token = None + self.token_signature = None self.ssid = None self.client_args = client_args self.servers = server_args @@ -154,7 +154,7 @@ def set_sp(self, project_name, sp: SP): if server != location: # The SP name is the server host name that we will connect to. # Save this name for this client so that it can be checked by others - set_scope_prop(scope_name=self.client_name, value=sp.name, key=FLContextKey.SERVER_HOST_NAME) + set_scope_property(scope_name=self.client_name, value=sp.name, key=FLContextKey.SERVER_HOST_NAME) self.servers[project_name]["target"] = location self.sp_established = True @@ -222,6 +222,10 @@ def _create_cell(self, location, scheme): DriverParams.CLIENT_CERT.value: ssl_cert, DriverParams.CLIENT_KEY.value: private_key, } + conn_security = self.client_args.get(SecureTrainConst.CONNECTION_SECURITY) + if conn_security: + credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security + set_scope_property(self.client_name, SecureTrainConst.CONNECTION_SECURITY, conn_security) else: credentials = {} @@ -235,7 +239,7 @@ def _create_cell(self, location, scheme): parent_url=parent_url, ) self.cell.start() - self.communicator.cell = self.cell + self.communicator.set_cell(self.cell) self.net_agent = NetAgent(self.cell) mpm.add_cleanup_cb(self.net_agent.close) mpm.add_cleanup_cb(self.cell.stop) @@ -272,16 +276,17 @@ def client_register(self, project_name, fl_ctx: FLContext): Args: project_name: FL study project name. - register_data: customer defined client register data (in a dict) fl_ctx: FLContext """ if not self.token: try: - self.token, self.ssid = self.communicator.client_registration(self.client_name, project_name, fl_ctx) + self.token, self.token_signature, self.ssid = self.communicator.client_registration( + self.client_name, project_name, fl_ctx + ) + if self.token is not None: self.fl_ctx.set_prop(FLContextKey.CLIENT_NAME, self.client_name, private=False) - self.fl_ctx.set_prop(EngineConstant.FL_TOKEN, self.token, private=False) self.logger.info( "Successfully registered client:{} for project {}. Token:{} SSID:{}".format( self.client_name, project_name, self.token, self.ssid diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index 21820212f6..354df3fda3 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -52,6 +52,7 @@ from nvflare.fuel.f3.cellnet.net_agent import NetAgent from nvflare.fuel.f3.drivers.driver_params import DriverParams from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm +from nvflare.fuel.sec.authn import add_authentication_headers from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.log_utils import get_obj_logger @@ -167,6 +168,10 @@ def deploy(self, args, grpc_args=None, secure_train=False): DriverParams.SERVER_CERT.value: ssl_cert, DriverParams.SERVER_KEY.value: private_key, } + + conn_security = grpc_args.get(SecureTrainConst.CONNECTION_SECURITY) + if conn_security: + credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security else: credentials = {} parent_url = None @@ -253,7 +258,7 @@ def fl_shutdown(self): self.shutdown = True start = time.time() while self.client_manager.clients: - # Wait for the clients to shutdown and quite first. + # Wait for the clients to shut down and quite first. time.sleep(0.1) if time.time() - start > self.shutdown_period: self.logger.info("There are still clients connected. But shutdown the server after timeout.") @@ -334,6 +339,11 @@ def __init__( self.name_to_reg = {} self.id_asserter = None + # these are used when the server sends a message to itself. + self.my_own_auth_client_name = "server" + self.my_own_token = "server" + self.my_own_token_signature = None + def _register_cellnet_cbs(self): self.cell.register_request_cb( channel=CellChannel.SERVER_MAIN, @@ -382,6 +392,80 @@ def _register_cellnet_cbs(self): reg_checker = threading.Thread(target=self._check_regs, daemon=True) reg_checker.start() + def _add_auth_headers(self, message: Message): + """Add auth headers to the messages sent by the server to itself. + This is such that no one can fake a message to pretend it's from the server to the server. + Args: + message: the message for which to add the headers + Returns: None + """ + origin = message.get_header(MessageHeaderKey.ORIGIN) + dest = message.get_header(MessageHeaderKey.DESTINATION) + if origin == FQCN.ROOT_SERVER and dest == origin: + if not self.my_own_token_signature: + self.my_own_token_signature = self.sign_auth_token(self.my_own_auth_client_name, self.my_own_token) + add_authentication_headers( + message, self.my_own_auth_client_name, self.my_own_token, self.my_own_token_signature + ) + + def _validate_auth_headers(self, message: Message): + """Validate auth headers from messages that go through the server. + Args: + message: the message to validate + Returns: + """ + headers = message.headers + self.logger.debug(f"**** _validate_auth_headers: {headers=}") + topic = message.get_header(MessageHeaderKey.TOPIC) + channel = message.get_header(MessageHeaderKey.CHANNEL) + + origin = message.get_header(MessageHeaderKey.ORIGIN) + + if topic in [CellChannelTopic.Register, CellChannelTopic.Challenge] and channel == CellChannel.SERVER_MAIN: + # skip: client not registered yet + self.logger.debug(f"skip special message {topic=} {channel=}") + return None + + client_name = message.get_header(CellMessageHeaderKeys.CLIENT_NAME) + err_text = f"unauthenticated msg ({channel=} {topic=}) received from {origin}" + if not client_name: + err = "missing client name" + self.logger.error(f"{err_text}: {err}") + return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) + + token = message.get_header(CellMessageHeaderKeys.TOKEN) + if not token: + err = "missing auth token" + self.logger.error(f"{err_text}: {err}") + return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) + + signature = message.get_header(CellMessageHeaderKeys.TOKEN_SIGNATURE) + if not signature: + err = "missing auth token signature" + self.logger.error(f"{err_text}: {err}") + return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) + + if not self.verify_auth_token(client_name, token, signature): + err = "invalid auth token signature" + self.logger.error(f"{err_text}: {err}") + return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) + + # all good + self.logger.debug(f"auth valid from {origin}: {topic=} {channel=}") + return None + + def sign_auth_token(self, client_name: str, token: str): + id_asserter = self._get_id_asserter() + if not id_asserter: + return "NA" + return id_asserter.sign(client_name + token, return_str=True) + + def verify_auth_token(self, client_name: str, token: str, signature): + id_asserter = self._get_id_asserter() + if not id_asserter: + return True + return id_asserter.verify_signature(client_name + token, signature) + def _check_regs(self): while True: with self.reg_lock: @@ -456,8 +540,13 @@ def create_job_cell(self, job_id, root_url, parent_url, secure_train, server_con DriverParams.SERVER_CERT.value: ssl_cert, DriverParams.SERVER_KEY.value: private_key, } + + conn_security = server_config.get(SecureTrainConst.CONNECTION_SECURITY) + if conn_security: + credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security else: credentials = {} + cell = Cell( fqcn=my_fqcn, root_url=root_url, @@ -625,8 +714,10 @@ def register_client(self, request: Message) -> Message: if self.admin_server: self.admin_server.client_heartbeat(client.token, client.name, client.get_fqcn()) + token_signature = self.sign_auth_token(client.name, client.token) headers = { CellMessageHeaderKeys.TOKEN: client.token, + CellMessageHeaderKeys.TOKEN_SIGNATURE: token_signature, CellMessageHeaderKeys.SSID: self.server_state.ssid, } else: @@ -897,6 +988,18 @@ def deploy(self, args, grpc_args=None, secure_train=False): self.engine.initialize_comm(self.cell) self._register_cellnet_cbs() + if secure_train: + core_cell = self.cell.core_cell + core_cell.add_incoming_filter( + channel="*", + topic="*", + cb=self._validate_auth_headers, + ) + + # set filter to add additional auth headers + core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=self._add_auth_headers) + core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=self._add_auth_headers) + self.overseer_agent.start(self.overseer_callback) def _init_agent(self, args=None): diff --git a/nvflare/private/fed/server/server_engine.py b/nvflare/private/fed/server/server_engine.py index f21d836322..fe59de77bc 100644 --- a/nvflare/private/fed/server/server_engine.py +++ b/nvflare/private/fed/server/server_engine.py @@ -40,7 +40,7 @@ from nvflare.apis.fl_snapshot import RunSnapshot from nvflare.apis.impl.job_def_manager import JobDefManagerSpec from nvflare.apis.job_def import Job -from nvflare.apis.job_launcher_spec import JobLauncherSpec +from nvflare.apis.job_launcher_spec import JobLauncherSpec, JobProcessArgs from nvflare.apis.shareable import ReturnCode, Shareable, make_reply from nvflare.apis.streaming import ConsumerFactory, ObjectProducer, StreamableEngine, StreamContext from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx, get_serializable_data @@ -55,7 +55,14 @@ from nvflare.fuel.utils.zip_utils import zip_directory_to_bytes from nvflare.private.admin_defs import Message, MsgHeader from nvflare.private.aux_runner import AuxMsgTarget -from nvflare.private.defs import CellChannel, CellMessageHeaderKeys, RequestHeader, TrainingTopic, new_cell_message +from nvflare.private.defs import ( + AUTH_CLIENT_NAME_FOR_SJ, + CellChannel, + CellMessageHeaderKeys, + RequestHeader, + TrainingTopic, + new_cell_message, +) from nvflare.private.fed.server.server_json_config import ServerJsonConfigurator from nvflare.private.fed.utils.fed_utils import ( get_job_launcher, @@ -221,11 +228,55 @@ def _start_runner_process(self, job, job_clients, snapshot, fl_ctx: FLContext): restore_snapshot = True else: restore_snapshot = False - fl_ctx.set_prop(FLContextKey.SNAPSHOT, restore_snapshot, private=True, sticky=False) - job_handle = job_launcher.launch_job(job.meta, fl_ctx) - self.logger.info(f"Launch job_id: {job.job_id} with job launcher: {type(job_launcher)} ") + # Job process args are the same for all job launchers! Letting each job launcher compute the job + # args would be error-prone and would require access to internal server components ( + # e.g. cell, server_state, self.server, etc.), which violates component layering. + # + # We prepare job process args here and save the prepared result in the fl_ctx. + # This way, the job launcher won't need to compute these args again. + # The job launcher will only need to use the args properly to launch the job process! + # + # Each arg is a tuple of (arg_option, arg_value). + # Note that the arg_option is fixed for each arg, and is not launcher specific! + workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) args = fl_ctx.get_prop(FLContextKey.ARGS) + server = fl_ctx.get_prop(FLContextKey.SITE_OBJ) + job_id = job.job_id + app_root = workspace_obj.get_app_dir(job_id) + cell = server.cell + server_state = server.server_state + command_options = "" + for t in args.set: + command_options += " " + t + command_options += f" restore_snapshot={restore_snapshot} print_conf=True" + args.set.append("print_conf=True") + args.set.append(f"restore_snapshot={restore_snapshot}") + + # create token and signature for SJ + token = job_id # use the run_number as the auth token + client_name = AUTH_CLIENT_NAME_FOR_SJ + signature = self.server.sign_auth_token(client_name, token) + + job_args = { + JobProcessArgs.JOB_ID: ("-n", job_id), + JobProcessArgs.EXE_MODULE: ("-m", "nvflare.private.fed.app.server.runner_process"), + JobProcessArgs.WORKSPACE: ("-m", args.workspace), + JobProcessArgs.STARTUP_CONFIG_FILE: ("-s", "fed_server.json"), + JobProcessArgs.APP_ROOT: ("-r", app_root), + JobProcessArgs.HA_MODE: ("--ha_mode", server.ha_mode), + JobProcessArgs.AUTH_TOKEN: ("-t", token), + JobProcessArgs.TOKEN_SIGNATURE: ("-ts", signature), + JobProcessArgs.PARENT_URL: ("-p", str(cell.get_internal_listener_url())), + JobProcessArgs.ROOT_URL: ("-u", str(cell.get_root_url_for_child())), + JobProcessArgs.SERVICE_HOST: ("--host", str(server_state.host)), + JobProcessArgs.SERVICE_PORT: ("--port", str(server_state.service_port)), + JobProcessArgs.SSID: ("--ssid", str(server_state.ssid)), + JobProcessArgs.OPTIONS: ("--set", command_options), + } + fl_ctx.set_prop(key=FLContextKey.JOB_PROCESS_ARGS, value=job_args, private=True, sticky=False) + job_handle = job_launcher.launch_job(job.meta, fl_ctx) + self.logger.info(f"Launch job_id: {job.job_id} with job launcher: {type(job_launcher)} ") if not job_clients: job_clients = self.client_manager.clients @@ -406,7 +457,7 @@ def initialize_comm(self, cell: Cell): Returns: """ - self.logger.info("initialize_comm called!") + self.logger.debug("initialize_comm called!") self.cell = cell if self.run_manager: # Note that the aux_runner is created with the self.run_manager as the "engine". diff --git a/nvflare/private/fed/utils/fed_utils.py b/nvflare/private/fed/utils/fed_utils.py index 5b84f55909..42d2ca3840 100644 --- a/nvflare/private/fed/utils/fed_utils.py +++ b/nvflare/private/fed/utils/fed_utils.py @@ -17,7 +17,7 @@ import pkgutil import sys import warnings -from typing import Any, List, Union +from typing import List, Union from nvflare.apis.app_validation import AppValidator from nvflare.apis.client import Client @@ -30,14 +30,12 @@ from nvflare.apis.utils.decomposers import flare_decomposers from nvflare.apis.workspace import Workspace from nvflare.app_common.decomposers import common_decomposers -from nvflare.fuel.data_event.data_bus import DataBus from nvflare.fuel.f3.stats_pool import CsvRecordHandler, StatsPoolManager from nvflare.fuel.sec.audit import AuditService from nvflare.fuel.sec.authz import AuthorizationService from nvflare.fuel.sec.security_content_service import LoadResult, SecurityContentService from nvflare.fuel.utils import fobs from nvflare.fuel.utils.fobs.fobs import register_custom_folder -from nvflare.fuel.utils.validation_utils import check_str from nvflare.private.defs import RequestHeader, SSLConstants from nvflare.private.event import fire_event from nvflare.private.fed.utils.decomposers import private_decomposers @@ -430,43 +428,6 @@ def extract_participants(participants_list): return participants -def _scope_prop_key(scope_name: str, key: str): - return f"{scope_name}::{key}" - - -def set_scope_prop(scope_name: str, key: str, value: Any): - """Save the specified property of the specified scope (globally). - - Args: - scope_name: name of the scope - key: key of the property to be saved - value: value of property - - Returns: None - - """ - check_str("scope_name", scope_name) - check_str("key", key) - data_bus = DataBus() - data_bus.put_data(_scope_prop_key(scope_name, key), value) - - -def get_scope_prop(scope_name: str, key: str) -> Any: - """Get the value of a specified property from the specified scope. - - Args: - scope_name: name of the scope - key: key of the scope - - Returns: - - """ - check_str("scope_name", scope_name) - check_str("key", key) - data_bus = DataBus() - return data_bus.get_data(_scope_prop_key(scope_name, key)) - - def get_job_launcher(job_meta: dict, fl_ctx: FLContext) -> JobLauncherSpec: engine = fl_ctx.get_engine() diff --git a/nvflare/private/fed/utils/identity_utils.py b/nvflare/private/fed/utils/identity_utils.py index e45fb2562b..d8a8a44850 100644 --- a/nvflare/private/fed/utils/identity_utils.py +++ b/nvflare/private/fed/utils/identity_utils.py @@ -70,6 +70,17 @@ def __init__(self, private_key_file: str, cert_file: str): def sign_common_name(self, nonce: str) -> str: return sign_content(self.cn + nonce, self.pri_key, return_str=False) + def sign(self, content, return_str: bool) -> str: + return sign_content(content, self.pri_key, return_str=return_str) + + def verify_signature(self, content, signature) -> bool: + pub_key = self.cert.public_key() + try: + verify_content(content=content, signature=signature, public_key=pub_key) + return True + except Exception: + return False + class IdentityVerifier: def __init__(self, root_cert_file: str): diff --git a/nvflare/utils/job_launcher_utils.py b/nvflare/utils/job_launcher_utils.py index 8e03c12816..6865673869 100644 --- a/nvflare/utils/job_launcher_utils.py +++ b/nvflare/utils/job_launcher_utils.py @@ -17,86 +17,89 @@ from nvflare.apis.fl_constant import FLContextKey, JobConstants, SystemVarName from nvflare.apis.job_def import JobMetaKey -from nvflare.apis.workspace import Workspace - - -def generate_client_command(job_meta, fl_ctx): - workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) - args = fl_ctx.get_prop(FLContextKey.ARGS) - client = fl_ctx.get_prop(FLContextKey.SITE_OBJ) - job_id = job_meta.get(JobConstants.JOB_ID) - server_config = fl_ctx.get_prop(FLContextKey.SERVER_CONFIG) - if not server_config: - raise RuntimeError(f"missing {FLContextKey.SERVER_CONFIG} in FL context") - service = server_config[0].get("service", {}) - if not isinstance(service, dict): - raise RuntimeError(f"expect server config data to be dict but got {type(service)}") - command_options = "" - for t in args.set: - command_options += " " + t - command = ( - f"{sys.executable} -m nvflare.private.fed.app.client.worker_process -m " - + args.workspace - + " -w " - + (workspace_obj.get_startup_kit_dir()) - + " -t " - + client.token - + " -d " - + client.ssid - + " -n " - + job_id - + " -c " - + client.client_name - + " -p " - + str(client.cell.get_internal_listener_url()) - + " -g " - + service.get("target") - + " -scheme " - + service.get("scheme", "grpc") - + " -s fed_client.json " - " --set" + command_options + " print_conf=True" +from nvflare.apis.job_launcher_spec import JobProcessArgs + + +def _job_args_str(job_args, arg_names) -> str: + result = "" + sep = "" + for name in arg_names: + n, v = job_args[name] + result += f"{sep}{n} {v}" + sep = " " + return result + + +def get_client_job_args(include_exe_module=True, include_set_options=True): + result = [] + if include_exe_module: + result.append(JobProcessArgs.EXE_MODULE) + + result.extend( + [ + JobProcessArgs.WORKSPACE, + JobProcessArgs.STARTUP_DIR, + JobProcessArgs.AUTH_TOKEN, + JobProcessArgs.TOKEN_SIGNATURE, + JobProcessArgs.SSID, + JobProcessArgs.JOB_ID, + JobProcessArgs.CLIENT_NAME, + JobProcessArgs.PARENT_URL, + JobProcessArgs.TARGET, + JobProcessArgs.SCHEME, + JobProcessArgs.STARTUP_CONFIG_FILE, + ] ) - return command - - -def generate_server_command(job_meta, fl_ctx): - workspace_obj: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) - args = fl_ctx.get_prop(FLContextKey.ARGS) - server = fl_ctx.get_prop(FLContextKey.SITE_OBJ) - job_id = job_meta.get(JobConstants.JOB_ID) - restore_snapshot = fl_ctx.get_prop(FLContextKey.SNAPSHOT, False) - app_root = workspace_obj.get_app_dir(job_id) - cell = server.cell - server_state = server.server_state - command_options = "" - for t in args.set: - command_options += " " + t - command = ( - sys.executable - + " -m nvflare.private.fed.app.server.runner_process -m " - + args.workspace - + " -s fed_server.json -r " - + app_root - + " -n " - + str(job_id) - + " -p " - + str(cell.get_internal_listener_url()) - + " -u " - + str(cell.get_root_url_for_child()) - + " --host " - + str(server_state.host) - + " --port " - + str(server_state.service_port) - + " --ssid " - + str(server_state.ssid) - + " --ha_mode " - + str(server.ha_mode) - + " --set" - + command_options - + " print_conf=True restore_snapshot=" - + str(restore_snapshot) + + if include_set_options: + result.append(JobProcessArgs.OPTIONS) + + return result + + +def generate_client_command(fl_ctx) -> str: + job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS) + if not job_args: + raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext") + + args_str = _job_args_str(job_args, get_client_job_args()) + return f"{sys.executable} {args_str}" + + +def get_server_job_args(include_exe_module=True, include_set_options=True): + result = [] + if include_exe_module: + result.append(JobProcessArgs.EXE_MODULE) + + result.extend( + [ + JobProcessArgs.WORKSPACE, + JobProcessArgs.STARTUP_CONFIG_FILE, + JobProcessArgs.APP_ROOT, + JobProcessArgs.JOB_ID, + JobProcessArgs.TOKEN_SIGNATURE, + JobProcessArgs.PARENT_URL, + JobProcessArgs.ROOT_URL, + JobProcessArgs.SERVICE_HOST, + JobProcessArgs.SERVICE_PORT, + JobProcessArgs.SSID, + JobProcessArgs.HA_MODE, + ] ) - return command + + if include_set_options: + result.append(JobProcessArgs.OPTIONS) + + return result + + +def generate_server_command(fl_ctx) -> str: + job_args = fl_ctx.get_prop(FLContextKey.JOB_PROCESS_ARGS) + if not job_args: + raise RuntimeError(f"missing {FLContextKey.JOB_PROCESS_ARGS} in FLContext!") + + args_str = _job_args_str(job_args, get_server_job_args()) + return f"{sys.executable} {args_str}" def extract_job_image(job_meta, site_name): diff --git a/tests/unit_test/fuel/f3/communicator_test.py b/tests/unit_test/fuel/f3/communicator_test.py index 8e0436ebb0..68a76e52e0 100644 --- a/tests/unit_test/fuel/f3/communicator_test.py +++ b/tests/unit_test/fuel/f3/communicator_test.py @@ -99,7 +99,7 @@ def test_sfm_message(self, scheme, port_range): comm_a = get_comm_a(comm_state) comm_b = get_comm_b(comm_state) - _, url = comm_a.start_listener(scheme, {"ports": port_range}) + _, url, _ = comm_a.start_listener(scheme, {"ports": port_range}) comm_a.start() # Check port is in the range