Skip to content

Commit

Permalink
Directly send tensor via jit serialization (#3088)
Browse files Browse the repository at this point in the history
* Added support for bfloat16 tensor using JIT

* directly send tensor via jit serialization

* polish sft_job

* polish sft_job

* polish local training script

* polish tensor params converter

* polish decomposer

* format correction

* header update

* update decomposer

* end to end tensor communication passed

---------

Co-authored-by: Zhihong Zhang <[email protected]>
  • Loading branch information
ZiyueXu77 and nvidianz authored Dec 13, 2024
1 parent 20ffadb commit 38157c3
Show file tree
Hide file tree
Showing 14 changed files with 606 additions and 308 deletions.
13 changes: 12 additions & 1 deletion examples/advanced/llm_hf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ Similar patterns can be observed from the PEFT curves, purple for centralized re
![peft](./figs/fl_peft.png)

## Model Quantization for Communication
In the above example, we used float32 for communication. To reduce the message size, we can use model precision conversion and quantization
In the above example, we used numpy in float32 for communication. To reduce the message size, we can use model precision conversion and quantization
from float32 to 16-bit, 8-bit, and 4-bit for communication. Quantization is enabled by NVFlare's [filter mechanism](https://nvflare.readthedocs.io/en/main/programming_guide/filters.html). We can use the following command to run the federated training with model quantization.
16-bit is a direct precision conversion, while 8-bit, 4-bit quantization is performed by [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes/tree/main).
Note that 4-bit quantizations (`fp4` or `nf4`) need device support.
Expand All @@ -125,6 +125,17 @@ For message reduce, from float32 to 16-/8-/4-bit, the message size (in MB) of Ll

Note that quantization will generate additional meta data, which can be significant for 4-bit cases.

## Model Communication with Tensor
In addition, since the model is trained with bf16, instead of first converting to numpy in float32, we can directly communicate with tensor in bf16 to avoid the message size inflation due to the conversion.
We can use the following command to run the federated training with direct tensor communication.
```
python3 sft_job.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft_tensor --job_dir ${PWD}/workspace/jobs/hf_sft_tensor --train_mode SFT --message_mode tensor
```
Similarly, quantization can be applied to tensor communication as well.
```
python3 sft_job.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft_tensor_fp4 --job_dir ${PWD}/workspace/jobs/hf_sft_tensor_fp4 --train_mode SFT --message_mode tensor --quantize_mode float4
```

## Federated Training with Multiple Clients
With the above example, we can easily extend the federated training to multiple clients. We can use the following command to run the federated training with multiple clients:
```
Expand Down
42 changes: 33 additions & 9 deletions examples/advanced/llm_hf/sft_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.app_opt.quantization.numpy_dequantizor import NumpyModelDequantizor
from nvflare.app_opt.quantization.numpy_quantizor import NumpyModelQuantizor
from nvflare.job_config.script_runner import ScriptRunner
from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor
from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor
from nvflare.app_opt.pt.tensor_params_converter import PTReceiveParamsConverter, PTSendParamsConverter
from nvflare.job_config.script_runner import BaseScriptRunner


def main():
Expand All @@ -46,6 +47,7 @@ def main():
job_dir = args.job_dir
model_name_or_path = args.model_name_or_path
train_mode = args.train_mode
message_mode = args.message_mode

# Create the FedJob
if train_mode.lower() == "sft":
Expand All @@ -66,8 +68,8 @@ def main():

if args.quantize_mode:
# If using quantization, add quantize filters.
quantizor = NumpyModelQuantizor(quantization_type=args.quantize_mode)
dequantizor = NumpyModelDequantizor(source_data_type="float32")
quantizor = ModelQuantizor(quantization_type=args.quantize_mode)
dequantizor = ModelDequantizor()
job.to(quantizor, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizor, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)

Expand All @@ -87,11 +89,27 @@ def main():
site_name = f"site-{client_id}"
data_path_train = os.path.join(args.data_path, client_id, "training.jsonl")
data_path_valid = os.path.join(args.data_path, client_id, "validation.jsonl")
runner = ScriptRunner(
script=train_script,
script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --clean_up {clean_up}",
)

script_args = f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --message_mode {message_mode} --clean_up {clean_up}"
if message_mode == "tensor":
# Add params converters and send to client
job.to(PTSendParamsConverter(), site_name, id="pt_send")
job.to(PTReceiveParamsConverter(), site_name, id="pt_receive")
runner = BaseScriptRunner(
script=train_script,
script_args=script_args,
from_nvflare_converter_id="pt_receive",
to_nvflare_converter_id="pt_send",
)
elif message_mode == "numpy":
runner = BaseScriptRunner(
script=train_script,
script_args=script_args,
)
else:
raise ValueError(f"Invalid message_mode: {message_mode}, only numpy and tensor are supported.")
job.to(runner, site_name, tasks=["train"])

if args.quantize_mode:
job.to(quantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
Expand Down Expand Up @@ -157,6 +175,12 @@ def define_parser():
default=None,
help="quantization mode, float16 or blockwise8, default to None (no quantization)",
)
parser.add_argument(
"--message_mode",
type=str,
default="numpy",
help="message mode, numpy or tensor, default to numpy",
)
parser.add_argument(
"--threads",
type=int,
Expand Down
14 changes: 11 additions & 3 deletions examples/advanced/llm_hf/src/hf_sft_peft_fl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -69,6 +69,12 @@ def main():
default="SFT",
help="training mode, SFT or PEFT, default to SFT",
)
parser.add_argument(
"--message_mode",
type=str,
default="numpy",
help="message mode, numpy or tensor, default to numpy",
)
parser.add_argument("--local_epoch", type=int, default=1)
parser.add_argument("--clean_up", type=int, default=0)
args = parser.parse_args()
Expand Down Expand Up @@ -232,8 +238,10 @@ def evaluate(input_weights, mode):
for key in list(out_param.keys()):
out_param["model." + key] = out_param.pop(key).cpu()

# cast out_param to float32 preparing for communication
out_param = {k: v.to(torch.float32) for k, v in out_param.items()}
if args.message_mode.lower() == "numpy":
# cast out_param to float32 preparing for communication with numpy
# otherwise do nothing
out_param = {k: v.to(torch.float32) for k, v in out_param.items()}

# construct trained FL model
output_model = flare.FLModel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def _init_converter(self, fl_ctx: FLContext):
if from_nvflare_converter is not None:
check_object_type(self._from_nvflare_converter_id, from_nvflare_converter, ParamsConverter)
self._from_nvflare_converter = from_nvflare_converter

to_nvflare_converter: ParamsConverter = engine.get_component(self._to_nvflare_converter_id)
if to_nvflare_converter is not None:
check_object_type(self._to_nvflare_converter_id, to_nvflare_converter, ParamsConverter)
Expand Down
53 changes: 49 additions & 4 deletions nvflare/app_opt/pt/decomposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,63 @@
from nvflare.fuel.utils.fobs.datum import DatumManager


class SerializationModule(torch.nn.Module):
def __init__(self, tensor):
super().__init__()
self.register_buffer("saved_tensor", tensor)


class TensorDecomposer(fobs.Decomposer):
def supported_type(self):
return torch.Tensor

def decompose(self, target: torch.Tensor, manager: DatumManager = None) -> Any:
if target.dtype == torch.bfloat16:
return self._jit_serialize(target)
else:
return self._numpy_serialize(target)

def recompose(self, data: Any, manager: DatumManager = None) -> torch.Tensor:
if isinstance(data, dict):
if data["dtype"] == "torch.bfloat16":
return self._jit_deserialize(data)
else:
buf = data["buffer"]
else:
buf = data

return self._numpy_deserialize(buf)

@staticmethod
def _numpy_serialize(tensor: torch.Tensor) -> dict:
stream = BytesIO()
# torch.save uses Pickle so converting Tensor to ndarray first
array = target.detach().cpu().numpy()
# supported ScalarType, use numpy to avoid Pickle
array = tensor.detach().cpu().numpy()
np.save(stream, array, allow_pickle=False)
return stream.getvalue()
return {
"buffer": stream.getvalue(),
"dtype": str(tensor.dtype),
}

def recompose(self, data: Any, manager: DatumManager = None) -> torch.Tensor:
@staticmethod
def _numpy_deserialize(data: Any) -> torch.Tensor:
stream = BytesIO(data)
array = np.load(stream, allow_pickle=False)
return torch.from_numpy(array)

@staticmethod
def _jit_serialize(tensor: torch.Tensor) -> dict:
stream = BytesIO()
# unsupported ScalarType by numpy, use torch.jit to avoid Pickle
module = SerializationModule(tensor)
torch.jit.save(torch.jit.script(module), stream)
return {
"buffer": stream.getvalue(),
"dtype": str(tensor.dtype),
}

@staticmethod
def _jit_deserialize(data: Any) -> torch.Tensor:
stream = BytesIO(data["buffer"])
loaded_module = torch.jit.load(stream)
return loaded_module.saved_tensor
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
# limitations under the License.

DATA_TYPE = [
"FLOAT64",
"FLOAT32",
"FLOAT16",
"BFLOAT16",
]

QUANTIZATION_TYPE = [
Expand Down
Loading

0 comments on commit 38157c3

Please sign in to comment.