Skip to content

Commit

Permalink
add additional BaseFedJob layer
Browse files Browse the repository at this point in the history
  • Loading branch information
SYangster committed Dec 10, 2024
1 parent e594dcb commit 91a0bea
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 79 deletions.
2 changes: 1 addition & 1 deletion docs/programming_guide/fed_job_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ The FedAvgJob automatically adds the FedAvg controller, PTFileModelPersistor and
For more examples of job patterns, see:

* :class:`BaseFedJob<nvflare.app_opt.pt.job_config.base_fed_job.BaseFedJob>`
* :class:`BaseFedJob<nvflare.job_config.base_fed_job.BaseFedJob>`
* :class:`FedAvgJob<nvflare.app_opt.pt.job_config.fed_avg.FedAvgJob>` (pytorch)
* :class:`FedAvgJob<nvflare.app_opt.tf.job_config.fed_avg.FedAvgJob>` (tensorflow)
* :class:`CCWFJob<nvflare.app_common.ccwf.ccwf_job.CCWFJob>`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@
"#### 2. Define a FedJob\n",
"The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine.\n",
"\n",
"Here we use a PyTorch `BaseFedJob`, where we can define the job name and the initial global model.\n",
"The `BaseFedJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience."
"Here we use a PyTorch `PTBaseFedJob`, where we can define the job name and the initial global model.\n",
"The `PTBaseFedJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience."
]
},
{
Expand All @@ -335,10 +335,10 @@
"from src.lit_net import LitNet\n",
"\n",
"from nvflare.app_common.workflows.fedavg import FedAvg\n",
"from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob\n",
"from nvflare.app_opt.pt.job_config.base_fed_job import PTBaseFedJob\n",
"from nvflare.job_config.script_runner import ScriptRunner\n",
"\n",
"job = BaseFedJob(\n",
"job = PTBaseFedJob(\n",
" name=\"cifar10_lightning_fedavg\",\n",
" initial_model=LitNet(),\n",
")"
Expand Down
8 changes: 4 additions & 4 deletions examples/getting_started/pt/nvflare_pt_getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@
"#### 2. Define a FedJob\n",
"The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine.\n",
"\n",
"Here we use a PyTorch `BaseFedJob`, where we can define the job name and the initial global model.\n",
"The `BaseFedJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience."
"Here we use a PyTorch `PTBaseFedJob`, where we can define the job name and the initial global model.\n",
"The `PTBaseFedJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience."
]
},
{
Expand All @@ -277,10 +277,10 @@
"from src.net import Net\n",
"\n",
"from nvflare.app_common.workflows.fedavg import FedAvg\n",
"from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob\n",
"from nvflare.app_opt.pt.job_config.base_fed_job import PTBaseFedJob\n",
"from nvflare.job_config.script_runner import ScriptRunner\n",
"\n",
"job = BaseFedJob(\n",
"job = PTBaseFedJob(\n",
" name=\"cifar10_pt_fedavg\",\n",
" initial_model=Net(),\n",
")"
Expand Down
8 changes: 4 additions & 4 deletions examples/getting_started/tf/nvflare_tf_getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@
"#### 2. Define a FedJob\n",
"The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine.\n",
"\n",
"Here we use a TensorFlow `BaseFedJob`, where we can define the job name and the initial global model.\n",
"The `BaseFedJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience."
"Here we use a TensorFlow `TFBaseFedJob`, where we can define the job name and the initial global model.\n",
"The `TFBaseFedJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience."
]
},
{
Expand All @@ -267,10 +267,10 @@
"from src.tf_net import TFNet\n",
"\n",
"from nvflare.app_common.workflows.fedavg import FedAvg\n",
"from nvflare.app_opt.tf.job_config.base_fed_job import BaseFedJob\n",
"from nvflare.app_opt.tf.job_config.base_fed_job import TFBaseFedJob\n",
"from nvflare.job_config.script_runner import FrameworkType, ScriptRunner\n",
"\n",
"job = BaseFedJob(\n",
"job = TFBaseFedJob(\n",
" name=\"cifar10_tf_fedavg\",\n",
" initial_model=TFNet(),\n",
")"
Expand Down
31 changes: 7 additions & 24 deletions nvflare/app_opt/pt/job_config/base_fed_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@

from nvflare.app_common.abstract.model_locator import ModelLocator
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE
from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.widgets.streaming import AnalyticsReceiver
from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator
from nvflare.app_opt.pt.job_config.model import PTModel
from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver
from nvflare.job_config.api import FedJob, validate_object_for_job
from nvflare.job_config.api import validate_object_for_job
from nvflare.job_config.base_fed_job import BaseFedJob


class BaseFedJob(FedJob):
class PTBaseFedJob(BaseFedJob):
def __init__(
self,
initial_model: nn.Module = None,
Expand Down Expand Up @@ -72,29 +72,15 @@ def __init__(
name=name,
min_clients=min_clients,
mandatory_clients=mandatory_clients,
key_metric=key_metric,
validation_json_generator=validation_json_generator,
intime_model_selector=intime_model_selector,
convert_to_fed_event=convert_to_fed_event,
)

self.initial_model = initial_model
self.comp_ids = {}

if validation_json_generator:
validate_object_for_job("validation_json_generator", validation_json_generator, ValidationJsonGenerator)
else:
validation_json_generator = ValidationJsonGenerator()
self.to_server(id="json_generator", obj=validation_json_generator)

if intime_model_selector:
validate_object_for_job("intime_model_selector", intime_model_selector, IntimeModelSelector)
self.to_server(id="model_selector", obj=intime_model_selector)
elif key_metric:
self.to_server(id="model_selector", obj=IntimeModelSelector(key_metric=key_metric))

if convert_to_fed_event:
validate_object_for_job("convert_to_fed_event", convert_to_fed_event, ConvertToFedEvent)
else:
convert_to_fed_event = ConvertToFedEvent(events_to_convert=[ANALYTIC_EVENT_TYPE])
self.convert_to_fed_event = convert_to_fed_event

if analytics_receiver:
validate_object_for_job("analytics_receiver", analytics_receiver, AnalyticsReceiver)
else:
Expand All @@ -109,6 +95,3 @@ def __init__(
self.comp_ids.update(
self.to_server(PTModel(model=initial_model, persistor=model_persistor, locator=model_locator))
)

def set_up_client(self, target: str):
self.to(id="event_to_fed", obj=self.convert_to_fed_event, target=target)
4 changes: 2 additions & 2 deletions nvflare/app_opt/pt/job_config/fed_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import torch.nn as nn

from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob
from nvflare.app_opt.pt.job_config.base_fed_job import PTBaseFedJob


class FedAvgJob(BaseFedJob):
class FedAvgJob(PTBaseFedJob):
def __init__(
self,
initial_model: nn.Module,
Expand Down
4 changes: 2 additions & 2 deletions nvflare/app_opt/pt/job_config/fed_sag_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from nvflare.app_common.aggregators import InTimeAccumulateWeightedAggregator
from nvflare.app_common.shareablegenerators import FullModelShareableGenerator
from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather
from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob
from nvflare.app_opt.pt.job_config.base_fed_job import PTBaseFedJob
from nvflare.app_opt.tracking.mlflow.mlflow_receiver import MLflowReceiver
from nvflare.app_opt.tracking.mlflow.mlflow_writer import MLflowWriter


class SAGMLFlowJob(BaseFedJob):
class SAGMLFlowJob(PTBaseFedJob):
def __init__(
self,
initial_model: nn.Module,
Expand Down
31 changes: 7 additions & 24 deletions nvflare/app_opt/tf/job_config/base_fed_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
import tensorflow as tf

from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE
from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.widgets.streaming import AnalyticsReceiver
from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator
from nvflare.app_opt.tf.job_config.model import TFModel
from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver
from nvflare.job_config.api import FedJob, validate_object_for_job
from nvflare.job_config.api import validate_object_for_job
from nvflare.job_config.base_fed_job import BaseFedJob


class BaseFedJob(FedJob):
class TFBaseFedJob(BaseFedJob):
def __init__(
self,
initial_model: tf.keras.Model = None,
Expand Down Expand Up @@ -69,29 +69,15 @@ def __init__(
name=name,
min_clients=min_clients,
mandatory_clients=mandatory_clients,
key_metric=key_metric,
validation_json_generator=validation_json_generator,
intime_model_selector=intime_model_selector,
convert_to_fed_event=convert_to_fed_event,
)

self.initial_model = initial_model
self.comp_ids = {}

if validation_json_generator:
validate_object_for_job("validation_json_generator", validation_json_generator, ValidationJsonGenerator)
else:
validation_json_generator = ValidationJsonGenerator()
self.to_server(id="json_generator", obj=validation_json_generator)

if intime_model_selector:
validate_object_for_job("intime_model_selector", intime_model_selector, IntimeModelSelector)
self.to_server(id="model_selector", obj=intime_model_selector)
elif key_metric:
self.to_server(id="model_selector", obj=IntimeModelSelector(key_metric=key_metric))

if convert_to_fed_event:
validate_object_for_job("convert_to_fed_event", convert_to_fed_event, ConvertToFedEvent)
else:
convert_to_fed_event = ConvertToFedEvent(events_to_convert=[ANALYTIC_EVENT_TYPE])
self.convert_to_fed_event = convert_to_fed_event

if analytics_receiver:
validate_object_for_job("analytics_receiver", analytics_receiver, AnalyticsReceiver)
else:
Expand All @@ -104,6 +90,3 @@ def __init__(

if initial_model:
self.comp_ids["persistor_id"] = self.to_server(TFModel(model=initial_model, persistor=model_persistor))

def set_up_client(self, target: str):
self.to(id="event_to_fed", obj=self.convert_to_fed_event, target=target)
4 changes: 2 additions & 2 deletions nvflare/app_opt/tf/job_config/fed_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import tensorflow as tf

from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.tf.job_config.base_fed_job import BaseFedJob
from nvflare.app_opt.tf.job_config.base_fed_job import TFBaseFedJob


class FedAvgJob(BaseFedJob):
class FedAvgJob(TFBaseFedJob):
def __init__(
self,
initial_model: tf.keras.Model,
Expand Down
103 changes: 103 additions & 0 deletions nvflare/job_config/base_fed_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

from nvflare.app_common.abstract.model_locator import ModelLocator
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE
from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.widgets.streaming import AnalyticsReceiver
from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator
from nvflare.job_config.api import FedJob, validate_object_for_job


class BaseFedJob(FedJob):
def __init__(
self,
name: str = "fed_job",
min_clients: int = 1,
mandatory_clients: Optional[List[str]] = None,
key_metric: str = "accuracy",
validation_json_generator: Optional[ValidationJsonGenerator] = None,
intime_model_selector: Optional[IntimeModelSelector] = None,
convert_to_fed_event: Optional[ConvertToFedEvent] = None,
analytics_receiver: Optional[AnalyticsReceiver] = None,
model_persistor: Optional[ModelPersistor] = None,
model_locator: Optional[ModelLocator] = None,
):
"""BaseFedJob.
By default configures ValidationJsonGenerator, IntimeModelSelector, ConvertToFedEvent.
If provided, configures AnalyticsReceiver, ModelPersistor, ModelLocator.
User must add controllers and executors.
Args:
name (name, optional): name of the job. Defaults to "fed_job".
min_clients (int, optional): the minimum number of clients for the job. Defaults to 1.
mandatory_clients (List[str], optional): mandatory clients to run the job. Default None.
key_metric (str, optional): Metric used to determine if the model is globally best.
if metrics are a `dict`, `key_metric` can select the metric used for global model selection.
Defaults to "accuracy".
validation_json_generator (ValidationJsonGenerator, optional): A component for generating validation results.
if not provided, a ValidationJsonGenerator will be configured.
intime_model_selector: (IntimeModelSelector, optional): A component for select the model.
if not provided, an IntimeModelSelector will be configured.
convert_to_fed_event: (ConvertToFedEvent, optional): A component to covert certain events to fed events.
if not provided, a ConvertToFedEvent object will be created.
analytics_receiver (AnlyticsReceiver, optional): Receive analytics.
model_persistor (optional, ModelPersistor): how to persist the model.
model_locator (optional, ModelLocator): how to locate the model.
"""
super().__init__(
name=name,
min_clients=min_clients,
mandatory_clients=mandatory_clients,
)

if validation_json_generator:
validate_object_for_job("validation_json_generator", validation_json_generator, ValidationJsonGenerator)
else:
validation_json_generator = ValidationJsonGenerator()
self.to_server(id="json_generator", obj=validation_json_generator)

if intime_model_selector:
validate_object_for_job("intime_model_selector", intime_model_selector, IntimeModelSelector)
self.to_server(id="model_selector", obj=intime_model_selector)
elif key_metric:
self.to_server(id="model_selector", obj=IntimeModelSelector(key_metric=key_metric))

if convert_to_fed_event:
validate_object_for_job("convert_to_fed_event", convert_to_fed_event, ConvertToFedEvent)
else:
convert_to_fed_event = ConvertToFedEvent(events_to_convert=[ANALYTIC_EVENT_TYPE])
self.convert_to_fed_event = convert_to_fed_event

if analytics_receiver:
validate_object_for_job("analytics_receiver", analytics_receiver, AnalyticsReceiver)
self.to_server(id="receiver", obj=analytics_receiver)

if model_persistor:
validate_object_for_job("persistor", model_persistor, ModelPersistor)
self.to_server(id="persistor", obj=model_persistor)

if model_locator:
validate_object_for_job("locator", model_locator, ModelLocator)
self.to_server(id="locator", obj=model_locator)

def set_up_client(self, target: str):
self.to(id="event_to_fed", obj=self.convert_to_fed_event, target=target)
Loading

0 comments on commit 91a0bea

Please sign in to comment.