Skip to content

Commit

Permalink
add spinquant
Browse files Browse the repository at this point in the history
  • Loading branch information
jambayk committed Jan 22, 2025
1 parent ab72d69 commit 706371d
Show file tree
Hide file tree
Showing 8 changed files with 787 additions and 222 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Please refer to [AutoAWQQuantizer](awq_quantizer) for more details about the pas
## QuaRot
`QuaRot` is a technique that rotates the weights of a model to make them more conducive to quantization. It is based on the [QuaRot paper](https://arxiv.org/abs/2305.14314) but only performs offline weight rotation. Can be followed by a pass such as GPTQ to quantize the rotated model weights.

This pass only supports HuggingFace transformer PyTorch models. Please refer to [QuaRot](quarot) for more details on the types of transformers models supported.
This pass only supports HuggingFace transformer PyTorch models.

### Example Configuration
```json
Expand All @@ -49,3 +49,17 @@ This pass only supports HuggingFace transformer PyTorch models. Please refer to
"rotate_mode": "hadamard"
}
```

## SpinQuant
`SpinQuant` is a technique simlar to QuaRot that rotates the weights of a model to make them more conducive to quantization. The rotation weights are trained on a calibration dataset to improve activation quantization quality. It is based on the [SpinQuant paper](https://arxiv.org/pdf/2405.16406) but only performs offline weight rotation. Can be followed by a pass such as GPTQ to quantize the rotated model weights.

This pass only supports HuggingFace transformer PyTorch models.

### Example Configuration
```json
{
"type": "SpinQuant",
"rotate_mode": "hadamard",
"a_bits": 8
}
```
6 changes: 6 additions & 0 deletions docs/source/reference/pass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ QuaRot
------
.. autoconfigclass:: olive.passes.QuaRot

.. _spinquant:

SpinQuant
---------
.. autoconfigclass:: olive.passes.SpinQuant

.. _gptq_quantizer:

GptqQuantizer
Expand Down
6 changes: 6 additions & 0 deletions olive/olive_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@
"supported_accelerators": [ "*" ],
"supported_precisions": [ "*" ]
},
"SpinQuant": {
"module_path": "olive.passes.pytorch.rotate.SpinQuant",
"supported_providers": [ "*" ],
"supported_accelerators": [ "*" ],
"supported_precisions": [ "*" ]
},
"SliceGPT": {
"module_path": "olive.passes.pytorch.slicegpt.SliceGPT",
"supported_providers": [ "*" ],
Expand Down
231 changes: 23 additions & 208 deletions olive/passes/pytorch/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# QLoRA: https://github.com/artidoro/qlora/blob/main/qlora.py
# https://arxiv.org/abs/2305.14314
# --------------------------------------------------------------------------
import dataclasses
import logging
import tempfile
from abc import abstractmethod
Expand All @@ -19,7 +18,7 @@
import transformers
from packaging import version

from olive.common.config_utils import ConfigBase, NestedConfig
from olive.common.config_utils import ConfigBase
from olive.common.hf.mappings import MODELS_TO_LORA_TARGET_MODULES_MAPPING
from olive.common.hf.utils import get_peft_task_type_from_task
from olive.common.pydantic_v1 import Field, validator
Expand All @@ -31,6 +30,13 @@
from olive.model.config.hf_config import HfLoadKwargs
from olive.passes import Pass
from olive.passes.olive_pass import PassConfigParam
from olive.passes.pytorch.train_utils import (
BaseHFTrainingArguments,
count_trainable_parameters,
get_training_dataset,
load_hf_base_model,
prepare_model_for_finetuning,
)
from olive.strategy.search_parameter import Categorical

if TYPE_CHECKING:
Expand All @@ -41,25 +47,16 @@

logger = logging.getLogger(__name__)

# pylint: disable=unused-import


# ruff: noqa: B010
# creating a Config class since transformers.TrainingArguments is a dataclass
# pydantic handles dataclasses differently and causes issues with validation
# this also allows us to handle and validate extra_args better
class HFTrainingArguments(NestedConfig):
class HFTrainingArguments(BaseHFTrainingArguments):
"""Training arguments for transformers.Trainer.
Has the same fields as transformers.TrainingArguments with recommended default values for QLoRA fine-tuning.
"""

_nested_field_name = "extra_args"

# TODO(jambayk): is this default optim required? does it work for regular lora? what about lr_scheduler_type?
optim: str = Field("paged_adamw_32bit", description="The optimizer to use.")
learning_rate: float = Field(0.0002, description="The initial learning rate for AdamW.")
gradient_checkpointing: bool = Field(True, description="Use gradient checkpointing. Recommended.")
lr_scheduler_type: str = Field(
"constant",
description="Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis.",
Expand All @@ -72,10 +69,6 @@ class HFTrainingArguments(NestedConfig):
" set to 'epoch'."
),
)
report_to: Union[str, List[str]] = Field(
"none", description="The list of integrations to report the results and logs to."
)
output_dir: str = Field(None, description="The output dir for logs and checkpoints. If None, will use a temp dir.")
overwrite_output_dir: bool = Field(
False,
description=(
Expand All @@ -89,48 +82,14 @@ class HFTrainingArguments(NestedConfig):
"The path to a folder with a valid checkpoint for the model. Supercedes any checkpoint found in output_dir."
),
)
deepspeed: Union[bool, str, Dict] = Field(
None,
description=(
"Use [Deepspeed](https://github.com/microsoft/deepspeed). If True, will use default deepspeed config. Else,"
" it is a path to a deepspeed config file or a dict with deepspeed config."
),
)
extra_args: Dict[str, Any] = Field(
None,
description=(
"Extra arguments to pass to the trainer. Values can be provided directly to this field as a dict or as"
" keyword arguments to the config. See transformers.TrainingArguments for more details on the available"
" arguments."
),
)

@validator("extra_args", pre=True, always=True)
def validate_extra_args(cls, v):
if v is None:
v = {}
# make sure extra args are fields of transformers.Trainer
training_args_fields = {f.name for f in dataclasses.fields(transformers.TrainingArguments) if f.init}
for k in list(v): # need a copy of the keys since we are mutating the dict
if k == "fp16":
logger.warning("Extra arg %s is not allowed. Please use `torch_dtype` instead.", k)
del v[k]
elif k not in training_args_fields:
logger.warning("Extra arg %s is not a field of transformers.TrainingArguments. Ignoring.", k)
del v[k]
def validate_torch_dtype(cls, v):
if v and "fp16" in v:
logger.warning("Extra arg 'fp16' is not allowed. Please use `torch_dtype` instead.")
del v["fp16"]
return v

def create_training_args(self) -> transformers.TrainingArguments:
args = self.dict()
if not args["output_dir"]:
raise ValueError("output_dir must be provided.")
if args["deepspeed"] is True:
args["deepspeed"] = deepcopy(DEFAULT_DEEPSPEED_CONFIG)
elif args["deepspeed"] is False:
del args["deepspeed"]
extra_args = args.pop("extra_args")
return transformers.TrainingArguments(**args, **extra_args)


class LoRABase(Pass):
"""Base class for LoRA and QLoRA fine-tuning passes."""
Expand Down Expand Up @@ -255,70 +214,16 @@ def get_datasets(
) -> Tuple["Dataset", Optional["Dataset"]]:
"""Load training and evaluation datasets."""
# we return dataset.Dataset object since the trainer works better with it
from datasets import Dataset

train_data_config = config.train_data_config
eval_data_config = config.eval_data_config

def data_generator(data_config):
data_container = data_config.to_data_container()
dataset = data_container.pre_process(data_container.load_dataset())

for idx in range(len(dataset)): # pylint: disable=consider-using-enumerate
example = dataset[idx]
if isinstance(example, tuple):
# if example = {**example[0], "labels": example[1]}, the attention_mask is not the same
# for some reason, so yield a new dict
yield {**example[0], "labels": example[1]}
else:
yield example

# load training dataset
train_dataset = Dataset.from_generator(data_generator, gen_kwargs={"data_config": train_data_config})
train_dataset.set_format("torch")
train_dataset = get_training_dataset(config.train_data_config)

# load evaluation dataset if needed
eval_dataset = None
if eval_data_config:
eval_dataset = Dataset.from_generator(data_generator, gen_kwargs={"data_config": eval_data_config})
eval_dataset.set_format("torch")
if config.eval_data_config:
eval_dataset = get_training_dataset(config.eval_data_config)

return train_dataset, eval_dataset

@staticmethod
def prepare_model_for_lora_finetuning(
model: "PreTrainedModel", use_gradient_checkpointing: bool
) -> "PreTrainedModel":
"""Prepare the model for fine-tuning.
Freeze base model's layers and prepare model for gradient checkpointing if necessary.
Similar to peft.prepare_model_for_kbit_training but no casting to fp32 and gradient checkpointing is
also supported for non-quantized models.
:param model: The Hugging Face PyTorch model to prepare for fine-tuning.
:param use_gradient_checkpointing: Whether to use gradient checkpointing.
:return: The prepared model.
"""
for param in model.parameters():
# freeze base model's layers
param.requires_grad = False

if use_gradient_checkpointing:
# For backward compatibility
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:

def make_inputs_require_grad(module_, input_, output_):
output_.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

# enable gradient checkpointing for memory efficiency
model.gradient_checkpointing_enable()

return model

def load_base_pytorch_model(self, model_handler: HfModelHandler, config: ConfigBase, **kwargs) -> "PreTrainedModel":
"""Load a base PyTorch model for fine-tuning.
Expand All @@ -329,39 +234,13 @@ def load_base_pytorch_model(self, model_handler: HfModelHandler, config: ConfigB
"""
import torch

# model cannot have it's own adapter
if model_handler.adapter_path:
raise ValueError("Model already has an adapter. Please provide a model without an adapter.")

# don't want the original loaded model
# also frees gpu memory if original model is on gpu
model_handler.model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()

# create copy of the input model, will modify this model
# also resets adapter_path
new_model_handler = deepcopy(model_handler)

torch_dtype = self.get_torch_dtype(config.torch_dtype)
# will use mixed precision since full fp16 is unstable
model_dtype = torch_dtype if torch_dtype != torch.float16 else torch.float32

# load model, reset load_kwargs and adapter_path
load_kwargs = new_model_handler.load_kwargs.dict() if new_model_handler.load_kwargs else {}
load_kwargs.update(
{
"torch_dtype": model_dtype,
# TODO(jambayk): Worry about `use_multi_gpu` and distributed training later
# "auto": uses all available GPUs, model parallel
"device_map": "auto",
}
)
# overwrite load_kwargs with kwargs
load_kwargs.update(kwargs)
new_model_handler.load_kwargs = HfLoadKwargs(**load_kwargs)

return new_model_handler.load_model(cache_model=False)
# TODO(jambayk): Worry about `use_multi_gpu` and distributed training later
# "auto": uses all available GPUs, model parallel
return load_hf_base_model(model_handler, torch_dtype=model_dtype, device_map="auto", **kwargs)

def init_lora_adapters(
self,
Expand Down Expand Up @@ -430,33 +309,21 @@ def enable_lora(
from peft import PeftModel

logger.debug("Enabling LoRA fine-tuning")
if config.training_args.gradient_checkpointing and not model.supports_gradient_checkpointing:
logger.warning(
"gradient_checkpointing is True, but model does not support gradient checkpointing! Setting"
" gradient_checkpoing to False"
)
config.training_args.gradient_checkpointing = False

model = self.prepare_model_for_lora_finetuning(model, config.training_args.gradient_checkpointing)
prepare_model_for_finetuning(model, config.training_args)

# set model_parallel and is_parallelizable to True
# we are using "auto" device_map, so model_parallel is True or doing DDP
# don't want the trainer to do Data Parallel
setattr(model, "model_parallel", True)
setattr(model, "is_parallelizable", True)
model.model_parallel = True
model.is_parallelizable = True

logger.debug(
"The number of trainable parameters in the original model: %s", self.count_trainable_parameters(model)
)
if not adapter_path:
logger.debug("Initializing LoRA adapters from config")
lora_model = self.init_lora_adapters(model, task, config, target_modules=target_modules)
else:
logger.debug("Loading LoRA adapters from %s", adapter_path)
lora_model = PeftModel.from_pretrained(model, adapter_path, is_trainable=True)
logger.debug(
"The number of trainable parameters in the LoRA model: %s", self.count_trainable_parameters(lora_model)
)
logger.debug("The number of trainable parameters in the LoRA model: %s", count_trainable_parameters(lora_model))
# no need to cast lora modules to model's dtype, we dont do peft.prepare_model_for_kbit_training so the modules
# are already in the same dtype as the model
# casting to dtype is risky since for awq quant linear, it also casts the scales to dtype and but the qlinear
Expand Down Expand Up @@ -573,20 +440,6 @@ def get_torch_dtype(torch_dtype: str) -> "torch.dtype":
assert torch_dtype in supported_dtypes, f"torch_dtype must be one of {supported_dtypes} but got {torch_dtype}"
return resolve_torch_dtype(torch_dtype)

@staticmethod
def count_trainable_parameters(model) -> str:
"""Count and return the number of trainable parameters in a model."""
trainable_params = 0
all_param = 0
for param in model.parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
return (
f"trainable params: {trainable_params} || all params: {all_param} "
f"|| trainable%: {100 * trainable_params / all_param:.2f}"
)


class LoRA(LoRABase):
"""Run LoRA fine-tuning on a Hugging Face PyTorch model."""
Expand Down Expand Up @@ -863,41 +716,3 @@ def get_quant_model(
)

return new_model_handler, pytorch_model, bnb_quant_config, quantized_modules


DEFAULT_DEEPSPEED_CONFIG = {
"zero_optimization": {
"stage": 3,
"allgather_partitions": True,
"allgather_bucket_size": 5e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": "auto",
"contiguous_gradients": True,
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"sub_group_size": 1e9,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": "auto",
"offload_param": {
"device": "cpu",
},
"offload_optimizer": {
"device": "cpu",
},
},
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1,
},
"bf16": {"enabled": "auto"},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
}
Loading

0 comments on commit 706371d

Please sign in to comment.