diff --git a/docs/programming_guide/fed_job_api.rst b/docs/programming_guide/fed_job_api.rst index 8f9e3cec01..b84b131969 100644 --- a/docs/programming_guide/fed_job_api.rst +++ b/docs/programming_guide/fed_job_api.rst @@ -366,7 +366,7 @@ The FedAvgJob automatically adds the FedAvg controller, PTFileModelPersistor and For more examples of job patterns, see: -* :class:`BaseFedJob` +* :class:`CommonJob` * :class:`FedAvgJob` (pytorch) * :class:`FedAvgJob` (tensorflow) * :class:`CCWFJob` diff --git a/examples/getting_started/pt/nvflare_lightning_getting_started.ipynb b/examples/getting_started/pt/nvflare_lightning_getting_started.ipynb index 6f26be3100..c2858e7946 100644 --- a/examples/getting_started/pt/nvflare_lightning_getting_started.ipynb +++ b/examples/getting_started/pt/nvflare_lightning_getting_started.ipynb @@ -321,8 +321,8 @@ "#### 2. Define a FedJob\n", "The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine.\n", "\n", - "Here we use a PyTorch `BaseFedJob`, where we can define the job name and the initial global model.\n", - "The `BaseFedJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience." + "Here we use a PyTorch `PTJob`, where we can define the job name and the initial global model.\n", + "The `PTJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience." ] }, { @@ -335,10 +335,10 @@ "from src.lit_net import LitNet\n", "\n", "from nvflare.app_common.workflows.fedavg import FedAvg\n", - "from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob\n", + "from nvflare.app_opt.pt.job_config.pt_job import PTJob\n", "from nvflare.job_config.script_runner import ScriptRunner\n", "\n", - "job = BaseFedJob(\n", + "job = PTJob(\n", " name=\"cifar10_lightning_fedavg\",\n", " initial_model=LitNet(),\n", ")" diff --git a/examples/getting_started/pt/nvflare_pt_getting_started.ipynb b/examples/getting_started/pt/nvflare_pt_getting_started.ipynb index 79e1b99259..fb7f25a6cc 100644 --- a/examples/getting_started/pt/nvflare_pt_getting_started.ipynb +++ b/examples/getting_started/pt/nvflare_pt_getting_started.ipynb @@ -263,8 +263,8 @@ "#### 2. Define a FedJob\n", "The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine.\n", "\n", - "Here we use a PyTorch `BaseFedJob`, where we can define the job name and the initial global model.\n", - "The `BaseFedJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience." + "Here we use a PyTorch `PTJob`, where we can define the job name and the initial global model.\n", + "The `PTJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience." ] }, { @@ -277,10 +277,10 @@ "from src.net import Net\n", "\n", "from nvflare.app_common.workflows.fedavg import FedAvg\n", - "from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob\n", + "from nvflare.app_opt.pt.job_config.pt_job import PTJob\n", "from nvflare.job_config.script_runner import ScriptRunner\n", "\n", - "job = BaseFedJob(\n", + "job = PTJob(\n", " name=\"cifar10_pt_fedavg\",\n", " initial_model=Net(),\n", ")" diff --git a/examples/getting_started/tf/nvflare_tf_getting_started.ipynb b/examples/getting_started/tf/nvflare_tf_getting_started.ipynb index 61afb4f870..a7870aa1db 100644 --- a/examples/getting_started/tf/nvflare_tf_getting_started.ipynb +++ b/examples/getting_started/tf/nvflare_tf_getting_started.ipynb @@ -253,8 +253,8 @@ "#### 2. Define a FedJob\n", "The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine.\n", "\n", - "Here we use a TensorFlow `BaseFedJob`, where we can define the job name and the initial global model.\n", - "The `BaseFedJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience." + "Here we use a TensorFlow `TFJob`, where we can define the job name and the initial global model.\n", + "The `TFJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience." ] }, { @@ -267,10 +267,10 @@ "from src.tf_net import TFNet\n", "\n", "from nvflare.app_common.workflows.fedavg import FedAvg\n", - "from nvflare.app_opt.tf.job_config.base_fed_job import BaseFedJob\n", + "from nvflare.app_opt.tf.job_config.tf_job import TFJob\n", "from nvflare.job_config.script_runner import FrameworkType, ScriptRunner\n", "\n", - "job = BaseFedJob(\n", + "job = TFJob(\n", " name=\"cifar10_tf_fedavg\",\n", " initial_model=TFNet(),\n", ")" diff --git a/nvflare/app_opt/pt/job_config/fed_avg.py b/nvflare/app_opt/pt/job_config/fed_avg.py index 58f10e27a5..300f401ba4 100644 --- a/nvflare/app_opt/pt/job_config/fed_avg.py +++ b/nvflare/app_opt/pt/job_config/fed_avg.py @@ -16,10 +16,10 @@ import torch.nn as nn from nvflare.app_common.workflows.fedavg import FedAvg -from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob +from nvflare.app_opt.pt.job_config.pt_job import PTJob -class FedAvgJob(BaseFedJob): +class FedAvgJob(PTJob): def __init__( self, initial_model: nn.Module, diff --git a/nvflare/app_opt/pt/job_config/fed_sag_mlflow.py b/nvflare/app_opt/pt/job_config/fed_sag_mlflow.py index 131ad11a2f..42ac67535d 100644 --- a/nvflare/app_opt/pt/job_config/fed_sag_mlflow.py +++ b/nvflare/app_opt/pt/job_config/fed_sag_mlflow.py @@ -19,12 +19,12 @@ from nvflare.app_common.aggregators import InTimeAccumulateWeightedAggregator from nvflare.app_common.shareablegenerators import FullModelShareableGenerator from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather -from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob +from nvflare.app_opt.pt.job_config.pt_job import PTJob from nvflare.app_opt.tracking.mlflow.mlflow_receiver import MLflowReceiver from nvflare.app_opt.tracking.mlflow.mlflow_writer import MLflowWriter -class SAGMLFlowJob(BaseFedJob): +class SAGMLFlowJob(PTJob): def __init__( self, initial_model: nn.Module, diff --git a/nvflare/app_opt/pt/job_config/pt_job.py b/nvflare/app_opt/pt/job_config/pt_job.py new file mode 100644 index 0000000000..b3f050a500 --- /dev/null +++ b/nvflare/app_opt/pt/job_config/pt_job.py @@ -0,0 +1,97 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from torch import nn as nn + +from nvflare.app_common.abstract.model_locator import ModelLocator +from nvflare.app_common.abstract.model_persistor import ModelPersistor +from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent +from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector +from nvflare.app_common.widgets.streaming import AnalyticsReceiver +from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator +from nvflare.app_opt.pt.job_config.model import PTModel +from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver +from nvflare.job_config.api import validate_object_for_job +from nvflare.job_config.common_job import CommonJob + + +class PTJob(CommonJob): + def __init__( + self, + initial_model: nn.Module = None, + name: str = "fed_job", + min_clients: int = 1, + mandatory_clients: Optional[List[str]] = None, + key_metric: str = "accuracy", + validation_json_generator: Optional[ValidationJsonGenerator] = None, + intime_model_selector: Optional[IntimeModelSelector] = None, + convert_to_fed_event: Optional[ConvertToFedEvent] = None, + analytics_receiver: Optional[AnalyticsReceiver] = None, + model_persistor: Optional[ModelPersistor] = None, + model_locator: Optional[ModelLocator] = None, + ): + """PyTorch CommonJob. + + Configures ValidationJsonGenerator, IntimeModelSelector, AnalyticsReceiver, ConvertToFedEvent. + + User must add controllers and executors. + + Args: + initial_model (nn.Module): initial PyTorch Model. Defaults to None. + name (name, optional): name of the job. Defaults to "fed_job". + min_clients (int, optional): the minimum number of clients for the job. Defaults to 1. + mandatory_clients (List[str], optional): mandatory clients to run the job. Default None. + key_metric (str, optional): Metric used to determine if the model is globally best. + if metrics are a `dict`, `key_metric` can select the metric used for global model selection. + Defaults to "accuracy". + validation_json_generator (ValidationJsonGenerator, optional): A component for generating validation results. + if not provided, a ValidationJsonGenerator will be configured. + intime_model_selector: (IntimeModelSelector, optional): A component for select the model. + if not provided, an IntimeModelSelector will be configured. + convert_to_fed_event: (ConvertToFedEvent, optional): A component to covert certain events to fed events. + if not provided, a ConvertToFedEvent object will be created. + analytics_receiver (AnlyticsReceiver, optional): Receive analytics. + If not provided, a TBAnalyticsReceiver will be configured. + model_persistor (optional, ModelPersistor): how to persistor the model. + model_locator (optional, ModelLocator): how to locate the model. + """ + super().__init__( + name=name, + min_clients=min_clients, + mandatory_clients=mandatory_clients, + key_metric=key_metric, + validation_json_generator=validation_json_generator, + intime_model_selector=intime_model_selector, + convert_to_fed_event=convert_to_fed_event, + ) + + self.initial_model = initial_model + self.comp_ids = {} + + if analytics_receiver: + validate_object_for_job("analytics_receiver", analytics_receiver, AnalyticsReceiver) + else: + analytics_receiver = TBAnalyticsReceiver() + + self.to_server( + id="receiver", + obj=analytics_receiver, + ) + + if initial_model: + self.comp_ids.update( + self.to_server(PTModel(model=initial_model, persistor=model_persistor, locator=model_locator)) + ) diff --git a/nvflare/app_opt/tf/job_config/fed_avg.py b/nvflare/app_opt/tf/job_config/fed_avg.py index e25d87f574..070aac6159 100644 --- a/nvflare/app_opt/tf/job_config/fed_avg.py +++ b/nvflare/app_opt/tf/job_config/fed_avg.py @@ -16,10 +16,10 @@ import tensorflow as tf from nvflare.app_common.workflows.fedavg import FedAvg -from nvflare.app_opt.tf.job_config.base_fed_job import BaseFedJob +from nvflare.app_opt.tf.job_config.tf_job import TFJob -class FedAvgJob(BaseFedJob): +class FedAvgJob(TFJob): def __init__( self, initial_model: tf.keras.Model, diff --git a/nvflare/app_opt/tf/job_config/base_fed_job.py b/nvflare/app_opt/tf/job_config/tf_job.py similarity index 75% rename from nvflare/app_opt/tf/job_config/base_fed_job.py rename to nvflare/app_opt/tf/job_config/tf_job.py index bf77cd1092..259a3ea319 100644 --- a/nvflare/app_opt/tf/job_config/base_fed_job.py +++ b/nvflare/app_opt/tf/job_config/tf_job.py @@ -17,17 +17,17 @@ import tensorflow as tf from nvflare.app_common.abstract.model_persistor import ModelPersistor -from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector from nvflare.app_common.widgets.streaming import AnalyticsReceiver from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator from nvflare.app_opt.tf.job_config.model import TFModel from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver -from nvflare.job_config.api import FedJob, validate_object_for_job +from nvflare.job_config.api import validate_object_for_job +from nvflare.job_config.common_job import CommonJob -class BaseFedJob(FedJob): +class TFJob(CommonJob): def __init__( self, initial_model: tf.keras.Model = None, @@ -41,7 +41,7 @@ def __init__( analytics_receiver: Optional[AnalyticsReceiver] = None, model_persistor: Optional[ModelPersistor] = None, ): - """TensorFlow BaseFedJob. + """TensorFlow CommonJob. Configures ValidationJsonGenerator, IntimeModelSelector, TBAnalyticsReceiver, ConvertToFedEvent. @@ -69,29 +69,15 @@ def __init__( name=name, min_clients=min_clients, mandatory_clients=mandatory_clients, + key_metric=key_metric, + validation_json_generator=validation_json_generator, + intime_model_selector=intime_model_selector, + convert_to_fed_event=convert_to_fed_event, ) self.initial_model = initial_model self.comp_ids = {} - if validation_json_generator: - validate_object_for_job("validation_json_generator", validation_json_generator, ValidationJsonGenerator) - else: - validation_json_generator = ValidationJsonGenerator() - self.to_server(id="json_generator", obj=validation_json_generator) - - if intime_model_selector: - validate_object_for_job("intime_model_selector", intime_model_selector, IntimeModelSelector) - self.to_server(id="model_selector", obj=intime_model_selector) - elif key_metric: - self.to_server(id="model_selector", obj=IntimeModelSelector(key_metric=key_metric)) - - if convert_to_fed_event: - validate_object_for_job("convert_to_fed_event", convert_to_fed_event, ConvertToFedEvent) - else: - convert_to_fed_event = ConvertToFedEvent(events_to_convert=[ANALYTIC_EVENT_TYPE]) - self.convert_to_fed_event = convert_to_fed_event - if analytics_receiver: validate_object_for_job("analytics_receiver", analytics_receiver, AnalyticsReceiver) else: @@ -104,6 +90,3 @@ def __init__( if initial_model: self.comp_ids["persistor_id"] = self.to_server(TFModel(model=initial_model, persistor=model_persistor)) - - def set_up_client(self, target: str): - self.to(id="event_to_fed", obj=self.convert_to_fed_event, target=target) diff --git a/nvflare/app_opt/pt/job_config/base_fed_job.py b/nvflare/job_config/common_job.py similarity index 82% rename from nvflare/app_opt/pt/job_config/base_fed_job.py rename to nvflare/job_config/common_job.py index 0225064967..c0d6d0b836 100644 --- a/nvflare/app_opt/pt/job_config/base_fed_job.py +++ b/nvflare/job_config/common_job.py @@ -14,8 +14,6 @@ from typing import List, Optional -from torch import nn as nn - from nvflare.app_common.abstract.model_locator import ModelLocator from nvflare.app_common.abstract.model_persistor import ModelPersistor from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE @@ -23,15 +21,12 @@ from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector from nvflare.app_common.widgets.streaming import AnalyticsReceiver from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator -from nvflare.app_opt.pt.job_config.model import PTModel -from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver from nvflare.job_config.api import FedJob, validate_object_for_job -class BaseFedJob(FedJob): +class CommonJob(FedJob): def __init__( self, - initial_model: nn.Module = None, name: str = "fed_job", min_clients: int = 1, mandatory_clients: Optional[List[str]] = None, @@ -43,14 +38,15 @@ def __init__( model_persistor: Optional[ModelPersistor] = None, model_locator: Optional[ModelLocator] = None, ): - """PyTorch BaseFedJob. + """CommonJob. + + By default configures ValidationJsonGenerator, IntimeModelSelector, ConvertToFedEvent. - Configures ValidationJsonGenerator, IntimeModelSelector, AnalyticsReceiver, ConvertToFedEvent. + If provided, configures AnalyticsReceiver, ModelPersistor, ModelLocator. User must add controllers and executors. Args: - initial_model (nn.Module): initial PyTorch Model. Defaults to None. name (name, optional): name of the job. Defaults to "fed_job". min_clients (int, optional): the minimum number of clients for the job. Defaults to 1. mandatory_clients (List[str], optional): mandatory clients to run the job. Default None. @@ -64,8 +60,7 @@ def __init__( convert_to_fed_event: (ConvertToFedEvent, optional): A component to covert certain events to fed events. if not provided, a ConvertToFedEvent object will be created. analytics_receiver (AnlyticsReceiver, optional): Receive analytics. - If not provided, a TBAnalyticsReceiver will be configured. - model_persistor (optional, ModelPersistor): how to persistor the model. + model_persistor (optional, ModelPersistor): how to persist the model. model_locator (optional, ModelLocator): how to locate the model. """ super().__init__( @@ -74,9 +69,6 @@ def __init__( mandatory_clients=mandatory_clients, ) - self.initial_model = initial_model - self.comp_ids = {} - if validation_json_generator: validate_object_for_job("validation_json_generator", validation_json_generator, ValidationJsonGenerator) else: @@ -97,18 +89,15 @@ def __init__( if analytics_receiver: validate_object_for_job("analytics_receiver", analytics_receiver, AnalyticsReceiver) - else: - analytics_receiver = TBAnalyticsReceiver() + self.to_server(id="receiver", obj=analytics_receiver) - self.to_server( - id="receiver", - obj=analytics_receiver, - ) + if model_persistor: + validate_object_for_job("persistor", model_persistor, ModelPersistor) + self.to_server(id="persistor", obj=model_persistor) - if initial_model: - self.comp_ids.update( - self.to_server(PTModel(model=initial_model, persistor=model_persistor, locator=model_locator)) - ) + if model_locator: + validate_object_for_job("locator", model_locator, ModelLocator) + self.to_server(id="locator", obj=model_locator) def set_up_client(self, target: str): self.to(id="event_to_fed", obj=self.convert_to_fed_event, target=target) diff --git a/web/src/components/code.astro b/web/src/components/code.astro index c7e8e69c23..c8c8d95ce5 100644 --- a/web/src/components/code.astro +++ b/web/src/components/code.astro @@ -206,7 +206,7 @@ const jobCode_pt = ` from cifar10_pt_fl import Net from nvflare.app_common.workflows.fedavg import FedAvg -from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob +from nvflare.app_opt.pt.job_config.pt_job import PTJob from nvflare.job_config.script_runner import ScriptRunner if __name__ == "__main__": @@ -214,8 +214,8 @@ if __name__ == "__main__": num_rounds = 2 train_script = "cifar10_pt_fl.py" - # Create BaseFedJob with initial model - job = BaseFedJob( + # Create PTJob with initial model + job = PTJob( name="cifar10_pt_fedavg", initial_model=Net(), ) @@ -425,7 +425,7 @@ const jobCode_lt = ` from cifar10_lightning_fl import LitNet from nvflare.app_common.workflows.fedavg import FedAvg -from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob +from nvflare.app_opt.pt.job_config.pt_job import PTJob from nvflare.job_config.script_runner import ScriptRunner if __name__ == "__main__": @@ -433,8 +433,8 @@ if __name__ == "__main__": num_rounds = 2 train_script = "cifar10_lightning_fl.py" - # Create BaseFedJob with initial model - job = BaseFedJob( + # Create PTJob with initial model + job = PTJob( name="cifar10_lightning_fedavg", initial_model=LitNet(), ) @@ -587,7 +587,7 @@ const jobCode_tf = ` from cifar10_tf_fl import TFNet from nvflare.app_common.workflows.fedavg import FedAvg -from nvflare.app_opt.tf.job_config.base_fed_job import BaseFedJob +from nvflare.app_opt.tf.job_config.tf_job import TFJob from nvflare.job_config.script_runner import FrameworkType, ScriptRunner if __name__ == "__main__": @@ -595,8 +595,8 @@ if __name__ == "__main__": num_rounds = 2 train_script = "cifar10_tf_fl.py" - # Create BaseFedJob with initial model - job = BaseFedJob( + # Create TFJob with initial model + job = TFJob( name="cifar10_tf_fedavg", initial_model=TFNet(input_shape=(None, 32, 32, 3)), ) @@ -665,7 +665,7 @@ const frameworks = [ framework: "pytorch", title: "Job Code (fedavg_cifar10_pt_job.py)", description: - "Lastly we construct the job with our 'cifar10_pt_fl.py' client script and 'FedAvg' server controller. The BaseFedJob automatically configures components for model persistence, model selection, and TensorBoard streaming. We then run the job with the FL simulator.", + "Lastly we construct the job with our 'cifar10_pt_fl.py' client script and 'FedAvg' server controller. The PTJob automatically configures components for model persistence, model selection, and TensorBoard streaming. We then run the job with the FL simulator.", code: jobCode_pt, }, { @@ -719,7 +719,7 @@ const frameworks = [ framework: "lightning", title: "Job Code (fedavg_cifar10_lightning_job.py)", description: - "Lastly we construct the job with our 'cifar10_lightning_fl.py' client script and 'FedAvg' server controller. The BaseFedJob automatically configures components for model persistence, model selection, and TensorBoard streaming. We then run the job with the FL simulator.", + "Lastly we construct the job with our 'cifar10_lightning_fl.py' client script and 'FedAvg' server controller. The PTJob automatically configures components for model persistence, model selection, and TensorBoard streaming. We then run the job with the FL simulator.", code: jobCode_lt, }, { @@ -773,7 +773,7 @@ const frameworks = [ framework: "tensorflow", title: "Job Code (fedavg_cifar10_tf_job.py)", description: - "Lastly we construct the job with our 'cifar10_tf_fl.py' client script and 'FedAvg' server controller. The BaseFedJob automatically configures components for model persistence, model selection, and TensorBoard streaming. We then run the job with the FL simulator.", + "Lastly we construct the job with our 'cifar10_tf_fl.py' client script and 'FedAvg' server controller. The TFJob automatically configures components for model persistence, model selection, and TensorBoard streaming. We then run the job with the FL simulator.", code: jobCode_tf, }, {