Skip to content

Commit

Permalink
Quantization code added (#40)
Browse files Browse the repository at this point in the history
* quantization code added

* minor fix
  • Loading branch information
kssteven418 authored Sep 30, 2023
1 parent d42d7cb commit 7e11ad9
Show file tree
Hide file tree
Showing 9 changed files with 532 additions and 10 deletions.
16 changes: 6 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,7 @@ With this approach, we are able to serve larger models with smaller memory footp
For instance, the Squeeze variant of the Vicuna models can be served within 6 GB of memory and reach 2% higher MMLU than the baseline model in FP16 with an even 2x larger memory footprint.
For more details please check out our [paper](https://arxiv.org/abs/2306.07629).


**Updates (7/25):** LLaMA-2 7B and 13B are uploaded.

**Updates (7/21):** Vicuna v1.3 7B and 13B are uploaded.

**Updates (7/10):** All models other than LLaMA and Vicuna v1.1 can be run and evaluated without downloading the original checkpoints.

**Updates (7/5):** Salesforce's XGen models (both [Base](https://huggingface.co/Salesforce/xgen-7b-8k-base) and [Inst](https://huggingface.co/Salesforce/xgen-7b-8k-inst)) with 8k sequence length and OPT models are supported.


**Updates (9/30):** The code for quantizing custom models is now available ([link](https://github.com/kssteven418/squeezellm-private/tree/sk/working#from-scratch-quantization)).

---
## Installation
Expand All @@ -43,6 +34,11 @@ python setup_cuda.py install

---

## From-scratch Quantization

To quantize your own models, follow the procedure in this [link](https://github.com/kssteven418/squeezellm-private/tree/sk/working/quantization).


## Supported Models

Currently, we support [LLaMA](https://arxiv.org/abs/2302.13971) 7B, 13B, 30B and 65B, [LLaMA-2](https://arxiv.org/abs/2307.09288) 7B and 13B, instruction-tuned [Vicuna](https://lmsys.org/blog/2023-03-30-vicuna/) 7B and 13B, [XGen](https://blog.salesforceairesearch.com/xgen/) 7B with 8k sequence length, and OPT 1.3B to 30B.
Expand Down
Binary file modified figs/thumbnail.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
52 changes: 52 additions & 0 deletions quantization/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
## From-scratch Quantization

Here we provide the codes for quantizing your custom models from scratch. Follow the steps outlined below.

**NOTE:** Only dense-only quantization is supported now. We are working to release the code for the Dense-and-Sparse quantization soon.

### 0. Prerequisite

In addition to installing the dependencies required for the inference code, you will need to install additional dependencies by running the following command:
```
conda activate sqllm
pip install scikit-learn==1.3.1
```
Additionally, make sure you have your own LLaMA Huggingface checkpoint saved at `[MODEL_PATH]`.


### 1. Compute gradients (Fisher-based sensitivity score)
SqueezeLLM employs the Fisher Information matrix as a sensitivity metric.
To compute this, we offer a separate [separate framework](https://github.com/kssteven418/SqueezeLLM-gradients) where you can compute the gradient square for your target model.
This framework will produce the gradient square in the same format as the original Huggingface model checkpoint for your target model, with the only difference being that the weight values are replaced by the gradient square.

### 2. Chunk model weights and gradients
You should now have the model checkpoint at `[MODEL_PATH]` and the gradient checkpoint computed in the previous step at `[GRADIENT_PATH]`.
Our framework requires that both checkpoints are chunked at the layer granularity to reduce the model loading overhead.
Run the following code to chunk both your model and gradient checkpoints:
```
python chunk_models.py --model [MODEL_PATH] --output [MODEL_CHUNKS_PATH] --model_type llama
python chunk_models.py --model [GRADIENT_PATH] --output [GRADIENT_CHUNKS_PATH] --model_type llama
```

This will save model weights and gradients in the layer granularity as `[MODEL_CHUNKS_PATH]` and `[GRADIENT_CHUNKS_PATH]`.

### 3. K-means clustering
Run the following code to perform K-means clustering, which will yield the non-uniform quantization look-up table (LUT):
```
python nuq.py --bit 4 --model_type llama --model [MODEL_CHUNKS_PATH] --gradient [GRADIENT_CHUNKS_PATH] --output [LUT_PATH]
```
The `--bit` argument is the bit-precision, and can be set to either 3 or 4.
The `--model` and `--gradient` arguments should point to the chunked model weights and gradients obtained in the previous step.
The resulting LUT entries will be stored in `[LUT_PATH]/lut`.

To only quantize a specific range of layers, you can use the `--range` option. For instance, assigning `--range 0,10` will only compute LUT entries for layers 0 to 9.

Please note that this process is highly CPU-intensive, so it is recommended to run the code in environments with multiple and stronger CPU cores for faster computation.

### 4. Packing
Finally, use the obtained LUT from the previous step to save your model into a packed format. Run the following command:
```
python pack.py --model [MODEL_PATH] --wbits 4 --folder [LUT_PATH] --save [PACKED_CKPT_PATH]
```
`[MODEL_PATH]` is the original model checkpoint, and `[LUT_PATH]` is the location where the LUT is stored from the previous step.
The packed checkpoint will be saved at `[PACKED_CKPT_PATH]`, which can now be immediately used in your inference code.
Binary file not shown.
46 changes: 46 additions & 0 deletions quantization/chunk_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import argparse
import os
import torch
from tqdm import tqdm

from squeezellm.model_parse import (
parse_model,
get_layers,
get_modules,
get_module_names,
load_model,
)

parser = argparse.ArgumentParser()
parser.add_argument(
'--output_path', type=str, default=None,
help="chunk the model and store"
)
parser.add_argument(
'--model', type=str,
help='model to load'
)
parser.add_argument(
'--model_type', type=str, default=None,
help='model type', choices=['llama', 'opt']
)

args = parser.parse_args()
# if model type is not explicitly given, infer from the model name
model_type = args.model_type or parse_model(args.model)

# This path is only taken when we want to chunk the model and store it,
# which is used when '--output_path' is passed as an argument.
print(f"chunking the model: {args.model} and storing in {args.output_path}")
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
model = load_model(args.model, model_type)
layers = get_layers(model, model_type)
for i, layer in tqdm(enumerate(layers)):
data = {}
modules = get_modules(layer, model_type)
module_names = get_module_names(model_type)

for lin, name in zip(modules, module_names):
data[name] = lin.weight.data
torch.save(data, os.path.join(args.output_path, f"layer_{i}.pt"))
130 changes: 130 additions & 0 deletions quantization/nuq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os
import torch
import pickle
import argparse
import numpy as np
from sklearn.cluster import KMeans

from tqdm import tqdm
from transformers import LlamaForCausalLM

from squeezellm.model_parse import parse_model, get_module_names

parser = argparse.ArgumentParser()

parser.add_argument(
'--model', type=str,
help='model weights to load', required=True
)
parser.add_argument(
'--model_type', type=str, default=None,
help='model type', choices=['llama', 'opt']
)
parser.add_argument(
'--gradient', type=str,
help='model gradients to load', required=True
)
parser.add_argument(
'--bit', type=int, default=3,
help='bitwidth', choices=[3, 4],
)
parser.add_argument(
'--range', type=str, default=None,
help='range of layers to quantize'
)
parser.add_argument(
'--output_folder', type=str, required=None,
help='path to dump the output'
)

if __name__ == "__main__":
args = parser.parse_args()

# if model type is not explicitly given, infer from the model name
model_type = args.model_type or parse_model(args.model)

lut_folder = f"{args.output_folder}/lut"
if not os.path.exists(lut_folder):
os.makedirs(lut_folder)

if args.range:
ranges = args.range.split(",")
ranges = [int(r) for r in ranges]
ran = list(range(ranges[0], ranges[1]))
else:
# Count number of layers based on the chunk item count in the model folder
# You should not add/delete anything in the folder to make this work
nlayers = len([f for f in os.listdir(args.model)])
ran = list(range(nlayers))

print(f"Quantizing layers {ran}")

for l in ran:
if ran is not None and l not in ran:
print(f"Skipping layer {l}")
continue

lut_file_name = f"{lut_folder}/l{l}.pkl"
print(lut_file_name)

if os.path.exists(lut_file_name):
print(f"Skipping layer {l}, file already exists at {lut_file_name}")
continue

print(f"Quantizing layer {l}")

try:
gradient_layer = torch.load(f"{args.gradient}/layer_{l}.pt")
except:
raise Exception(f"Needs chunked gradient file at {gradient_layer}")

try:
model_layer = torch.load(f"./{args.model}/layer_{l}.pt")
except:
raise Exception(f"Needs chunked model weight file at {model_layer}")

config_per_layer = {}

for name in tqdm(get_module_names(model_type)):
g = gradient_layer[name].float().numpy()

config_per_row = []
module_weight = model_layer[name]
_weights_np = module_weight.numpy()

n_cluster = 2 ** args.bit

# iterate over row
for i in (range(module_weight.shape[0])):
config_per_group = []
weights_np_temp = _weights_np[i, :]
weights_np = weights_np_temp.reshape(-1, 1)

weight_mask = weights_np_temp != 0
sample_weight = g[i, :]
sample_weight = sample_weight * weight_mask

if np.sum(sample_weight) == 0:
sample_weight = np.ones_like(sample_weight)

kmeans = KMeans(
n_clusters=n_cluster,
random_state=0,
n_init="auto",
max_iter=50,
).fit(
weights_np,
sample_weight=sample_weight,
)
config_per_group.append(
(kmeans.cluster_centers_.reshape(-1), np.cast['byte'](kmeans.labels_))
)
config_per_row.append(config_per_group)

config_per_layer[name] = config_per_row

# save parts
with open(lut_file_name, "wb") as f:
print(f"Saving layer lut to {lut_folder}/l{l}.pkl")
pickle.dump(config_per_layer, f)

Loading

0 comments on commit 7e11ad9

Please sign in to comment.