Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Directly send tensor via jit serialization #3088

Merged
merged 21 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion examples/advanced/llm_hf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ Similar patterns can be observed from the PEFT curves, purple for centralized re
![peft](./figs/fl_peft.png)

## Model Quantization for Communication
In the above example, we used float32 for communication. To reduce the message size, we can use model precision conversion and quantization
In the above example, we used numpy in float32 for communication. To reduce the message size, we can use model precision conversion and quantization
from float32 to 16-bit, 8-bit, and 4-bit for communication. Quantization is enabled by NVFlare's [filter mechanism](https://nvflare.readthedocs.io/en/main/programming_guide/filters.html). We can use the following command to run the federated training with model quantization.
16-bit is a direct precision conversion, while 8-bit, 4-bit quantization is performed by [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes/tree/main).
Note that 4-bit quantizations (`fp4` or `nf4`) need device support.
Expand All @@ -125,6 +125,17 @@ For message reduce, from float32 to 16-/8-/4-bit, the message size (in MB) of Ll

Note that quantization will generate additional meta data, which can be significant for 4-bit cases.

## Model Communication with Tensor
In addition, since the model is trained with bf16, instead of first converting to numpy in float32, we can directly communicate with tensor in bf16 to avoid the message size inflation due to the conversion.
ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved
We can use the following command to run the federated training with direct tensor communication.
```
python3 sft_job.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft_tensor --job_dir ${PWD}/workspace/jobs/hf_sft_tensor --train_mode SFT --message_mode tensor
```
Similarly, quantization can be applied to tensor communication as well.
```
python3 sft_job.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft_tensor_fp4 --job_dir ${PWD}/workspace/jobs/hf_sft_tensor_fp4 --train_mode SFT --message_mode tensor --quantize_mode float4
```

## Federated Training with Multiple Clients
With the above example, we can easily extend the federated training to multiple clients. We can use the following command to run the federated training with multiple clients:
```
Expand Down
42 changes: 33 additions & 9 deletions examples/advanced/llm_hf/sft_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.app_opt.quantization.numpy_dequantizor import NumpyModelDequantizor
from nvflare.app_opt.quantization.numpy_quantizor import NumpyModelQuantizor
from nvflare.job_config.script_runner import ScriptRunner
from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor
from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor
from nvflare.app_opt.pt.tensor_params_converter import PTReceiveParamsConverter, PTSendParamsConverter
from nvflare.job_config.script_runner import BaseScriptRunner


def main():
Expand All @@ -46,6 +47,7 @@ def main():
job_dir = args.job_dir
model_name_or_path = args.model_name_or_path
train_mode = args.train_mode
message_mode = args.message_mode

# Create the FedJob
if train_mode.lower() == "sft":
Expand All @@ -66,8 +68,8 @@ def main():

if args.quantize_mode:
# If using quantization, add quantize filters.
quantizor = NumpyModelQuantizor(quantization_type=args.quantize_mode)
dequantizor = NumpyModelDequantizor(source_data_type="float32")
quantizor = ModelQuantizor(quantization_type=args.quantize_mode)
dequantizor = ModelDequantizor()
job.to(quantizor, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizor, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)

Expand All @@ -87,11 +89,27 @@ def main():
site_name = f"site-{client_id}"
data_path_train = os.path.join(args.data_path, client_id, "training.jsonl")
data_path_valid = os.path.join(args.data_path, client_id, "validation.jsonl")
runner = ScriptRunner(
script=train_script,
script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --clean_up {clean_up}",
)

script_args = f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --message_mode {message_mode} --clean_up {clean_up}"
if message_mode == "tensor":
ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved
# Add params converters and send to client
job.to(PTSendParamsConverter(), site_name, id="pt_send")
job.to(PTReceiveParamsConverter(), site_name, id="pt_receive")
runner = BaseScriptRunner(
script=train_script,
script_args=script_args,
from_nvflare_converter_id="pt_receive",
to_nvflare_converter_id="pt_send",
)
elif message_mode == "numpy":
runner = BaseScriptRunner(
chesterxgchen marked this conversation as resolved.
Show resolved Hide resolved
script=train_script,
script_args=script_args,
)
else:
raise ValueError(f"Invalid message_mode: {message_mode}, only numpy and tensor are supported.")
job.to(runner, site_name, tasks=["train"])

if args.quantize_mode:
job.to(quantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
Expand Down Expand Up @@ -157,6 +175,12 @@ def define_parser():
default=None,
help="quantization mode, float16 or blockwise8, default to None (no quantization)",
)
parser.add_argument(
"--message_mode",
type=str,
default="numpy",
help="message mode, numpy or tensor, default to numpy",
)
parser.add_argument(
"--threads",
type=int,
Expand Down
14 changes: 11 additions & 3 deletions examples/advanced/llm_hf/src/hf_sft_peft_fl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -69,6 +69,12 @@ def main():
default="SFT",
help="training mode, SFT or PEFT, default to SFT",
)
parser.add_argument(
"--message_mode",
type=str,
default="numpy",
help="message mode, numpy or tensor, default to numpy",
)
parser.add_argument("--local_epoch", type=int, default=1)
parser.add_argument("--clean_up", type=int, default=0)
args = parser.parse_args()
Expand Down Expand Up @@ -232,8 +238,10 @@ def evaluate(input_weights, mode):
for key in list(out_param.keys()):
out_param["model." + key] = out_param.pop(key).cpu()

# cast out_param to float32 preparing for communication
out_param = {k: v.to(torch.float32) for k, v in out_param.items()}
if args.message_mode.lower() == "numpy":
# cast out_param to float32 preparing for communication with numpy
ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved
# otherwise do nothing
out_param = {k: v.to(torch.float32) for k, v in out_param.items()}

# construct trained FL model
output_model = flare.FLModel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def _init_converter(self, fl_ctx: FLContext):
if from_nvflare_converter is not None:
chesterxgchen marked this conversation as resolved.
Show resolved Hide resolved
check_object_type(self._from_nvflare_converter_id, from_nvflare_converter, ParamsConverter)
self._from_nvflare_converter = from_nvflare_converter

to_nvflare_converter: ParamsConverter = engine.get_component(self._to_nvflare_converter_id)
if to_nvflare_converter is not None:
check_object_type(self._to_nvflare_converter_id, to_nvflare_converter, ParamsConverter)
Expand Down
53 changes: 49 additions & 4 deletions nvflare/app_opt/pt/decomposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,63 @@
from nvflare.fuel.utils.fobs.datum import DatumManager


class SerializationModule(torch.nn.Module):
def __init__(self, tensor):
super().__init__()
self.register_buffer("saved_tensor", tensor)


class TensorDecomposer(fobs.Decomposer):
def supported_type(self):
return torch.Tensor

def decompose(self, target: torch.Tensor, manager: DatumManager = None) -> Any:
if target.dtype == torch.bfloat16:
ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved
return self._jit_serialize(target)
else:
return self._numpy_serialize(target)

def recompose(self, data: Any, manager: DatumManager = None) -> torch.Tensor:
if isinstance(data, dict):
if data["dtype"] == "torch.bfloat16":
return self._jit_deserialize(data)
else:
buf = data["buffer"]
else:
buf = data

return self._numpy_deserialize(buf)

@staticmethod
def _numpy_serialize(tensor: torch.Tensor) -> dict:
stream = BytesIO()
# torch.save uses Pickle so converting Tensor to ndarray first
array = target.detach().cpu().numpy()
# supported ScalarType, use numpy to avoid Pickle
array = tensor.detach().cpu().numpy()
np.save(stream, array, allow_pickle=False)
return stream.getvalue()
return {
"buffer": stream.getvalue(),
"dtype": str(tensor.dtype),
}

def recompose(self, data: Any, manager: DatumManager = None) -> torch.Tensor:
@staticmethod
def _numpy_deserialize(data: Any) -> torch.Tensor:
stream = BytesIO(data)
array = np.load(stream, allow_pickle=False)
return torch.from_numpy(array)

@staticmethod
def _jit_serialize(tensor: torch.Tensor) -> dict:
stream = BytesIO()
# unsupported ScalarType by numpy, use torch.jit to avoid Pickle
module = SerializationModule(tensor)
torch.jit.save(torch.jit.script(module), stream)
return {
"buffer": stream.getvalue(),
"dtype": str(tensor.dtype),
}

@staticmethod
def _jit_deserialize(data: Any) -> torch.Tensor:
stream = BytesIO(data["buffer"])
loaded_module = torch.jit.load(stream)
return loaded_module.saved_tensor
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
# limitations under the License.

DATA_TYPE = [
"FLOAT64",
"FLOAT32",
"FLOAT16",
"BFLOAT16",
]

QUANTIZATION_TYPE = [
Expand Down
Loading
Loading