Skip to content

Commit

Permalink
Add fedrag example with embedding training (#2915)
Browse files Browse the repository at this point in the history
* Add fedrag example with embedding training

* fix link and format

* fix link and format

* fix link and format

* keep rag folder structure, remove the retrieveal placeholder

* keep rag folder structure, remove the retrieveal placeholder

* remove template job preparation

* remove template job preparation

* update JobAPI script

* update eval bash

* update eval bash and result

---------

Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <[email protected]>
  • Loading branch information
ZiyueXu77 and YuanTingHsieh authored Sep 26, 2024
1 parent 5a2668d commit 926c099
Show file tree
Hide file tree
Showing 13 changed files with 699 additions and 0 deletions.
3 changes: 3 additions & 0 deletions examples/advanced/rag/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Federated Retrieval-Augmented Generation (RAG)
The examples in this directory illustrate how to use [NVIDIA FLARE](https://nvidia.github.io/NVFlare) for RAG tasks, including:
- federated embedding model training
79 changes: 79 additions & 0 deletions examples/advanced/rag/embedding/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Embedding Model Tuning via SentenceTransformers Trainer
This example shows how to use [NVIDIA FLARE](https://nvidia.github.io/NVFlare) for embedding tuning tasks, a critical component of Retrieval-Augmented Generation (RAG).

It illustrates how to adapt a local training script with [SentenceTransformers](https://github.com/UKPLab/sentence-transformers) trainer to NVFlare.

## Introduction
[SentenceTransformers](https://sbert.net/) is a widely used framework for computing dense vector representations for texts.
The models are based on transformer, achieving state-of-the-art performance in various tasks.

One major application is to embed the text in vector space for later clustering and/or retrieval using similarity metrics.

This example illustrates a supervised fine-tuning (SFT) scheme for an embedding model with various training datasets.

## Setup
Please make sure you set up virtual environment following [example root readme](../../../README.md).
Install additional requirements (if you already have a specific version of nvflare installed in your environment, you may want to remove nvflare in the requirements to avoid reinstalling nvflare):
```
python3 -m pip install -r requirements.txt
```
Models and data will be loaded directly from Huggingface, so no need to download them manually.

## Centralized Training
### Single-session training
Centralized trainings, as the baseline for comparison with FL results, are done with the following command:
```
bash train_single_session.sh
```

### Adaptation Step 1: iterative training
To adapt the centralized training script to federated application, under `launch_once = true` setting, we first need to "break" the single call to `trainer.train()` into iterative calls, one for each round of training.
For this purpose, we provided `utils/train_iterative.py` as an example, which is a modified version of `utils/train_single_session.py`.

In the iterative training script, the `trainer.train()` call is replaced by a `for` loop, and the training epochs are split into six rounds, `unit_train_epochs = 0.25` epoch per round, in total `0.25 * 6 = 1.5` epochs, same as single session setting.

The first round is trained with `trainer.train()`, then from the second round,
we call `trainer.train(resume_from_checkpoint=True)` with `args.num_train_epochs` incremented by `unit_train_epochs` to continue training from the last checkpoint.

To run iterative training, we use the following command:
```
bash train_iterative.sh
```

The training loss curves are shown below, single session and iterative scripts align with each other.

![iter_single](./figs/iter_single.png)

### Adaptation Step 2: federated with NVFlare
Once we have the iterative training script ready with "starting model" loading capability, it can be easily adapted to a NVFlare trainer by using [Client API](../../../hello-world/ml-to-fl/pt/README.md).

The major code modifications are for receiving the global model, set it as the starting point for each round's training, and returning the trained model after each local training round.

## Federated Training
We can use the Python JobAPI to create and run the federated training job.
```
python3 train_fed.py
```

## Results
Below are the evaluation results on two test datasets - [stsb](https://huggingface.co/datasets/sentence-transformers/stsb) with embedding similarity evaluation, and [NLI](https://huggingface.co/datasets/sentence-transformers/all-nli) with triplet accuracy evaluation. The candidate models are:
- NLI: single site training using NLI data
- Squad: single site training using Squad data
- Quora: single site training using Quora data
- All: centralized training using the combined data (see `utils/train_single_session.py`)
- Federated: three sites federated learning, each site contains its own data of NLI, Squad or Quora

We listed two similarity metrics for each of the two testing datasets:
```commandline
bash eval_all.sh
```

TrainData | STSB_pearson_cos | STSB_spearman_euc | NLI_cos_acc | NLI_euc_acc
--- |------------------|-------------------|-------------| ---
NLI | 0.7586 | 0.7895 | 0.8033 | 0.8045
Squad | 0.8206 | 0.8154 | 0.8051 | 0.8042
Quora | 0.8161 | 0.8121 | 0.7891 | 0.7854
All | 0.8497 | 0.8523 | 0.8426 | 0.8384
Federated | 0.8443 | 0.8367 | 0.8261 | 0.8249

As shown, the federated training results are better than individual site's, and can be close to the centralized training results, demonstrating the effectiveness of NVFlare in embedding model tuning tasks.
8 changes: 8 additions & 0 deletions examples/advanced/rag/embedding/eval_all.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
for dataset_name in nli squad quora all
do
echo "Evaluation on model ${dataset_name}"
python utils/eval_model.py --model_path /tmp/embed/cen/models_single/mpnet-base-${dataset_name}/final
done

echo "Evaluation on model federated"
python utils/eval_model.py --model_path /tmp/embed/nvflare/workspace_api/site-1/models/mpnet-base-nli/global
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions examples/advanced/rag/embedding/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
nvflare~=2.5.0
torch
datasets
scikit-learn
tensorboard
transformers
sentence-transformers
26 changes: 26 additions & 0 deletions examples/advanced/rag/embedding/src/st_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from sentence_transformers import SentenceTransformer


class SenTransModel(torch.nn.Module):
def __init__(self, model_name):
super(SenTransModel, self).__init__()
self.model = SentenceTransformer(model_name)

def forward(self, input_id):
output = self.model(input_ids=input_id, return_dict=False)
return output
158 changes: 158 additions & 0 deletions examples/advanced/rag/embedding/src/train_fl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import copy

from datasets import load_dataset
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from transformers import trainer_utils

import nvflare.client as flare


def main():
# argparse
parser = argparse.ArgumentParser(description="Train a model on a dataset")
parser.add_argument(
"--model_name",
type=str,
default="microsoft/mpnet-base",
)
parser.add_argument(
"--dataset_name",
type=str,
default="nli",
)
args = parser.parse_args()
model_name = args.model_name
dataset_name = args.dataset_name

# Load a model to finetune with
model = SentenceTransformer(model_name)

# Load training datasets
if dataset_name == "nli":
# (anchor, positive, negative)
dataset_train = load_dataset("sentence-transformers/all-nli", "triplet", split="train[:16000]")
dataset_val = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
elif dataset_name == "squad":
# (question, answer)
dataset_train = load_dataset("sentence-transformers/squad", split="train[:16000]")
dataset_val = load_dataset("sentence-transformers/squad", split="train[16000:18000]")
elif dataset_name == "quora":
# (anchor, positive)
dataset_train = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[:16000]")
dataset_val = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[16000:18000]")
else:
raise ValueError(f"Unknown dataset name: {dataset_name}")

# Load loss function
loss = MultipleNegativesRankingLoss(model)

base_model_name = model_name.split("/")[-1]
output_dir = f"./models/{base_model_name}-{dataset_name}"
unit_train_epochs = 0.25
# Specify training arguments
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=output_dir,
# Optional training parameters:
num_train_epochs=unit_train_epochs,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=1e-6,
lr_scheduler_type="constant",
bf16=True,
batch_sampler=BatchSamplers.NO_DUPLICATES,
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=50,
save_strategy="steps",
save_steps=50,
save_total_limit=1,
# logging parameters:
logging_dir=f"{output_dir}/logs",
logging_strategy="steps",
logging_steps=50,
report_to="tensorboard",
)

# Define trainer
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=dataset_train,
eval_dataset=dataset_val,
loss=loss,
)

# initializes NVFlare client API
flare.init()

while flare.is_running():
# receives FLModel from NVFlare
input_model = flare.receive()
curr_round = input_model.current_round
print(f"current_round={curr_round}")

# Update the key name received from global model if using model def file
global_model = copy.deepcopy(input_model.params)
for key in list(global_model.keys()):
global_model[key.replace("model.", "", 1)] = global_model.pop(key)

# evaluate on received global model
trainer.model.load_state_dict(global_model)
eval_loss_dict = trainer.evaluate()
eval_loss = float(eval_loss_dict["eval_loss"])
print(f"Evaluation loss: {eval_loss}")
# Save the global model
model.save_pretrained(f"{output_dir}/global")

# Train the model
if curr_round == 0:
# First round: start from scratch
trainer.train()
else:
# Subsequent rounds: start from the previous model
# Since we perform iterative training by using "resume" functionality
# we need to replace the resume weights with global weights every round
resume_from_checkpoint_folder = trainer_utils.get_last_checkpoint(trainer.args.output_dir)
# update local record with global model weights
trainer.model.save_pretrained(resume_from_checkpoint_folder)
# increment the number of training epochs so that the trainer will continue training
args.num_train_epochs += unit_train_epochs
# continue training
trainer.train(resume_from_checkpoint=True)

# update the key name sent to global model
out_param = trainer.model.state_dict()
for key in list(out_param.keys()):
out_param["model." + key] = out_param.pop(key).cpu()
num_steps = trainer.train_dataset.num_rows * unit_train_epochs

# construct trained FL model
output_model = flare.FLModel(
params=out_param,
metrics={"eval_loss": eval_loss},
meta={"NUM_STEPS_CURRENT_ROUND": num_steps},
)
# send model back to NVFlare
flare.send(output_model)


if __name__ == "__main__":
main()
56 changes: 56 additions & 0 deletions examples/advanced/rag/embedding/train_fed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from nvflare import FedJob
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.job_config.script_runner import ScriptRunner

if __name__ == "__main__":
n_clients = 3
num_rounds = 7
train_script = "src/train_fl.py"

# Create the FedJob
job = FedJob(name="embed_fedavg", min_clients=3, mandatory_clients=["site-1", "site-2", "site-3"])

# Define the FedAvg controller workflow and send to server
controller = FedAvg(
num_clients=n_clients,
num_rounds=num_rounds,
)
job.to(controller, "server")

# Define the model persistor and send to server
# First send the model to the server
job.to("src/st_model.py", "server")
# Then send the model persistor to the server
model_args = {"path": "src.st_model.SenTransModel", "args": {"model_name": "microsoft/mpnet-base"}}
job.to(PTFileModelPersistor(model=model_args), "server", id="persistor")

# Add model selection widget and send to server
job.to(IntimeModelSelector(key_metric="eval_loss", negate_key_metric=True), "server", id="model_selector")

# Send ScriptRunner to all clients
runner = ScriptRunner(script=train_script, script_args="--dataset_name nli")
job.to(runner, "site-1")
runner = ScriptRunner(script=train_script, script_args="--dataset_name squad")
job.to(runner, "site-2")
runner = ScriptRunner(script=train_script, script_args="--dataset_name quora")
job.to(runner, "site-3")

job.export_job("/tmp/embed/nvflare/job_api")
job.simulator_run("/tmp/embed/nvflare/workspace_api")
5 changes: 5 additions & 0 deletions examples/advanced/rag/embedding/train_iterative.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
for dataset_name in nli squad quora
do
echo "Training on ${dataset_name}"
python utils/train_iterative.py --dataset_name ${dataset_name}
done
5 changes: 5 additions & 0 deletions examples/advanced/rag/embedding/train_single_session.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
for dataset_name in nli squad quora all
do
echo "Training on ${dataset_name}"
python utils/train_single_session.py --dataset_name ${dataset_name}
done
Loading

0 comments on commit 926c099

Please sign in to comment.