Federated learning is a machine learning approach where many clients collaboratively train a model under the orchestration of a central server while keeping the training data decentralized. This helps address data privacy concerns and enables lower-latency inference. However, the communication cost can be prohibitively high, especially for clients with slow or expensive communication links.
This repository implements communication-efficient federated learning algorithms in PyTorch, supporting Gaudi HPU.
The suite includes the following algorithms:
- FedAvg: Baseline Federated Averaging.
- Subsampling: Federated Averaging with Subsampling.
- Quantization: Federated Averaging with Quantization.
- EvoFed: Evolutionary Federated Learning.
- MAPA: Model-Agnostic Projection Adaptation.
- FA-LoRA: Frozen-A Low-Rank Adaptation.
- MA-LoRA: Model-Agnostic Low-Rank Adaptation.
And Benchmarked on:
- MNIST: A dataset of handwritten digits (0-9) commonly used for image classification tasks.
- Fashion-MNIST (FMNIST): A dataset of Zalando's article images used for image classification tasks.
- CIFAR-10/100: Datasets containing 10 or 100 classes of 32x32 color images, widely used for image recognition tasks.
- Shakespeare Dataset: A character-level dataset built from Shakespeare’s plays, used for next-character prediction tasks.
- Sentiment140: A dataset for sentiment analysis containing 1.6 million tweets labeled as positive, negative, or neutral.
- Reddit Dataset: A dataset of user comments from Reddit structured for federated learning tasks like next-word prediction or topic modeling.
The full roadmap is available in the project board.
Set up the environment using conda
:
conda create -n fl python=3.12
conda activate fl
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
pip install -r requirements.txt
Run experiments:
# Run the training script on GPU
python train.py
# Run the training script on CPU
python train.py --gpu -1
To install PyTorch on Gaudi-v2, use Intel Habana's setup:
- Follow the Installation Guide.
- Supported PyTorch versions are listed in the Support Matrix.
Example installation for PyTorch 2.4.0 on Ubuntu 22.04:
export OS="ubuntu22.04"
export PTV="2.4.0"
docker pull vault.habana.ai/gaudi-docker/1.18.0/${OS}/habanalabs/pytorch-installer-${PTV}:latest
docker run --name torch${PTV} -it --runtime=habana \
-e HABANA_VISIBLE_DEVICES=all \
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
--cap-add=sys_nice --net=host --ipc=host \
-v /home/irteamsu/data:/data \
vault.habana.ai/gaudi-docker/1.18.0/$OS/habanalabs/pytorch-installer-${PTV}:latest
For more details, see Getting Started with PyTorch and Gaudi.
# Set up environment variables
export WANDB_API_KEY=<your_wandb_api_key>
export WANDB_MODE="online"
export PT_HPU_LAZY_MODE=0 # Eager mode
# Build the Docker image
./docker_build_run.sh <image_name> Dockerfile . ./data
Run the code:
# Run training on Gaudi
./docker_run.sh python train.py --gaudi
# Run training on Gaudi in eager mode
PT_HPU_LAZY_MODE=0 ./docker_run.sh python train.py --gaudi --eager
# Run training on CPU
./docker_run.sh python train.py --gpu -1
project/
├── Dockerfile # Defines the container environment and dependencies.
├── setup_env.sh # Custom environment setup script for the container. Executed in Dockerfile.
├── requirements.txt # Python dependencies for the project. Installed with Dockerfile.
├── docker_build_run.sh # Builds and optionally runs a Docker container.
├── docker_run.sh # Runs a pre-built Docker container with mounts and script execution.
./docker_build_run.sh <image_name> <dockerfile_path> [source_folder] [dataset_folder] [script_to_run] [script_args...]
Example:
./docker_build_run.sh my_image ./Dockerfile ./src ./data ./train.py arg1 arg2
./docker_run.sh <image_name> [source_folder] [dataset_folder] [script_to_run] [script_args...]
Example:
./docker_run.sh my_image ./src ./data ./train.py arg1 arg2
./docker_run.sh my_image ./src ./data
Run training with specific arguments:
python train.py --dataset mnist --iid --model cnn --epochs 50 --gaudi --all_clients
See the arguments in config.py.
WANDB_API_KEY
: API key for Weights and Biases logging.WANDB_MODE
: Logging mode (online
oroffline
).PT_HPU_LAZY_MODE
: Controls Habana lazy execution mode (default:1
).
TBD
TBD
This repository draws inspiration from: