Skip to content

Commit

Permalink
Merge pull request #38 from Iacob-Alexandru-Andrei/main
Browse files Browse the repository at this point in the history
Allow custom Ray actors
  • Loading branch information
relogu authored Apr 6, 2024
2 parents fca9b96 + e964563 commit 1a44417
Show file tree
Hide file tree
Showing 12 changed files with 211 additions and 61 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ wandb/*
run_scripts/launch_template.sh
run_scripts/launch_mnist.sh
pollen_worker
*.pt
*.npz
*.bin

# Slurm logs
slurm-*.out
Expand Down
9 changes: 8 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ exclude: |
multirun/|
dev/|
data/|
run_scripts/|
.pre-commit-config.yaml
)$
Expand All @@ -22,6 +21,14 @@ repos:
rev: 1.16.0
hooks:
- id: yamlfix
- repo: https://github.com/shellcheck-py/shellcheck-py
rev: v0.10.0.1
hooks:
- id: shellcheck
- repo: https://github.com/scop/pre-commit-shfmt
rev: v3.8.0-1
hooks:
- id: shfmt
- repo: https://github.com/psf/black
rev: 24.1.1
hooks:
Expand Down
43 changes: 36 additions & 7 deletions project/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import random
from pathlib import Path
from typing import Any
from typing import Any, cast

import flwr as fl
from flwr.common import NDArrays
from flwr.server import History
from omegaconf import DictConfig
from pydantic import BaseModel
from torch import nn
Expand All @@ -22,14 +23,23 @@
CID,
ClientDataloaderGen,
ClientGen,
ConfigStructure,
DataStructure,
EvalRes,
FitRes,
GetClientGen,
ClientTypeGen,
NetGen,
TestFunc,
TrainFunc,
ServerRNG,
TrainStructure,
)
from project.utils.utils import obtain_device
from flwr.simulation.ray_transport.ray_actor import (
DefaultActor,
VirtualClientEngineActor,
)
from flwr.common import Parameters


class ClientConfig(BaseModel):
Expand Down Expand Up @@ -366,7 +376,22 @@ def client_generator(cid: CID) -> fl.client.NumPyClient:
return client_generator


def dispatch_client_gen(cfg: DictConfig, **kwargs: Any) -> GetClientGen | None:
def dispatch_client_gen(
cfg: DictConfig,
saved_state: tuple[Parameters | None, ServerRNG, History],
working_dir: Path,
data_structure: DataStructure,
train_structure: TrainStructure,
config_structure: ConfigStructure,
**kwargs: Any,
) -> (
tuple[
ClientTypeGen,
type[VirtualClientEngineActor],
dict[str, Any] | None,
]
| None
):
"""Dispatch the get_client_generator function based on the hydra config.
Parameters
Expand All @@ -377,16 +402,20 @@ def dispatch_client_gen(cfg: DictConfig, **kwargs: Any) -> GetClientGen | None:
Returns
-------
Optional[GetClientGen]
The get_client_generator function.
Return None if you cannot match the cfg.
tuple[GetClientGen, type[VirtualClientEngineActor], dict[str, Any]] | None
The get_client_generator function and the actor type.
Together with actor kwargs.
"""
client_gen: str | None = cfg.get("task", None).get("client_gen", None)

if client_gen is None:
return None

if client_gen.upper() == "DEFAULT":
return get_client_generator
return (
get_client_generator,
cast(type[VirtualClientEngineActor], DefaultActor),
None,
)

return None
22 changes: 19 additions & 3 deletions project/dispatch/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,15 @@
dispatch_train as dispatch_mnist_train,
)
from project.types.common import (
ClientAndActorStructure,
ConfigStructure,
DataStructure,
GetClientGen,
ClientTypeGen,
TrainStructure,
)
from flwr.simulation.ray_transport.ray_actor import (
VirtualClientEngineActor,
)


def dispatch_train(cfg: DictConfig, **kwargs: Any) -> TrainStructure:
Expand Down Expand Up @@ -154,7 +158,9 @@ def dispatch_config(cfg: DictConfig, **kwargs: Any) -> ConfigStructure:
)


def dispatch_get_client_generator(cfg: DictConfig, **kwargs: Any) -> GetClientGen:
def dispatch_get_client_generator(
cfg: DictConfig, **kwargs: Any
) -> ClientAndActorStructure:
"""Dispatch the get_client_generator function based on the hydra config.
Functionality should be added to the dispatch.py
Expand All @@ -177,7 +183,17 @@ def dispatch_get_client_generator(cfg: DictConfig, **kwargs: Any) -> GetClientGe
The get_client_generators function.
"""
# Create the list of task dispatches to try
task_get_client_generators: list[Callable[..., GetClientGen | None]] = [
task_get_client_generators: list[
Callable[
...,
tuple[
ClientTypeGen,
type[VirtualClientEngineActor],
dict[str, Any] | None,
]
| None,
]
] = [
dispatch_default_client_gen,
]

Expand Down
49 changes: 42 additions & 7 deletions project/fed/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ClientGen,
Ext,
Files,
InitialParameterGen,
IsolatedRNGState,
IsolatedRNG,
ServerRNG,
Expand Down Expand Up @@ -94,6 +95,39 @@ def generic_get_parameters(net: nn.Module) -> NDArrays:
return parameters


def generate_initial_params_from_net_generator(
net_gen: NetGen | None,
config: dict,
server_rng: IsolatedRNG,
hydra_config: DictConfig | None,
) -> Parameters | None:
"""Generate initial parameters from a network generator.
Parameters
----------
net_gen : NetGen
The network generator.
config : dict
The configuration.
server_rng_tuple : IsolatedRNG
The server RNG tuple.
hydra_config : DictConfig
The Hydra configuration.
Returns
-------
Parameter s
The initial parameters.
"""
return (
ndarrays_to_parameters(
generic_get_parameters(net_gen(config, server_rng, hydra_config)),
)
if net_gen is not None
else None
)


def load_parameters_from_file(path: Path) -> Parameters:
"""Load parameters from a binary file.
Expand Down Expand Up @@ -131,6 +165,7 @@ def load_parameters_from_file(path: Path) -> Parameters:

def get_state(
net_generator: NetGen | None,
initial_parameter_gen: InitialParameterGen | None,
config: dict,
load_parameters_from: Path | None,
load_rng_from: Path | None,
Expand Down Expand Up @@ -167,12 +202,10 @@ def get_state(

return (
(
ndarrays_to_parameters(
generic_get_parameters(
net_generator(config, server_rng_tuple[0], hydra_config)
),
initial_parameter_gen(
net_generator, config, server_rng_tuple[0], hydra_config
)
if net_generator is not None
if initial_parameter_gen is not None
else None
),
server_rng_tuple,
Expand All @@ -182,6 +215,7 @@ def get_state(
if server_round is None:
return get_state(
net_generator,
initial_parameter_gen,
config,
None,
load_rng_from=None,
Expand Down Expand Up @@ -220,8 +254,9 @@ def get_state(

return get_state(
net_generator,
config,
None,
initial_parameter_gen,
config=config,
load_parameters_from=None,
load_rng_from=None,
load_history_from=None,
seed=seed,
Expand Down
25 changes: 19 additions & 6 deletions project/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,11 @@ def main(cfg: DictConfig) -> None:
# Change the cfg.task.model_and_data str to change functionality
(
net_generator,
initial_parameter_gen,
client_dataloader_gen,
fed_dataloader_gen,
init_working_dir,
) = dispatch_data(
) = data_structure = dispatch_data(
cfg,
)
# The folder starts either empty or only with restored files
Expand All @@ -168,7 +169,8 @@ def main(cfg: DictConfig) -> None:

saved_state = get_state(
net_generator,
cast(
initial_parameter_gen,
config=cast(
dict,
OmegaConf.to_container(
cfg.task.net_config_initial_parameters,
Expand Down Expand Up @@ -212,7 +214,7 @@ def main(cfg: DictConfig) -> None:
train_func,
test_func,
get_fed_eval_fn,
) = dispatch_train(cfg)
) = train_structure = dispatch_train(cfg)

# Obtain the on_fit config and on_eval config
# generation functions
Expand All @@ -221,9 +223,18 @@ def main(cfg: DictConfig) -> None:
(
on_fit_config_fn,
on_evaluate_config_fn,
) = dispatch_config(cfg)

get_client_generator = dispatch_get_client_generator(cfg)
) = config_structure = dispatch_config(cfg)

get_client_generator, actor_type, actor_kwargs = (
dispatch_get_client_generator(
cfg,
saved_state=saved_state,
working_dir=working_dir,
data_structure=data_structure,
train_structure=train_structure,
config_structure=config_structure,
)
)

# Build the evaluate function from the given components
# This is the function that is called on the server
Expand Down Expand Up @@ -372,6 +383,8 @@ def main(cfg: DictConfig) -> None:
if cfg.ray_address is not None
else {"include_dashboard": False}
),
actor_type=actor_type,
actor_kwargs=actor_kwargs,
)

# Sync the entire results dir to wandb if enabled
Expand Down
5 changes: 5 additions & 0 deletions project/task/default/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
)
from project.types.common import ConfigStructure, DataStructure, TrainStructure

from project.fed.utils.utils import (
generate_initial_params_from_net_generator as get_initial_parameters,
)


def dispatch_train(
cfg: DictConfig,
Expand Down Expand Up @@ -111,6 +115,7 @@ def dispatch_data(cfg: DictConfig, **kwargs: Any) -> DataStructure | None:
if client_model_and_data is not None and client_model_and_data.upper() == "DEFAULT":
ret_tuple: DataStructure = (
get_net,
get_initial_parameters,
get_client_dataloader,
get_fed_dataloader,
init_working_dir,
Expand Down
14 changes: 14 additions & 0 deletions project/task/default/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,20 @@ def fed_eval_fn(
return fed_eval_fn


# Get NONE fed eval fn
def get_none_fed_eval_fn(
net_generator: NetGen | None,
fed_dataloader_generator: FedDataloaderGen | None,
test_func: TestFunc,
_config: dict,
working_dir: Path,
rng_tuple: IsolatedRNG,
hydra_config: DictConfig | None,
) -> FedEvalFN | None:
"""Get an empty federated evaluation function."""
return None


def get_on_fit_config_fn(fit_config: dict) -> OnFitConfigFN:
"""Generate on_fit_config_fn based on a dict from the hydra config,.
Expand Down
5 changes: 5 additions & 0 deletions project/task/mnist_classification/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from typing import Any

from omegaconf import DictConfig
from project.fed.utils.utils import (
generate_initial_params_from_net_generator as get_initial_parameters,
)

from project.task.default.dispatch import (
dispatch_config as dispatch_default_config,
Expand Down Expand Up @@ -131,13 +134,15 @@ def dispatch_data(cfg: DictConfig, **kwargs: Any) -> DataStructure | None:
if client_model_and_data.upper() == "MNIST_CNN":
return (
get_net,
get_initial_parameters,
client_dataloader_gen,
fed_dataloader_gen,
init_working_dir_default,
)
elif client_model_and_data.upper() == "MNIST_LR":
return (
get_logistic_regression,
get_initial_parameters,
client_dataloader_gen,
fed_dataloader_gen,
init_working_dir_default,
Expand Down
Loading

0 comments on commit 1a44417

Please sign in to comment.