-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
9 changed files
with
532 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
## From-scratch Quantization | ||
|
||
Here we provide the codes for quantizing your custom models from scratch. Follow the steps outlined below. | ||
|
||
**NOTE:** Only dense-only quantization is supported now. We are working to release the code for the Dense-and-Sparse quantization soon. | ||
|
||
### 0. Prerequisite | ||
|
||
In addition to installing the dependencies required for the inference code, you will need to install additional dependencies by running the following command: | ||
``` | ||
conda activate sqllm | ||
pip install scikit-learn==1.3.1 | ||
``` | ||
Additionally, make sure you have your own LLaMA Huggingface checkpoint saved at `[MODEL_PATH]`. | ||
|
||
|
||
### 1. Compute gradients (Fisher-based sensitivity score) | ||
SqueezeLLM employs the Fisher Information matrix as a sensitivity metric. | ||
To compute this, we offer a separate [separate framework](https://github.com/kssteven418/SqueezeLLM-gradients) where you can compute the gradient square for your target model. | ||
This framework will produce the gradient square in the same format as the original Huggingface model checkpoint for your target model, with the only difference being that the weight values are replaced by the gradient square. | ||
|
||
### 2. Chunk model weights and gradients | ||
You should now have the model checkpoint at `[MODEL_PATH]` and the gradient checkpoint computed in the previous step at `[GRADIENT_PATH]`. | ||
Our framework requires that both checkpoints are chunked at the layer granularity to reduce the model loading overhead. | ||
Run the following code to chunk both your model and gradient checkpoints: | ||
``` | ||
python chunk_models.py --model [MODEL_PATH] --output [MODEL_CHUNKS_PATH] --model_type llama | ||
python chunk_models.py --model [GRADIENT_PATH] --output [GRADIENT_CHUNKS_PATH] --model_type llama | ||
``` | ||
|
||
This will save model weights and gradients in the layer granularity as `[MODEL_CHUNKS_PATH]` and `[GRADIENT_CHUNKS_PATH]`. | ||
|
||
### 3. K-means clustering | ||
Run the following code to perform K-means clustering, which will yield the non-uniform quantization look-up table (LUT): | ||
``` | ||
python nuq.py --bit 4 --model_type llama --model [MODEL_CHUNKS_PATH] --gradient [GRADIENT_CHUNKS_PATH] --output [LUT_PATH] | ||
``` | ||
The `--bit` argument is the bit-precision, and can be set to either 3 or 4. | ||
The `--model` and `--gradient` arguments should point to the chunked model weights and gradients obtained in the previous step. | ||
The resulting LUT entries will be stored in `[LUT_PATH]/lut`. | ||
|
||
To only quantize a specific range of layers, you can use the `--range` option. For instance, assigning `--range 0,10` will only compute LUT entries for layers 0 to 9. | ||
|
||
Please note that this process is highly CPU-intensive, so it is recommended to run the code in environments with multiple and stronger CPU cores for faster computation. | ||
|
||
### 4. Packing | ||
Finally, use the obtained LUT from the previous step to save your model into a packed format. Run the following command: | ||
``` | ||
python pack.py --model [MODEL_PATH] --wbits 4 --folder [LUT_PATH] --save [PACKED_CKPT_PATH] | ||
``` | ||
`[MODEL_PATH]` is the original model checkpoint, and `[LUT_PATH]` is the location where the LUT is stored from the previous step. | ||
The packed checkpoint will be saved at `[PACKED_CKPT_PATH]`, which can now be immediately used in your inference code. |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import argparse | ||
import os | ||
import torch | ||
from tqdm import tqdm | ||
|
||
from squeezellm.model_parse import ( | ||
parse_model, | ||
get_layers, | ||
get_modules, | ||
get_module_names, | ||
load_model, | ||
) | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
'--output_path', type=str, default=None, | ||
help="chunk the model and store" | ||
) | ||
parser.add_argument( | ||
'--model', type=str, | ||
help='model to load' | ||
) | ||
parser.add_argument( | ||
'--model_type', type=str, default=None, | ||
help='model type', choices=['llama', 'opt'] | ||
) | ||
|
||
args = parser.parse_args() | ||
# if model type is not explicitly given, infer from the model name | ||
model_type = args.model_type or parse_model(args.model) | ||
|
||
# This path is only taken when we want to chunk the model and store it, | ||
# which is used when '--output_path' is passed as an argument. | ||
print(f"chunking the model: {args.model} and storing in {args.output_path}") | ||
if not os.path.exists(args.output_path): | ||
os.makedirs(args.output_path) | ||
model = load_model(args.model, model_type) | ||
layers = get_layers(model, model_type) | ||
for i, layer in tqdm(enumerate(layers)): | ||
data = {} | ||
modules = get_modules(layer, model_type) | ||
module_names = get_module_names(model_type) | ||
|
||
for lin, name in zip(modules, module_names): | ||
data[name] = lin.weight.data | ||
torch.save(data, os.path.join(args.output_path, f"layer_{i}.pt")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import os | ||
import torch | ||
import pickle | ||
import argparse | ||
import numpy as np | ||
from sklearn.cluster import KMeans | ||
|
||
from tqdm import tqdm | ||
from transformers import LlamaForCausalLM | ||
|
||
from squeezellm.model_parse import parse_model, get_module_names | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
'--model', type=str, | ||
help='model weights to load', required=True | ||
) | ||
parser.add_argument( | ||
'--model_type', type=str, default=None, | ||
help='model type', choices=['llama', 'opt'] | ||
) | ||
parser.add_argument( | ||
'--gradient', type=str, | ||
help='model gradients to load', required=True | ||
) | ||
parser.add_argument( | ||
'--bit', type=int, default=3, | ||
help='bitwidth', choices=[3, 4], | ||
) | ||
parser.add_argument( | ||
'--range', type=str, default=None, | ||
help='range of layers to quantize' | ||
) | ||
parser.add_argument( | ||
'--output_folder', type=str, required=None, | ||
help='path to dump the output' | ||
) | ||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
|
||
# if model type is not explicitly given, infer from the model name | ||
model_type = args.model_type or parse_model(args.model) | ||
|
||
lut_folder = f"{args.output_folder}/lut" | ||
if not os.path.exists(lut_folder): | ||
os.makedirs(lut_folder) | ||
|
||
if args.range: | ||
ranges = args.range.split(",") | ||
ranges = [int(r) for r in ranges] | ||
ran = list(range(ranges[0], ranges[1])) | ||
else: | ||
# Count number of layers based on the chunk item count in the model folder | ||
# You should not add/delete anything in the folder to make this work | ||
nlayers = len([f for f in os.listdir(args.model)]) | ||
ran = list(range(nlayers)) | ||
|
||
print(f"Quantizing layers {ran}") | ||
|
||
for l in ran: | ||
if ran is not None and l not in ran: | ||
print(f"Skipping layer {l}") | ||
continue | ||
|
||
lut_file_name = f"{lut_folder}/l{l}.pkl" | ||
print(lut_file_name) | ||
|
||
if os.path.exists(lut_file_name): | ||
print(f"Skipping layer {l}, file already exists at {lut_file_name}") | ||
continue | ||
|
||
print(f"Quantizing layer {l}") | ||
|
||
try: | ||
gradient_layer = torch.load(f"{args.gradient}/layer_{l}.pt") | ||
except: | ||
raise Exception(f"Needs chunked gradient file at {gradient_layer}") | ||
|
||
try: | ||
model_layer = torch.load(f"./{args.model}/layer_{l}.pt") | ||
except: | ||
raise Exception(f"Needs chunked model weight file at {model_layer}") | ||
|
||
config_per_layer = {} | ||
|
||
for name in tqdm(get_module_names(model_type)): | ||
g = gradient_layer[name].float().numpy() | ||
|
||
config_per_row = [] | ||
module_weight = model_layer[name] | ||
_weights_np = module_weight.numpy() | ||
|
||
n_cluster = 2 ** args.bit | ||
|
||
# iterate over row | ||
for i in (range(module_weight.shape[0])): | ||
config_per_group = [] | ||
weights_np_temp = _weights_np[i, :] | ||
weights_np = weights_np_temp.reshape(-1, 1) | ||
|
||
weight_mask = weights_np_temp != 0 | ||
sample_weight = g[i, :] | ||
sample_weight = sample_weight * weight_mask | ||
|
||
if np.sum(sample_weight) == 0: | ||
sample_weight = np.ones_like(sample_weight) | ||
|
||
kmeans = KMeans( | ||
n_clusters=n_cluster, | ||
random_state=0, | ||
n_init="auto", | ||
max_iter=50, | ||
).fit( | ||
weights_np, | ||
sample_weight=sample_weight, | ||
) | ||
config_per_group.append( | ||
(kmeans.cluster_centers_.reshape(-1), np.cast['byte'](kmeans.labels_)) | ||
) | ||
config_per_row.append(config_per_group) | ||
|
||
config_per_layer[name] = config_per_row | ||
|
||
# save parts | ||
with open(lut_file_name, "wb") as f: | ||
print(f"Saving layer lut to {lut_folder}/l{l}.pkl") | ||
pickle.dump(config_per_layer, f) | ||
|
Oops, something went wrong.