-
Notifications
You must be signed in to change notification settings - Fork 181
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add fedrag example with embedding training (#2915)
* Add fedrag example with embedding training * fix link and format * fix link and format * fix link and format * keep rag folder structure, remove the retrieveal placeholder * keep rag folder structure, remove the retrieveal placeholder * remove template job preparation * remove template job preparation * update JobAPI script * update eval bash * update eval bash and result --------- Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <[email protected]>
- Loading branch information
1 parent
5a2668d
commit 926c099
Showing
13 changed files
with
699 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Federated Retrieval-Augmented Generation (RAG) | ||
The examples in this directory illustrate how to use [NVIDIA FLARE](https://nvidia.github.io/NVFlare) for RAG tasks, including: | ||
- federated embedding model training |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Embedding Model Tuning via SentenceTransformers Trainer | ||
This example shows how to use [NVIDIA FLARE](https://nvidia.github.io/NVFlare) for embedding tuning tasks, a critical component of Retrieval-Augmented Generation (RAG). | ||
|
||
It illustrates how to adapt a local training script with [SentenceTransformers](https://github.com/UKPLab/sentence-transformers) trainer to NVFlare. | ||
|
||
## Introduction | ||
[SentenceTransformers](https://sbert.net/) is a widely used framework for computing dense vector representations for texts. | ||
The models are based on transformer, achieving state-of-the-art performance in various tasks. | ||
|
||
One major application is to embed the text in vector space for later clustering and/or retrieval using similarity metrics. | ||
|
||
This example illustrates a supervised fine-tuning (SFT) scheme for an embedding model with various training datasets. | ||
|
||
## Setup | ||
Please make sure you set up virtual environment following [example root readme](../../../README.md). | ||
Install additional requirements (if you already have a specific version of nvflare installed in your environment, you may want to remove nvflare in the requirements to avoid reinstalling nvflare): | ||
``` | ||
python3 -m pip install -r requirements.txt | ||
``` | ||
Models and data will be loaded directly from Huggingface, so no need to download them manually. | ||
|
||
## Centralized Training | ||
### Single-session training | ||
Centralized trainings, as the baseline for comparison with FL results, are done with the following command: | ||
``` | ||
bash train_single_session.sh | ||
``` | ||
|
||
### Adaptation Step 1: iterative training | ||
To adapt the centralized training script to federated application, under `launch_once = true` setting, we first need to "break" the single call to `trainer.train()` into iterative calls, one for each round of training. | ||
For this purpose, we provided `utils/train_iterative.py` as an example, which is a modified version of `utils/train_single_session.py`. | ||
|
||
In the iterative training script, the `trainer.train()` call is replaced by a `for` loop, and the training epochs are split into six rounds, `unit_train_epochs = 0.25` epoch per round, in total `0.25 * 6 = 1.5` epochs, same as single session setting. | ||
|
||
The first round is trained with `trainer.train()`, then from the second round, | ||
we call `trainer.train(resume_from_checkpoint=True)` with `args.num_train_epochs` incremented by `unit_train_epochs` to continue training from the last checkpoint. | ||
|
||
To run iterative training, we use the following command: | ||
``` | ||
bash train_iterative.sh | ||
``` | ||
|
||
The training loss curves are shown below, single session and iterative scripts align with each other. | ||
|
||
![iter_single](./figs/iter_single.png) | ||
|
||
### Adaptation Step 2: federated with NVFlare | ||
Once we have the iterative training script ready with "starting model" loading capability, it can be easily adapted to a NVFlare trainer by using [Client API](../../../hello-world/ml-to-fl/pt/README.md). | ||
|
||
The major code modifications are for receiving the global model, set it as the starting point for each round's training, and returning the trained model after each local training round. | ||
|
||
## Federated Training | ||
We can use the Python JobAPI to create and run the federated training job. | ||
``` | ||
python3 train_fed.py | ||
``` | ||
|
||
## Results | ||
Below are the evaluation results on two test datasets - [stsb](https://huggingface.co/datasets/sentence-transformers/stsb) with embedding similarity evaluation, and [NLI](https://huggingface.co/datasets/sentence-transformers/all-nli) with triplet accuracy evaluation. The candidate models are: | ||
- NLI: single site training using NLI data | ||
- Squad: single site training using Squad data | ||
- Quora: single site training using Quora data | ||
- All: centralized training using the combined data (see `utils/train_single_session.py`) | ||
- Federated: three sites federated learning, each site contains its own data of NLI, Squad or Quora | ||
|
||
We listed two similarity metrics for each of the two testing datasets: | ||
```commandline | ||
bash eval_all.sh | ||
``` | ||
|
||
TrainData | STSB_pearson_cos | STSB_spearman_euc | NLI_cos_acc | NLI_euc_acc | ||
--- |------------------|-------------------|-------------| --- | ||
NLI | 0.7586 | 0.7895 | 0.8033 | 0.8045 | ||
Squad | 0.8206 | 0.8154 | 0.8051 | 0.8042 | ||
Quora | 0.8161 | 0.8121 | 0.7891 | 0.7854 | ||
All | 0.8497 | 0.8523 | 0.8426 | 0.8384 | ||
Federated | 0.8443 | 0.8367 | 0.8261 | 0.8249 | ||
|
||
As shown, the federated training results are better than individual site's, and can be close to the centralized training results, demonstrating the effectiveness of NVFlare in embedding model tuning tasks. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
for dataset_name in nli squad quora all | ||
do | ||
echo "Evaluation on model ${dataset_name}" | ||
python utils/eval_model.py --model_path /tmp/embed/cen/models_single/mpnet-base-${dataset_name}/final | ||
done | ||
|
||
echo "Evaluation on model federated" | ||
python utils/eval_model.py --model_path /tmp/embed/nvflare/workspace_api/site-1/models/mpnet-base-nli/global |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
nvflare~=2.5.0 | ||
torch | ||
datasets | ||
scikit-learn | ||
tensorboard | ||
transformers | ||
sentence-transformers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
from sentence_transformers import SentenceTransformer | ||
|
||
|
||
class SenTransModel(torch.nn.Module): | ||
def __init__(self, model_name): | ||
super(SenTransModel, self).__init__() | ||
self.model = SentenceTransformer(model_name) | ||
|
||
def forward(self, input_id): | ||
output = self.model(input_ids=input_id, return_dict=False) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import copy | ||
|
||
from datasets import load_dataset | ||
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments | ||
from sentence_transformers.losses import MultipleNegativesRankingLoss | ||
from sentence_transformers.training_args import BatchSamplers | ||
from transformers import trainer_utils | ||
|
||
import nvflare.client as flare | ||
|
||
|
||
def main(): | ||
# argparse | ||
parser = argparse.ArgumentParser(description="Train a model on a dataset") | ||
parser.add_argument( | ||
"--model_name", | ||
type=str, | ||
default="microsoft/mpnet-base", | ||
) | ||
parser.add_argument( | ||
"--dataset_name", | ||
type=str, | ||
default="nli", | ||
) | ||
args = parser.parse_args() | ||
model_name = args.model_name | ||
dataset_name = args.dataset_name | ||
|
||
# Load a model to finetune with | ||
model = SentenceTransformer(model_name) | ||
|
||
# Load training datasets | ||
if dataset_name == "nli": | ||
# (anchor, positive, negative) | ||
dataset_train = load_dataset("sentence-transformers/all-nli", "triplet", split="train[:16000]") | ||
dataset_val = load_dataset("sentence-transformers/all-nli", "triplet", split="dev") | ||
elif dataset_name == "squad": | ||
# (question, answer) | ||
dataset_train = load_dataset("sentence-transformers/squad", split="train[:16000]") | ||
dataset_val = load_dataset("sentence-transformers/squad", split="train[16000:18000]") | ||
elif dataset_name == "quora": | ||
# (anchor, positive) | ||
dataset_train = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[:16000]") | ||
dataset_val = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[16000:18000]") | ||
else: | ||
raise ValueError(f"Unknown dataset name: {dataset_name}") | ||
|
||
# Load loss function | ||
loss = MultipleNegativesRankingLoss(model) | ||
|
||
base_model_name = model_name.split("/")[-1] | ||
output_dir = f"./models/{base_model_name}-{dataset_name}" | ||
unit_train_epochs = 0.25 | ||
# Specify training arguments | ||
args = SentenceTransformerTrainingArguments( | ||
# Required parameter: | ||
output_dir=output_dir, | ||
# Optional training parameters: | ||
num_train_epochs=unit_train_epochs, | ||
per_device_train_batch_size=16, | ||
per_device_eval_batch_size=16, | ||
learning_rate=1e-6, | ||
lr_scheduler_type="constant", | ||
bf16=True, | ||
batch_sampler=BatchSamplers.NO_DUPLICATES, | ||
# Optional tracking/debugging parameters: | ||
eval_strategy="steps", | ||
eval_steps=50, | ||
save_strategy="steps", | ||
save_steps=50, | ||
save_total_limit=1, | ||
# logging parameters: | ||
logging_dir=f"{output_dir}/logs", | ||
logging_strategy="steps", | ||
logging_steps=50, | ||
report_to="tensorboard", | ||
) | ||
|
||
# Define trainer | ||
trainer = SentenceTransformerTrainer( | ||
model=model, | ||
args=args, | ||
train_dataset=dataset_train, | ||
eval_dataset=dataset_val, | ||
loss=loss, | ||
) | ||
|
||
# initializes NVFlare client API | ||
flare.init() | ||
|
||
while flare.is_running(): | ||
# receives FLModel from NVFlare | ||
input_model = flare.receive() | ||
curr_round = input_model.current_round | ||
print(f"current_round={curr_round}") | ||
|
||
# Update the key name received from global model if using model def file | ||
global_model = copy.deepcopy(input_model.params) | ||
for key in list(global_model.keys()): | ||
global_model[key.replace("model.", "", 1)] = global_model.pop(key) | ||
|
||
# evaluate on received global model | ||
trainer.model.load_state_dict(global_model) | ||
eval_loss_dict = trainer.evaluate() | ||
eval_loss = float(eval_loss_dict["eval_loss"]) | ||
print(f"Evaluation loss: {eval_loss}") | ||
# Save the global model | ||
model.save_pretrained(f"{output_dir}/global") | ||
|
||
# Train the model | ||
if curr_round == 0: | ||
# First round: start from scratch | ||
trainer.train() | ||
else: | ||
# Subsequent rounds: start from the previous model | ||
# Since we perform iterative training by using "resume" functionality | ||
# we need to replace the resume weights with global weights every round | ||
resume_from_checkpoint_folder = trainer_utils.get_last_checkpoint(trainer.args.output_dir) | ||
# update local record with global model weights | ||
trainer.model.save_pretrained(resume_from_checkpoint_folder) | ||
# increment the number of training epochs so that the trainer will continue training | ||
args.num_train_epochs += unit_train_epochs | ||
# continue training | ||
trainer.train(resume_from_checkpoint=True) | ||
|
||
# update the key name sent to global model | ||
out_param = trainer.model.state_dict() | ||
for key in list(out_param.keys()): | ||
out_param["model." + key] = out_param.pop(key).cpu() | ||
num_steps = trainer.train_dataset.num_rows * unit_train_epochs | ||
|
||
# construct trained FL model | ||
output_model = flare.FLModel( | ||
params=out_param, | ||
metrics={"eval_loss": eval_loss}, | ||
meta={"NUM_STEPS_CURRENT_ROUND": num_steps}, | ||
) | ||
# send model back to NVFlare | ||
flare.send(output_model) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from nvflare import FedJob | ||
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector | ||
from nvflare.app_common.workflows.fedavg import FedAvg | ||
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor | ||
from nvflare.job_config.script_runner import ScriptRunner | ||
|
||
if __name__ == "__main__": | ||
n_clients = 3 | ||
num_rounds = 7 | ||
train_script = "src/train_fl.py" | ||
|
||
# Create the FedJob | ||
job = FedJob(name="embed_fedavg", min_clients=3, mandatory_clients=["site-1", "site-2", "site-3"]) | ||
|
||
# Define the FedAvg controller workflow and send to server | ||
controller = FedAvg( | ||
num_clients=n_clients, | ||
num_rounds=num_rounds, | ||
) | ||
job.to(controller, "server") | ||
|
||
# Define the model persistor and send to server | ||
# First send the model to the server | ||
job.to("src/st_model.py", "server") | ||
# Then send the model persistor to the server | ||
model_args = {"path": "src.st_model.SenTransModel", "args": {"model_name": "microsoft/mpnet-base"}} | ||
job.to(PTFileModelPersistor(model=model_args), "server", id="persistor") | ||
|
||
# Add model selection widget and send to server | ||
job.to(IntimeModelSelector(key_metric="eval_loss", negate_key_metric=True), "server", id="model_selector") | ||
|
||
# Send ScriptRunner to all clients | ||
runner = ScriptRunner(script=train_script, script_args="--dataset_name nli") | ||
job.to(runner, "site-1") | ||
runner = ScriptRunner(script=train_script, script_args="--dataset_name squad") | ||
job.to(runner, "site-2") | ||
runner = ScriptRunner(script=train_script, script_args="--dataset_name quora") | ||
job.to(runner, "site-3") | ||
|
||
job.export_job("/tmp/embed/nvflare/job_api") | ||
job.simulator_run("/tmp/embed/nvflare/workspace_api") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
for dataset_name in nli squad quora | ||
do | ||
echo "Training on ${dataset_name}" | ||
python utils/train_iterative.py --dataset_name ${dataset_name} | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
for dataset_name in nli squad quora all | ||
do | ||
echo "Training on ${dataset_name}" | ||
python utils/train_single_session.py --dataset_name ${dataset_name} | ||
done |
Oops, something went wrong.