From 38157c3622af1f808b30e2e00bc95b13b4f24070 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Fri, 13 Dec 2024 17:05:09 -0500 Subject: [PATCH] Directly send tensor via jit serialization (#3088) * Added support for bfloat16 tensor using JIT * directly send tensor via jit serialization * polish sft_job * polish sft_job * polish local training script * polish tensor params converter * polish decomposer * format correction * header update * update decomposer * end to end tensor communication passed --------- Co-authored-by: Zhihong Zhang --- examples/advanced/llm_hf/README.md | 13 +- examples/advanced/llm_hf/sft_job.py | 42 +++- .../advanced/llm_hf/src/hf_sft_peft_fl.py | 14 +- .../in_process_client_api_executor.py | 1 - nvflare/app_opt/pt/decomposers.py | 53 ++++- .../app_opt/{ => pt}/quantization/__init__.py | 0 .../app_opt/{ => pt}/quantization/constant.py | 3 + .../app_opt/pt/quantization/dequantizor.py | 197 +++++++++++++++++ nvflare/app_opt/pt/quantization/quantizor.py | 206 ++++++++++++++++++ nvflare/app_opt/pt/tensor_params_converter.py | 65 ++++++ .../app_opt/quantization/numpy_dequantizor.py | 136 ------------ .../app_opt/quantization/numpy_quantizor.py | 148 ------------- nvflare/job_config/script_runner.py | 9 +- .../app_opt/quantization/quantization_test.py | 27 ++- 14 files changed, 606 insertions(+), 308 deletions(-) rename nvflare/app_opt/{ => pt}/quantization/__init__.py (100%) rename nvflare/app_opt/{ => pt}/quantization/constant.py (94%) create mode 100644 nvflare/app_opt/pt/quantization/dequantizor.py create mode 100644 nvflare/app_opt/pt/quantization/quantizor.py create mode 100644 nvflare/app_opt/pt/tensor_params_converter.py delete mode 100644 nvflare/app_opt/quantization/numpy_dequantizor.py delete mode 100644 nvflare/app_opt/quantization/numpy_quantizor.py diff --git a/examples/advanced/llm_hf/README.md b/examples/advanced/llm_hf/README.md index 6cbe1ae994..173536d61f 100644 --- a/examples/advanced/llm_hf/README.md +++ b/examples/advanced/llm_hf/README.md @@ -99,7 +99,7 @@ Similar patterns can be observed from the PEFT curves, purple for centralized re ![peft](./figs/fl_peft.png) ## Model Quantization for Communication -In the above example, we used float32 for communication. To reduce the message size, we can use model precision conversion and quantization +In the above example, we used numpy in float32 for communication. To reduce the message size, we can use model precision conversion and quantization from float32 to 16-bit, 8-bit, and 4-bit for communication. Quantization is enabled by NVFlare's [filter mechanism](https://nvflare.readthedocs.io/en/main/programming_guide/filters.html). We can use the following command to run the federated training with model quantization. 16-bit is a direct precision conversion, while 8-bit, 4-bit quantization is performed by [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes/tree/main). Note that 4-bit quantizations (`fp4` or `nf4`) need device support. @@ -125,6 +125,17 @@ For message reduce, from float32 to 16-/8-/4-bit, the message size (in MB) of Ll Note that quantization will generate additional meta data, which can be significant for 4-bit cases. +## Model Communication with Tensor +In addition, since the model is trained with bf16, instead of first converting to numpy in float32, we can directly communicate with tensor in bf16 to avoid the message size inflation due to the conversion. +We can use the following command to run the federated training with direct tensor communication. +``` +python3 sft_job.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft_tensor --job_dir ${PWD}/workspace/jobs/hf_sft_tensor --train_mode SFT --message_mode tensor +``` +Similarly, quantization can be applied to tensor communication as well. +``` +python3 sft_job.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft_tensor_fp4 --job_dir ${PWD}/workspace/jobs/hf_sft_tensor_fp4 --train_mode SFT --message_mode tensor --quantize_mode float4 +``` + ## Federated Training with Multiple Clients With the above example, we can easily extend the federated training to multiple clients. We can use the following command to run the federated training with multiple clients: ``` diff --git a/examples/advanced/llm_hf/sft_job.py b/examples/advanced/llm_hf/sft_job.py index 3b5221baec..f14a0ce4cb 100644 --- a/examples/advanced/llm_hf/sft_job.py +++ b/examples/advanced/llm_hf/sft_job.py @@ -19,9 +19,10 @@ from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector from nvflare.app_common.workflows.fedavg import FedAvg from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor -from nvflare.app_opt.quantization.numpy_dequantizor import NumpyModelDequantizor -from nvflare.app_opt.quantization.numpy_quantizor import NumpyModelQuantizor -from nvflare.job_config.script_runner import ScriptRunner +from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor +from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor +from nvflare.app_opt.pt.tensor_params_converter import PTReceiveParamsConverter, PTSendParamsConverter +from nvflare.job_config.script_runner import BaseScriptRunner def main(): @@ -46,6 +47,7 @@ def main(): job_dir = args.job_dir model_name_or_path = args.model_name_or_path train_mode = args.train_mode + message_mode = args.message_mode # Create the FedJob if train_mode.lower() == "sft": @@ -66,8 +68,8 @@ def main(): if args.quantize_mode: # If using quantization, add quantize filters. - quantizor = NumpyModelQuantizor(quantization_type=args.quantize_mode) - dequantizor = NumpyModelDequantizor(source_data_type="float32") + quantizor = ModelQuantizor(quantization_type=args.quantize_mode) + dequantizor = ModelDequantizor() job.to(quantizor, "server", tasks=["train"], filter_type=FilterType.TASK_DATA) job.to(dequantizor, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT) @@ -87,11 +89,27 @@ def main(): site_name = f"site-{client_id}" data_path_train = os.path.join(args.data_path, client_id, "training.jsonl") data_path_valid = os.path.join(args.data_path, client_id, "validation.jsonl") - runner = ScriptRunner( - script=train_script, - script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --clean_up {clean_up}", - ) + + script_args = f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --message_mode {message_mode} --clean_up {clean_up}" + if message_mode == "tensor": + # Add params converters and send to client + job.to(PTSendParamsConverter(), site_name, id="pt_send") + job.to(PTReceiveParamsConverter(), site_name, id="pt_receive") + runner = BaseScriptRunner( + script=train_script, + script_args=script_args, + from_nvflare_converter_id="pt_receive", + to_nvflare_converter_id="pt_send", + ) + elif message_mode == "numpy": + runner = BaseScriptRunner( + script=train_script, + script_args=script_args, + ) + else: + raise ValueError(f"Invalid message_mode: {message_mode}, only numpy and tensor are supported.") job.to(runner, site_name, tasks=["train"]) + if args.quantize_mode: job.to(quantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT) job.to(dequantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA) @@ -157,6 +175,12 @@ def define_parser(): default=None, help="quantization mode, float16 or blockwise8, default to None (no quantization)", ) + parser.add_argument( + "--message_mode", + type=str, + default="numpy", + help="message mode, numpy or tensor, default to numpy", + ) parser.add_argument( "--threads", type=int, diff --git a/examples/advanced/llm_hf/src/hf_sft_peft_fl.py b/examples/advanced/llm_hf/src/hf_sft_peft_fl.py index 1411fabd95..49113190ff 100755 --- a/examples/advanced/llm_hf/src/hf_sft_peft_fl.py +++ b/examples/advanced/llm_hf/src/hf_sft_peft_fl.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -69,6 +69,12 @@ def main(): default="SFT", help="training mode, SFT or PEFT, default to SFT", ) + parser.add_argument( + "--message_mode", + type=str, + default="numpy", + help="message mode, numpy or tensor, default to numpy", + ) parser.add_argument("--local_epoch", type=int, default=1) parser.add_argument("--clean_up", type=int, default=0) args = parser.parse_args() @@ -232,8 +238,10 @@ def evaluate(input_weights, mode): for key in list(out_param.keys()): out_param["model." + key] = out_param.pop(key).cpu() - # cast out_param to float32 preparing for communication - out_param = {k: v.to(torch.float32) for k, v in out_param.items()} + if args.message_mode.lower() == "numpy": + # cast out_param to float32 preparing for communication with numpy + # otherwise do nothing + out_param = {k: v.to(torch.float32) for k, v in out_param.items()} # construct trained FL model output_model = flare.FLModel( diff --git a/nvflare/app_common/executors/in_process_client_api_executor.py b/nvflare/app_common/executors/in_process_client_api_executor.py index e04ff5673e..c89233904b 100644 --- a/nvflare/app_common/executors/in_process_client_api_executor.py +++ b/nvflare/app_common/executors/in_process_client_api_executor.py @@ -204,7 +204,6 @@ def _init_converter(self, fl_ctx: FLContext): if from_nvflare_converter is not None: check_object_type(self._from_nvflare_converter_id, from_nvflare_converter, ParamsConverter) self._from_nvflare_converter = from_nvflare_converter - to_nvflare_converter: ParamsConverter = engine.get_component(self._to_nvflare_converter_id) if to_nvflare_converter is not None: check_object_type(self._to_nvflare_converter_id, to_nvflare_converter, ParamsConverter) diff --git a/nvflare/app_opt/pt/decomposers.py b/nvflare/app_opt/pt/decomposers.py index f009a6b1c2..a7f071f33d 100644 --- a/nvflare/app_opt/pt/decomposers.py +++ b/nvflare/app_opt/pt/decomposers.py @@ -22,18 +22,63 @@ from nvflare.fuel.utils.fobs.datum import DatumManager +class SerializationModule(torch.nn.Module): + def __init__(self, tensor): + super().__init__() + self.register_buffer("saved_tensor", tensor) + + class TensorDecomposer(fobs.Decomposer): def supported_type(self): return torch.Tensor def decompose(self, target: torch.Tensor, manager: DatumManager = None) -> Any: + if target.dtype == torch.bfloat16: + return self._jit_serialize(target) + else: + return self._numpy_serialize(target) + + def recompose(self, data: Any, manager: DatumManager = None) -> torch.Tensor: + if isinstance(data, dict): + if data["dtype"] == "torch.bfloat16": + return self._jit_deserialize(data) + else: + buf = data["buffer"] + else: + buf = data + + return self._numpy_deserialize(buf) + + @staticmethod + def _numpy_serialize(tensor: torch.Tensor) -> dict: stream = BytesIO() - # torch.save uses Pickle so converting Tensor to ndarray first - array = target.detach().cpu().numpy() + # supported ScalarType, use numpy to avoid Pickle + array = tensor.detach().cpu().numpy() np.save(stream, array, allow_pickle=False) - return stream.getvalue() + return { + "buffer": stream.getvalue(), + "dtype": str(tensor.dtype), + } - def recompose(self, data: Any, manager: DatumManager = None) -> torch.Tensor: + @staticmethod + def _numpy_deserialize(data: Any) -> torch.Tensor: stream = BytesIO(data) array = np.load(stream, allow_pickle=False) return torch.from_numpy(array) + + @staticmethod + def _jit_serialize(tensor: torch.Tensor) -> dict: + stream = BytesIO() + # unsupported ScalarType by numpy, use torch.jit to avoid Pickle + module = SerializationModule(tensor) + torch.jit.save(torch.jit.script(module), stream) + return { + "buffer": stream.getvalue(), + "dtype": str(tensor.dtype), + } + + @staticmethod + def _jit_deserialize(data: Any) -> torch.Tensor: + stream = BytesIO(data["buffer"]) + loaded_module = torch.jit.load(stream) + return loaded_module.saved_tensor diff --git a/nvflare/app_opt/quantization/__init__.py b/nvflare/app_opt/pt/quantization/__init__.py similarity index 100% rename from nvflare/app_opt/quantization/__init__.py rename to nvflare/app_opt/pt/quantization/__init__.py diff --git a/nvflare/app_opt/quantization/constant.py b/nvflare/app_opt/pt/quantization/constant.py similarity index 94% rename from nvflare/app_opt/quantization/constant.py rename to nvflare/app_opt/pt/quantization/constant.py index 06c422c655..e1e48ea779 100644 --- a/nvflare/app_opt/quantization/constant.py +++ b/nvflare/app_opt/pt/quantization/constant.py @@ -13,7 +13,10 @@ # limitations under the License. DATA_TYPE = [ + "FLOAT64", "FLOAT32", + "FLOAT16", + "BFLOAT16", ] QUANTIZATION_TYPE = [ diff --git a/nvflare/app_opt/pt/quantization/dequantizor.py b/nvflare/app_opt/pt/quantization/dequantizor.py new file mode 100644 index 0000000000..d19a9584ec --- /dev/null +++ b/nvflare/app_opt/pt/quantization/dequantizor.py @@ -0,0 +1,197 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Union + +import numpy as np +import torch +from bitsandbytes.functional import QuantState, dequantize_4bit, dequantize_blockwise + +from nvflare.apis.dxo import DXO, DataKind, MetaKey +from nvflare.apis.dxo_filter import DXOFilter +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.app_opt.pt.quantization.constant import QUANTIZATION_TYPE + + +class ModelDequantizor(DXOFilter): + def __init__(self): + """Filter to dequantize Shareable object to recover from quantization + + Args: + None + + """ + + # support weight and weight_diff data kinds + data_kinds = [DataKind.WEIGHTS, DataKind.WEIGHT_DIFF] + super().__init__(supported_data_kinds=data_kinds, data_kinds_to_filter=data_kinds) + self.logger.info("Using model dequantizator.") + + def dequantization( + self, params: dict, quant_state: dict, quantization_type: str, source_datatype: dict, fl_ctx: FLContext + ): + n_params = len(params.keys()) + self.log_info(fl_ctx, f"Running dequantization on {n_params} variables") + n_bytes_before = 0 + n_bytes_after = 0 + n_bytes_meta = 0 + n_quant_params = 0 + for i, param_name in enumerate(params.keys()): + source_data_type = source_datatype[param_name] + + # get the bits information + source_date_bits = int(re.findall(r"\d+", source_data_type)[0]) + quantization_bits = int(re.findall(r"\d+", quantization_type)[0]) + + # only dequantize if the quantization type is lower than the source data type + if quantization_bits >= source_date_bits: + self.log_info( + fl_ctx, + f"Skipping dequantization for {param_name}, quantization bit {quantization_type} >= source data bit {source_data_type}", + ) + continue + else: + values = params[param_name] + n_bytes_before += values.nbytes + for item in quant_state[param_name].values(): + if isinstance(item, np.ndarray) or isinstance(item, torch.Tensor): + n_bytes_meta += item.nbytes + + if isinstance(values, np.ndarray): + # if numpy, convert to torch + source_data_format = "numpy" + elif isinstance(values, torch.Tensor): + source_data_format = "torch" + else: + raise ValueError(f"Invalid source data type: {type(values)}, valid: numpy or torch") + + n_quant_params += 1 + if quantization_type == "float16": + # direct assign and convert back to higher precision + params[param_name] = values + elif quantization_type in ["blockwise8", "float4", "normfloat4"]: + # use bitsandbytes to dequantize the values + # extract quantization state + if quantization_type == "blockwise8": + if source_data_format == "numpy": + # first convert numpy array to tensor if numpy + quantized = torch.as_tensor(values) + absmax = torch.as_tensor(quant_state[param_name]["absmax"]) + code = torch.as_tensor(quant_state[param_name]["code"]) + elif source_data_format == "torch": + quantized = values + absmax = quant_state[param_name]["absmax"] + code = quant_state[param_name]["code"] + # de-quanitze + dequantized = dequantize_blockwise(quantized, absmax=absmax, code=code) + else: + if source_data_format == "numpy": + # first convert numpy array to tensor, need to use GPU + quantized = torch.as_tensor(values).cuda() + # create QuantState object + quantize_state = QuantState( + quant_type=quant_state[param_name]["quant_type"], + absmax=torch.as_tensor(quant_state[param_name]["absmax"]).cuda(), + blocksize=quant_state[param_name]["blocksize"], + code=torch.as_tensor(quant_state[param_name]["quant_map"]).cuda(), + dtype=getattr(torch, quant_state[param_name]["dtype"]), + shape=torch.Size(quant_state[param_name]["shape"]), + ) + elif source_data_format == "torch": + quantized = values.cuda() + quantize_state = QuantState( + quant_type=quant_state[param_name]["quant_type"], + absmax=quant_state[param_name]["absmax"].cuda(), + blocksize=quant_state[param_name]["blocksize"], + code=quant_state[param_name]["quant_map"].cuda(), + dtype=getattr(torch, quant_state[param_name]["dtype"]), + shape=torch.Size(quant_state[param_name]["shape"]), + ) + # de-quanitze + if quantization_type == "float4": + dequantized = dequantize_4bit(quantized, quantize_state, quant_type="fp4") + else: + dequantized = dequantize_4bit(quantized, quantize_state, quant_type="nf4") + if source_data_format == "numpy": + params[param_name] = dequantized.cpu().numpy() + elif source_data_format == "torch": + params[param_name] = dequantized.cpu() + + # assign back + if source_data_format == "numpy": + # convert back to original data type + if source_data_type == "float32": + params[param_name] = params[param_name].astype(np.float32) + elif source_data_type == "float64": + params[param_name] = params[param_name].astype(np.float64) + elif source_data_type == "float16": + params[param_name] = params[param_name].astype(np.float16) + elif source_data_format == "torch": + # convert back to original data type + if source_data_type == "float32": + params[param_name] = params[param_name].float() + elif source_data_type == "float64": + params[param_name] = params[param_name].double() + elif source_data_type == "float16": + params[param_name] = params[param_name].half() + elif source_data_type == "bfloat16": + params[param_name] = params[param_name].bfloat16() + + n_bytes_after += params[param_name].nbytes + + self.log_info( + fl_ctx, + f"Dequantized {n_quant_params}/{n_params} params." + f" Before dequantization: {n_bytes_before / (1024 ** 2):.2f} MB with meta: {n_bytes_meta / (1024 ** 2):.2f} MB." + f" After dequantization: {n_bytes_after / (1024 ** 2):.2f} MB.", + ) + return params + + def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Union[None, DXO]: + """Filter process apply to the Shareable object. + + Args: + dxo: data to be processed + shareable: that the dxo belongs to + fl_ctx: FLContext + + Returns: DXO object with dequantized weights + + """ + + self.log_info(fl_ctx, "Running dequantization...") + + # check config + quantization_type = dxo.get_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, default=None) + if quantization_type.upper() not in QUANTIZATION_TYPE: + raise ValueError(f"Invalid quantization type: {quantization_type}, valid: {QUANTIZATION_TYPE}") + + dequantized_params = self.dequantization( + params=dxo.data, + quant_state=dxo.meta["quant_state"], + quantization_type=quantization_type, + source_datatype=dxo.meta["source_datatype"], + fl_ctx=fl_ctx, + ) + # Compose new DXO with dequantized data + dxo.data = dequantized_params + dxo.remove_meta_props(MetaKey.PROCESSED_ALGORITHM) + dxo.remove_meta_props("quant_state") + dxo.remove_meta_props("source_datatype") + dxo.update_shareable(shareable) + self.log_info(fl_ctx, "Dequantized back") + + return dxo diff --git a/nvflare/app_opt/pt/quantization/quantizor.py b/nvflare/app_opt/pt/quantization/quantizor.py new file mode 100644 index 0000000000..083ce8bde5 --- /dev/null +++ b/nvflare/app_opt/pt/quantization/quantizor.py @@ -0,0 +1,206 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Union + +import numpy as np +import torch +from bitsandbytes.functional import quantize_4bit, quantize_blockwise + +from nvflare.apis.dxo import DXO, DataKind, MetaKey +from nvflare.apis.dxo_filter import DXOFilter +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.app_opt.pt.quantization.constant import DATA_TYPE, QUANTIZATION_TYPE + + +class ModelQuantizor(DXOFilter): + def __init__( + self, + quantization_type="float16", + ): + """Filter to quantize Shareable object to reduce communication burden. + + Args: + quantization_type: method used for quantization + + """ + + # support weight and weight_diff data kinds + data_kinds = [DataKind.WEIGHTS, DataKind.WEIGHT_DIFF] + super().__init__(supported_data_kinds=data_kinds, data_kinds_to_filter=data_kinds) + + # assign quantization type and check if it is valid + self.logger.info("Using model quantizator.") + if quantization_type.upper() not in QUANTIZATION_TYPE: + raise ValueError(f"Invalid quantization type: {quantization_type}, valid: {QUANTIZATION_TYPE}") + else: + self.quantization_type = quantization_type + + # quantization constants + self.NP_FP16_MIN = np.finfo(np.float16).min + self.NP_FP16_MAX = np.finfo(np.float16).max + self.TS_FP16_MIN = torch.finfo(torch.float16).min + self.TS_FP16_MAX = torch.finfo(torch.float16).max + + def quantization(self, params: dict, fl_ctx: FLContext): + n_params = len(params.keys()) + self.log_info(fl_ctx, f"Running quantization on {n_params} variables") + n_bytes_before = 0 + n_bytes_after = 0 + n_bytes_meta = 0 + n_quant_params = 0 + quant_state = {} + source_datatype = {} + for i, param_name in enumerate(params.keys()): + values = params[param_name] + quant_state[param_name] = {} + + # check the data type, numpy or torch + # otherwise error + if isinstance(values, np.ndarray): + # if numpy, convert to torch + source_data_format = "numpy" + elif isinstance(values, torch.Tensor): + source_data_format = "torch" + else: + raise ValueError(f"Invalid source data type: {type(values)}, valid: numpy or torch") + + # get the data type of the values + if source_data_format == "numpy": + source_data_type = values.dtype.name + elif source_data_format == "torch": + source_data_type = str(values.dtype).split(".")[1] + source_datatype[param_name] = source_data_type + + # check if the data type is valid + if source_data_type.upper() not in DATA_TYPE: + raise ValueError(f"Invalid source data type: {source_data_type}, valid: {DATA_TYPE}") + + # get the bits information + source_data_bits = int(re.findall(r"\d+", source_data_type)[0]) + quantization_bits = int(re.findall(r"\d+", self.quantization_type)[0]) + + # add the number of bytes of the values + n_bytes_before += values.nbytes + # only quantize if the quantization type is lower than the source data type + if quantization_bits >= source_data_bits: + self.log_info( + fl_ctx, + f"Skipping quantization for {param_name}, quantization bit {self.quantization_type} >= source data bit {source_data_type}", + ) + continue + else: + n_quant_params += 1 + if self.quantization_type == "float16": + if source_data_format == "numpy": + # first clamp the values to the range of float16 + values = np.clip(values, self.NP_FP16_MIN, self.NP_FP16_MAX) + # then convert to float16 + values = values.astype(np.float16) + elif source_data_format == "torch": + # first clamp the values to the range of float16 + values = torch.clamp(values, self.TS_FP16_MIN, self.TS_FP16_MAX) + # then convert to float16 + values = values.to(torch.float16) + params[param_name] = values + elif self.quantization_type in ["blockwise8", "float4", "normfloat4"]: + # use bitsandbytes to quantize the values + # input is a tensor, output is a tuple of (quantized tensor, quantized_state) + if self.quantization_type == "blockwise8": + if source_data_format == "numpy": + # if numpy, first convert numpy array to tensor + values_tensor = torch.as_tensor(values) + elif source_data_format == "torch": + values_tensor = values + + # then quantize the tensor + quantized, quantized_state = quantize_blockwise(values_tensor) + # add the quantization state and values, keep source data format + if source_data_format == "numpy": + quant_state[param_name]["absmax"] = quantized_state.absmax.numpy() + quant_state[param_name]["code"] = quantized_state.code.numpy() + values = quantized.numpy() + elif source_data_format == "torch": + quant_state[param_name]["absmax"] = quantized_state.absmax + quant_state[param_name]["code"] = quantized_state.code + values = quantized + n_bytes_meta += quant_state[param_name]["absmax"].nbytes + n_bytes_meta += quant_state[param_name]["code"].nbytes + else: + if source_data_format == "numpy": + # if numpy, first convert numpy array to tensor, need to use GPU + values_tensor = torch.as_tensor(values).cuda() + elif source_data_format == "torch": + # if torch, directly use the tensor, need to use GPU + values_tensor = values.cuda() + # then quantize the tensor + if self.quantization_type == "float4": + quantized, quantized_state = quantize_4bit(values_tensor, quant_type="fp4") + else: + quantized, quantized_state = quantize_4bit(values_tensor, quant_type="nf4") + # add the quantization state and values, keep source data format + quantized_state = quantized_state.as_dict() + + for state_name, state in quantized_state.items(): + if isinstance(state, torch.Tensor): + if source_data_format == "numpy": + # if the state is a tensor, convert it to numpy array + quant_state[param_name][state_name] = state.cpu().numpy() + elif source_data_format == "torch": + # if the state is a tensor, keep it as tensor + quant_state[param_name][state_name] = state.cpu() + n_bytes_meta += state.nbytes + else: + quant_state[param_name][state_name] = state + # add values + if source_data_format == "numpy": + values = quantized.cpu().numpy() + elif source_data_format == "torch": + values = quantized.cpu() + params[param_name] = values + n_bytes_after += params[param_name].nbytes + + self.log_info( + fl_ctx, + f"Quantized {n_quant_params}/{n_params} params." + f" Before quantization: {n_bytes_before / (1024 ** 2):.2f} MB." + f" After quantization: {n_bytes_after / (1024 ** 2):.2f} MB with meta: {n_bytes_meta / (1024 ** 2):.2f} MB.", + ) + return params, quant_state, source_datatype + + def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Union[None, DXO]: + """Filter process apply to the Shareable object. + + Args: + dxo: data to be processed + shareable: that the dxo belongs to + fl_ctx: FLContext + + Returns: DXO object with quantized weights + + """ + + self.log_info(fl_ctx, "Running quantization...") + quantized_params, quant_state, source_datatype = self.quantization(params=dxo.data, fl_ctx=fl_ctx) + # Compose new DXO with quantized data + # Add quant_state to the new DXO meta + new_dxo = DXO(data_kind=dxo.data_kind, data=quantized_params, meta=dxo.meta) + new_dxo.set_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, value=self.quantization_type) + new_dxo.set_meta_prop(key="quant_state", value=quant_state) + new_dxo.set_meta_prop(key="source_datatype", value=source_datatype) + self.log_info(fl_ctx, f"Quantized to {self.quantization_type}") + + return new_dxo diff --git a/nvflare/app_opt/pt/tensor_params_converter.py b/nvflare/app_opt/pt/tensor_params_converter.py new file mode 100644 index 0000000000..6c87cdf9d3 --- /dev/null +++ b/nvflare/app_opt/pt/tensor_params_converter.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import numpy as np +import torch + +from nvflare.app_common.abstract.params_converter import ParamsConverter + + +class PTReceiveParamsConverter(ParamsConverter): + def convert(self, params: Dict, fl_ctx) -> Dict: + tensor_shapes = fl_ctx.get_prop("tensor_shapes") + exclude_vars = fl_ctx.get_prop("exclude_vars") + + return_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + return_params[k] = v + else: + # "PT receive, so potentially also need to handle numpy to tensor" + if tensor_shapes: + if k in tensor_shapes: + return_params[k] = torch.as_tensor(np.reshape(v, tensor_shapes[k])) + else: + return_params[k] = torch.as_tensor(v) + else: + return_params[k] = torch.as_tensor(v) + + if exclude_vars: + for k, v in exclude_vars.items(): + return_params[k] = v + + return return_params + + +class PTSendParamsConverter(ParamsConverter): + def convert(self, params: Dict, fl_ctx) -> Dict: + return_tensors = {} + exclude_vars = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + return_tensors[k] = v.cpu() + else: + exclude_vars[k] = v + + if exclude_vars: + fl_ctx.set_prop("exclude_vars", exclude_vars) + self.logger.warning( + f"{len(exclude_vars)} vars excluded as they were non-tensor type: " f"{list(exclude_vars.keys())}" + ) + + return return_tensors diff --git a/nvflare/app_opt/quantization/numpy_dequantizor.py b/nvflare/app_opt/quantization/numpy_dequantizor.py deleted file mode 100644 index 0409ac45c2..0000000000 --- a/nvflare/app_opt/quantization/numpy_dequantizor.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Union - -import numpy as np -import torch -from bitsandbytes.functional import QuantState, dequantize_4bit, dequantize_blockwise - -from nvflare.apis.dxo import DXO, DataKind, MetaKey -from nvflare.apis.dxo_filter import DXOFilter -from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import Shareable -from nvflare.app_opt.quantization.constant import DATA_TYPE, QUANTIZATION_TYPE - - -class NumpyModelDequantizor(DXOFilter): - def __init__(self, source_data_type="float32"): - """Filter to dequantize Shareable object to recover from quantization - - Args: - source_data_type: original data type of the model - - """ - - # support weight and weight_diff data kinds - data_kinds = [DataKind.WEIGHTS, DataKind.WEIGHT_DIFF] - super().__init__(supported_data_kinds=data_kinds, data_kinds_to_filter=data_kinds) - - # assign data type and check if it is valid - self.logger.info("Using model dequantizator.") - if source_data_type.upper() not in DATA_TYPE: - raise ValueError(f"Invalid source data type: {source_data_type}, valid: {DATA_TYPE}") - else: - self.source_data_type = source_data_type - - def dequantization(self, params: dict, quant_state: dict, quant_type: str, fl_ctx: FLContext): - n_params = len(params.keys()) - self.log_info(fl_ctx, f"Running dequantization on {n_params} variables") - n_bytes_before = 0 - n_bytes_after = 0 - n_bytes_meta = 0 - n_quant_params = 0 - for i, param_name in enumerate(params.keys()): - if self.source_data_type == "float32": - values = params[param_name] - n_bytes_before += values.nbytes - for item in quant_state[param_name].values(): - if isinstance(item, np.ndarray): - n_bytes_meta += item.nbytes - if self.source_data_type != quant_type: - # if the source data type is not the same as the quantization type, convert it - n_quant_params += 1 - if quant_type == "float16": - # direct convert - values = values.astype(np.float32) - params[param_name] = values - elif quant_type in ["blockwise8", "float4", "normfloat4"]: - # use bitsandbytes to dequantize the values - # extract quantization state - if quant_type == "blockwise8": - quantized = torch.as_tensor(values) - absmax = torch.as_tensor(quant_state[param_name]["absmax"]) - code = torch.as_tensor(quant_state[param_name]["code"]) - # de-quanitze - dequantized = dequantize_blockwise(quantized, absmax=absmax, code=code) - params[param_name] = dequantized.numpy() - else: - # first convert numpy array to tensor, need to use GPU - quantized = torch.as_tensor(values).cuda() - # create QuantState object - quantize_state = QuantState( - quant_type=quant_state[param_name]["quant_type"], - absmax=torch.as_tensor(quant_state[param_name]["absmax"]).cuda(), - blocksize=quant_state[param_name]["blocksize"], - code=torch.as_tensor(quant_state[param_name]["quant_map"]).cuda(), - dtype=getattr(torch, quant_state[param_name]["dtype"]), - shape=torch.Size(quant_state[param_name]["shape"]), - ) - # de-quanitze - if quant_type == "float4": - dequantized = dequantize_4bit(quantized, quantize_state, quant_type="fp4") - else: - dequantized = dequantize_4bit(quantized, quantize_state, quant_type="nf4") - params[param_name] = dequantized.cpu().numpy() - n_bytes_after += params[param_name].nbytes - - self.log_info( - fl_ctx, - f"Dequantized {n_quant_params}/{n_params} params." - f" Before dequantization: {n_bytes_before / (1024 ** 2):.2f} MB with meta: {n_bytes_meta / (1024 ** 2):.2f} MB." - f" After dequantization: {n_bytes_after / (1024 ** 2):.2f} MB.", - ) - return params - - def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Union[None, DXO]: - """Filter process apply to the Shareable object. - - Args: - dxo: data to be processed - shareable: that the dxo belongs to - fl_ctx: FLContext - - Returns: DXO object with dequantized weights - - """ - - self.log_info(fl_ctx, "Running dequantization...") - - # check config - quantization_type = dxo.get_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, default=None) - if quantization_type.upper() not in QUANTIZATION_TYPE: - raise ValueError(f"Invalid quantization type: {quantization_type}, valid: {QUANTIZATION_TYPE}") - - dequantized_params = self.dequantization( - params=dxo.data, quant_state=dxo.meta["quant_state"], quant_type=quantization_type, fl_ctx=fl_ctx - ) - # Compose new DXO with dequantized data - dxo.data = dequantized_params - dxo.remove_meta_props(MetaKey.PROCESSED_ALGORITHM) - dxo.remove_meta_props("quant_state") - dxo.update_shareable(shareable) - self.log_info(fl_ctx, f"Dequantized back to {self.source_data_type}") - - return dxo diff --git a/nvflare/app_opt/quantization/numpy_quantizor.py b/nvflare/app_opt/quantization/numpy_quantizor.py deleted file mode 100644 index 13b0b44379..0000000000 --- a/nvflare/app_opt/quantization/numpy_quantizor.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Union - -import numpy as np -import torch -from bitsandbytes.functional import quantize_4bit, quantize_blockwise - -from nvflare.apis.dxo import DXO, DataKind, MetaKey -from nvflare.apis.dxo_filter import DXOFilter -from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import Shareable -from nvflare.app_opt.quantization.constant import DATA_TYPE, QUANTIZATION_TYPE - - -class NumpyModelQuantizor(DXOFilter): - def __init__( - self, - quantization_type="float16", - ): - """Filter to quantize Shareable object to reduce communication burden. - - Args: - quantization_type: method used for quantization - - """ - - # support weight and weight_diff data kinds - data_kinds = [DataKind.WEIGHTS, DataKind.WEIGHT_DIFF] - super().__init__(supported_data_kinds=data_kinds, data_kinds_to_filter=data_kinds) - - # assign quantization type and check if it is valid - self.logger.info("Using model quantizator.") - if quantization_type.upper() not in QUANTIZATION_TYPE: - raise ValueError(f"Invalid quantization type: {quantization_type}, valid: {QUANTIZATION_TYPE}") - else: - self.quantization_type = quantization_type - - # quantization constants - self.FP16_MIN = np.finfo(np.float16).min - self.FP16_MAX = np.finfo(np.float16).max - - def quantization(self, params: dict, fl_ctx: FLContext): - n_params = len(params.keys()) - self.log_info(fl_ctx, f"Running quantization on {n_params} variables") - n_bytes_before = 0 - n_bytes_after = 0 - n_bytes_meta = 0 - n_quant_params = 0 - quant_state = {} - for i, param_name in enumerate(params.keys()): - values = params[param_name] - quant_state[param_name] = {} - # check the data type of the values and if it is valid - source_data_type = values.dtype.name - if source_data_type.upper() not in DATA_TYPE: - raise ValueError(f"Invalid source data type: {source_data_type}, valid: {DATA_TYPE}") - # add the number of bytes of the values - n_bytes_before += values.nbytes - if source_data_type != self.quantization_type: - # if the source data type is not the same as the quantization type, convert it - n_quant_params += 1 - if source_data_type == "float32": - if self.quantization_type == "float16": - # first clamp the values to the range of float16 - values = np.clip(values, self.FP16_MIN, self.FP16_MAX) - # then convert to float16 - values = values.astype(np.float16) - params[param_name] = values - elif self.quantization_type in ["blockwise8", "float4", "normfloat4"]: - # use bitsandbytes to quantize the values - # input is a tensor, output is a tuple of (quantized tensor, quantized_state) - if self.quantization_type == "blockwise8": - # first convert numpy array to tensor - values_tensor = torch.as_tensor(values) - # then quantize the tensor - quantized, quantized_state = quantize_blockwise(values_tensor) - # add the quantization state - quant_state[param_name]["absmax"] = quantized_state.absmax.numpy() - n_bytes_meta += quant_state[param_name]["absmax"].nbytes - quant_state[param_name]["code"] = quantized_state.code.numpy() - n_bytes_meta += quant_state[param_name]["code"].nbytes - # add values - values = quantized.numpy() - else: - # first convert numpy array to tensor, need to use GPU - values_tensor = torch.as_tensor(values).cuda() - # then quantize the tensor - if self.quantization_type == "float4": - quantized, quantized_state = quantize_4bit(values_tensor, quant_type="fp4") - else: - quantized, quantized_state = quantize_4bit(values_tensor, quant_type="nf4") - # add the quantization state - quantized_state = quantized_state.as_dict() - for state_name, state in quantized_state.items(): - # if the state is a tensor, convert it to numpy array - if isinstance(state, torch.Tensor): - quant_state[param_name][state_name] = state.cpu().numpy() - n_bytes_meta += state.nbytes - else: - quant_state[param_name][state_name] = state - # add values - values = quantized.cpu().numpy() - params[param_name] = values - n_bytes_after += params[param_name].nbytes - - self.log_info( - fl_ctx, - f"Quantized {n_quant_params}/{n_params} params." - f" Before quantization: {n_bytes_before / (1024 ** 2):.2f} MB." - f" After quantization: {n_bytes_after / (1024 ** 2):.2f} MB with meta: {n_bytes_meta / (1024 ** 2):.2f} MB.", - ) - return params, quant_state - - def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Union[None, DXO]: - """Filter process apply to the Shareable object. - - Args: - dxo: data to be processed - shareable: that the dxo belongs to - fl_ctx: FLContext - - Returns: DXO object with quantized weights - - """ - - self.log_info(fl_ctx, "Running quantization...") - quantized_params, quant_state = self.quantization(params=dxo.data, fl_ctx=fl_ctx) - # Compose new DXO with quantized data - # Add quant_state to the new DXO meta - new_dxo = DXO(data_kind=dxo.data_kind, data=quantized_params, meta=dxo.meta) - new_dxo.set_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, value=self.quantization_type) - new_dxo.set_meta_prop(key="quant_state", value=quant_state) - self.log_info(fl_ctx, f"Quantized to {self.quantization_type}") - - return new_dxo diff --git a/nvflare/job_config/script_runner.py b/nvflare/job_config/script_runner.py index 5fa79e3fb0..68fbd8a2a9 100644 --- a/nvflare/job_config/script_runner.py +++ b/nvflare/job_config/script_runner.py @@ -45,6 +45,8 @@ def __init__( framework: FrameworkType = FrameworkType.PYTORCH, params_transfer_type: str = TransferType.FULL, executor: Union[ClientAPILauncherExecutor, InProcessClientAPIExecutor, None] = None, + from_nvflare_converter_id: Optional[str] = None, + to_nvflare_converter_id: Optional[str] = None, task_pipe: Optional[Pipe] = None, launcher: Optional[Launcher] = None, metric_relay: Optional[MetricRelay] = None, @@ -96,7 +98,8 @@ def __init__( self._launch_external_process = launch_external_process self._framework = framework self._params_transfer_type = params_transfer_type - + self._from_nvflare_converter_id = from_nvflare_converter_id + self._to_nvflare_converter_id = to_nvflare_converter_id self._params_exchange_format = None if self._framework == FrameworkType.PYTORCH: @@ -186,6 +189,8 @@ def add_to_fed_job(self, job: FedJob, ctx, **kwargs): launcher_id=launcher_id, params_exchange_format=self._params_exchange_format, params_transfer_type=self._params_transfer_type, + from_nvflare_converter_id=self._from_nvflare_converter_id, + to_nvflare_converter_id=self._to_nvflare_converter_id, heartbeat_timeout=0, ) ) @@ -231,6 +236,8 @@ def add_to_fed_job(self, job: FedJob, ctx, **kwargs): task_script_args=self._script_args, params_exchange_format=self._params_exchange_format, params_transfer_type=self._params_transfer_type, + from_nvflare_converter_id=self._from_nvflare_converter_id, + to_nvflare_converter_id=self._to_nvflare_converter_id, ) ) job.add_executor(executor, tasks=tasks, ctx=ctx) diff --git a/tests/unit_test/app_opt/quantization/quantization_test.py b/tests/unit_test/app_opt/quantization/quantization_test.py index 9fdc00d54a..b8b2a6fb35 100644 --- a/tests/unit_test/app_opt/quantization/quantization_test.py +++ b/tests/unit_test/app_opt/quantization/quantization_test.py @@ -14,11 +14,12 @@ import numpy as np import pytest +import torch from nvflare.apis.dxo import DXO, DataKind from nvflare.apis.fl_context import FLContext -from nvflare.app_opt.quantization.numpy_dequantizor import NumpyModelDequantizor -from nvflare.app_opt.quantization.numpy_quantizor import NumpyModelQuantizor +from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor +from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor TEST_CASES = [ ( @@ -31,6 +32,16 @@ "blockwise8", {"a": np.array([0.99062496, 2.003125, 3.015625, 4.0], dtype="float32")}, ), + ( + {"a": torch.tensor([1.0, 2.0, 3.0, 4000.0], dtype=torch.bfloat16)}, + "float16", + {"a": torch.tensor([1.0, 2.0, 3.0, 4000.0], dtype=torch.bfloat16)}, + ), + ( + {"a": torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)}, + "blockwise8", + {"a": torch.tensor([0.99062496, 2.003125, 3.015625, 4.0], dtype=torch.float32)}, + ), ] @@ -42,12 +53,18 @@ def test_quantization(self, input_data, quantization_type, expected_data): data=input_data, ) fl_ctx = FLContext() - f_quant = NumpyModelQuantizor(quantization_type=quantization_type) + f_quant = ModelQuantizor(quantization_type=quantization_type) quant_dxo = f_quant.process_dxo(dxo, dxo.to_shareable(), fl_ctx) - f_dequant = NumpyModelDequantizor(source_data_type="float32") + f_dequant = ModelDequantizor() dequant_dxo = f_dequant.process_dxo(quant_dxo, dxo.to_shareable(), fl_ctx) dequant_data = dequant_dxo.data for key in dequant_data.keys(): dequant_array = dequant_data[key] expected_array = expected_data[key] - assert np.allclose(dequant_array, expected_array) + # print the values + print(f"dequant_array: {dequant_array}") + print(f"expected_array: {expected_array}") + if isinstance(dequant_array, torch.Tensor): + assert torch.allclose(dequant_array, expected_array) + else: + assert np.allclose(dequant_array, expected_array)