diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..3f0a1a5 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 SqueezeAILab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..68e3a55 --- /dev/null +++ b/README.md @@ -0,0 +1,110 @@ +# SqueezeLLM: Dense-and-Sparse Quantization [[Paper](https://arxiv.org/abs/2306.07629)] + +![Thumbnail](figs/thumbnail.png) + + +SqueezeLLM is a post-training quantization framework that incorporates a new method called Dense-and-Sparse Quantization to enable efficient LLM serving. + +TLDR: +Deploying LLMs is difficult due to their large memory size. This can be addressed with reduced precision quantization. But a naive method hurts performance. We address this with a new Dense-and-Sparse Quantization method. +Dense-and-Sparse splits weight matrices into two components: A dense component that can be heavily quantized without affecting model performance, as well as a sparse part that preserves sensitive and outlier parts of the weight matrices +With this approach, we are able to serve larger models with smaller memory footprint, the same latency, and **yet higher accuracy and quality**. +For instance, the Squeeze variant of the Vicuna models can be served within 6 GB of memory and reach 2% higher MMLU than the baseline model in FP16 with an even 2x larger memory footprint. +For more details please check out our [paper](https://arxiv.org/abs/2306.07629). + + +--- +## Installation + +1. Create a conda environment +``` +conda create --name sqllm python=3.9 -y +conda activate sqllm +``` + +2. Clone and install the dependencies +``` +git clone https://github.com/SqueezeAILab/SqueezeLLM +cd SqueezeLLM +pip install -e . +cd squeezellm +python setup_cuda.py install +``` + +--- + +## Supported Models + +Currently, we support [LLaMA](https://arxiv.org/abs/2302.13971) 7B, 13B, and 30B, as well as the instruction-tuned [Vicuna](https://lmsys.org/blog/2023-03-30-vicuna/) 7B and 13B. +For each model, we support 3-bit and 4-bit quantized models, with sparse levels of 0% (dense-only), 0.05%, and 0.45%. +See our [Paper](https://arxiv.org/abs/2306.07629) for more detailed information on these configurations. +Below are the links to download the models. + +### LLaMA + +| Model | Bitwidth | Dense-only (0%) | +| -------- | -------- | -------- | +| LLaMA-7B | 3 | [sq-llama-7b-w3-s0](https://huggingface.co/squeeze-ai-lab/sq-llama-7b-w3-s0/blob/main/sq-llama-7b-w3-s0.pt) | +| LLaMA-7B | 4 | [sq-llama-7b-w4-s0](https://huggingface.co/squeeze-ai-lab/sq-llama-7b-w4-s0/blob/main/sq-llama-7b-w4-s0.pt) | +| LLaMA-13B | 3 | [sq-llama-13b-w3-s0](https://huggingface.co/squeeze-ai-lab/sq-llama-13b-w3-s0/blob/main/sq-llama-13b-w3-s0.pt) | +| LLaMA-13B | 4 | [sq-llama-13b-w4-s0](https://huggingface.co/squeeze-ai-lab/sq-llama-13b-w4-s0/blob/main/sq-llama-13b-w4-s0.pt) | +| LLaMA-30B | 3 | sq-llama-30b-w3-s0 (coming soon) | +| LLaMA-30B | 4 | sq-llama-30b-w4-s0 (coming soon) | + +### Vicuna + +| Model | Bitwidth | Dense-only (0%) | +| -------- | -------- | -------- | +| Vicuna-7B | 3 | sq-vicuna-7b-w3-s0 (coming soon) | +| Vicuna-7B | 4 | sq-vicuna-7b-w4-s0 (coming soon) | +| Vicuna-13B | 3 | sq-vicuna-13b-w3-s0 (coming soon) | +| Vicuna-13B | 4 | sq-vicuna-13b-w4-s0 (coming soon) | + +**NOTE:** Sparsity levels with 0.05% and 0.45% are coming soon! + +The LLaMA model [license](https://github.com/facebookresearch/llama/blob/main/LICENSE) is currently only available for research purposes. We direct everyone to carefully review the license before using the quantized models. +Similar to other works on LLaMA, we only release the quantized portions of the model in [Huggingface Model Hub](https://huggingface.co/squeeze-ai-lab). +To successfully run our code, you need to first obtain the original, pre-trained LLaMA model in the Huggingface-compatible format locally and provide the path in the commands below. +We have scripts that will substitute the necessary components, but you will need the original model for those scripts to run. + + +### Benchmarking + +The following code will run and benchmark the 3-bit quantized LLaMA-7B model on the C4 dataset. The `--torch_profile` argument can be passed when running benchmarking to replicate the runtime results from the paper. +Download the quantized model (e.g. `sq-llama-7b-w3-s0.pt`) locally from the link above. +You can follow the same procedure for other quantized models. + +``` +CUDA_VISIBLE_DEVICES=0 python llama.py c4 --wbits 4 --load sq-llama-7b-w3-o0.pt --benchmark 128 --check +``` + +### Perplexity Evaluation + +The following code will evaluate perplexity using the 3-bit quantized LLaMA-7B model on the C4 dataset, following the same evaluation methodology of [GPTQ](https://github.com/IST-DASLab/gptq) and [GPTQ-For-LLaMA](https://github.com/qwopqwop200/GPTQ-for-LLaMa/). +Download the quantized model (e.g. `sq-llama-7b-w3-s0.pt`) locally from the link above. +You can follow the same procedure for other quantized models. +``` +CUDA_VISIBLE_DEVICES=0 python llama.py c4 --wbits 4 --load sq-llama-7b-w3-o0.pt --eval +``` + +The code was tested on A5000 and A6000 GPUs with Cuda 11.3 and CUDNN 8.2. + +--- +## Acknowledgement + +This code reuses components from several libraries including [GPTQ](https://github.com/IST-DASLab/gptq) as well as [GPTQ-For-LLaMA](https://github.com/qwopqwop200/GPTQ-for-LLaMa/). + +--- + +## Citation + +SqueezeLLM has been developed as part of the following paper. We appreciate it if you would please cite the following paper if you found the library useful for your work: + +``` +@article{kim2023squeezellm, + title={SqueezeLLM: Dense-and-Sparse Quantization}, + author={Kim, Sehoon and Hooper, Coleman and Gholami, Amir and Dong, Zhen and Li, Xiuyu and Shen, Sheng and Mahoney, Michael and Keutzer, Kurt}, + journal={arXiv}, + year={2023} +} +``` diff --git a/figs/thumbnail.png b/figs/thumbnail.png new file mode 100644 index 0000000..a8e46e7 Binary files /dev/null and b/figs/thumbnail.png differ diff --git a/llama.py b/llama.py new file mode 100644 index 0000000..5528f21 --- /dev/null +++ b/llama.py @@ -0,0 +1,273 @@ +import time + +import torch +import torch.nn as nn + +import transformers +from squeezellm.modelutils import * +from squeezellm.quant import * + +import pickle +import json + +def get_llama(model): + import torch + def skip(*args, **kwargs): + pass + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + from transformers import LlamaForCausalLM + model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto') + model.seqlen = 2048 + return model + +@torch.no_grad() +def llama_eval(model, testenc, dev): + print('Evaluating ...') + + testenc = testenc.input_ids + nsamples = testenc.numel() // model.seqlen + + use_cache = model.config.use_cache + model.config.use_cache = False + layers = model.model.layers + + model.model.embed_tokens = model.model.embed_tokens.to(dev) + layers[0] = layers[0].to(dev) + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros( + (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev + ) + cache = {'i': 0, 'attention_mask': None} + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + def forward(self, inp, **kwargs): + inps[cache['i']] = inp + cache['i'] += 1 + cache['attention_mask'] = kwargs['attention_mask'] + cache['position_ids'] = kwargs['position_ids'] + raise ValueError + layers[0] = Catcher(layers[0]) + for i in range(nsamples): + batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) + try: + model(batch) + except ValueError: + pass + layers[0] = layers[0].module + + layers[0] = layers[0].cpu() + model.model.embed_tokens = model.model.embed_tokens.cpu() + torch.cuda.empty_cache() + + outs = torch.zeros_like(inps) + attention_mask = cache['attention_mask'] + position_ids = cache['position_ids'] + + for i in range(len(layers)): + print(i) + layer = layers[i].to(dev) + + for j in range(nsamples): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids = position_ids)[0] + layers[i] = layer.cpu() + del layer + torch.cuda.empty_cache() + inps, outs = outs, inps + + if model.model.norm is not None: + model.model.norm = model.model.norm.to(dev) + model.lm_head = model.lm_head.to(dev) + + testenc = testenc.to(dev) + nlls = [] + for i in range(nsamples): + hidden_states = inps[i].unsqueeze(0) + if model.model.norm is not None: + hidden_states = model.model.norm(hidden_states) + lm_logits = model.lm_head(hidden_states) + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = testenc[ + :, (i * model.seqlen):((i + 1) * model.seqlen) + ][:, 1:] + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + neg_log_likelihood = loss.float() * model.seqlen + nlls.append(neg_log_likelihood) + ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) + print(ppl.item()) + + model.config.use_cache = use_cache + +# function for loading packed checkpoint +def load_quant(model, checkpoint, wbits): + from transformers import LlamaForCausalLM + model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto') + model = model.eval() + layers = find_layers(model) + + for name in ['lm_head']: + if name in layers: + del layers[name] + make_quant_lut(model, layers, wbits) + del layers + + print('Loading model ...') + state_dict = torch.load(checkpoint) + model.load_state_dict(state_dict, strict = False) + model.seqlen = 2048 + print('Done.') + + return model + + +# function for benchmarking runtime +def benchmark(model, input_ids, check=False): + input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV) + torch.cuda.synchronize() + + cache = {'past': None} + def clear_past(i): + def tmp(layer, inp, out): + if cache['past']: + cache['past'][i] = None + return tmp + for i, layer in enumerate(model.model.layers): + layer.register_forward_hook(clear_past(i)) + + print('Benchmarking ...') + + if check: + loss = nn.CrossEntropyLoss() + tot = 0. + + def sync(): + if hasattr(model, 'gpus'): + for gpu in model.gpus: + torch.cuda.synchronize(gpu) + else: + torch.cuda.synchronize() + max_memory = 0 + with torch.no_grad(): + attention_mask = torch.ones((1, input_ids.numel()), device=DEV) + times = [] + for i in range(input_ids.numel()): + tick = time.time() + out = model( + input_ids[:, i:i+1], + past_key_values=cache['past'], + attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)) + ) + sync() + times.append(time.time() - tick) + print(i, times[-1]) + max_memory = max(max_memory,torch.cuda.memory_allocated() / 1024 /1024) + if check and i != input_ids.numel() - 1: + tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float() + cache['past'] = list(out.past_key_values) + del out + sync() + import numpy as np + print('Median:', np.median(times)) + if check: + print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item()) + print('max memory(MiB):',max_memory) + +if __name__ == '__main__': + import argparse + from squeezellm.datautils import * + + parser = argparse.ArgumentParser() + + parser.add_argument( + 'model', type=str, + help='llama model to load' + ) + parser.add_argument( + 'dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], + help='Which dataset to use for benchmarking.' + ) + parser.add_argument( + '--seed', + type=int, default=0, help='Seed for sampling the calibration data.' + ) + parser.add_argument( + '--wbits', type=int, default=16, choices=[3, 4, 16], + help='#bits to use for quantization; use 16 for evaluating base model.' + ) + parser.add_argument( + '--eval', action='store_true', + help='evaluate quantized model.' + ) + parser.add_argument( + '--save', type=str, default='', + help='Save quantized checkpoint under this name.' + ) + parser.add_argument( + '--load', type=str, default='', + help='Load quantized model.' + ) + parser.add_argument( + '--benchmark', type=int, default=0, + help='Number of tokens to use for benchmarking.' + ) + parser.add_argument( + '--check', action='store_true', + help='Whether to compute perplexity during benchmarking for verification.' + ) + parser.add_argument( + '--nsamples', type=int, default=128, + help='Number of calibration data samples.' + ) + parser.add_argument( + '--torch_profile', action='store_true', + help='Use CUDA profiling tool for timing runs.' + ) + + DEV = torch.device('cuda:0') + + args = parser.parse_args() + + if type(args.load) is not str: + args.load = args.load.as_posix() + + if args.load: + model = load_quant(args.model, args.load, args.wbits) + else: + model = get_llama(args.model) + model.eval() + + dataloader, testloader = get_loaders( + args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen + ) + + if args.benchmark: + model = model.to(DEV) + if args.benchmark: + input_ids = next(iter(dataloader))[0][:, :args.benchmark] + + if args.torch_profile: + from torch.profiler import profile, record_function, ProfilerActivity + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + ) as p: + benchmark(model, input_ids, check=args.check) + print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + else: + benchmark(model, input_ids, check=args.check) + + if args.eval: + datasets = ['wikitext2', 'ptb', 'c4'] + for dataset in datasets: + dataloader, testloader = get_loaders( + dataset, seed=args.seed, model=args.model, seqlen=model.seqlen + ) + llama_eval(model, testloader, DEV) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..87df8cd --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,24 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "squeezellm" +version = "0.1.0" +description = "Ultra Low-Precision LLM Quantization." +readme = "README.md" +requires-python = ">=3.9" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", +] +dependencies = [ + "accelerate", + "sentencepiece", + "tokenizers>=0.12.1", + "torch", + "transformers>=4.28.0", + "datasets" +] + +[tool.setuptools.packages.find] diff --git a/squeezellm/.DS_Store b/squeezellm/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/squeezellm/.DS_Store differ diff --git a/squeezellm/datautils.py b/squeezellm/datautils.py new file mode 100644 index 0000000..6937a23 --- /dev/null +++ b/squeezellm/datautils.py @@ -0,0 +1,172 @@ +import numpy as np +import torch + +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + +def get_wikitext2(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') + testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') + testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + +def get_ptb(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') + valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') + testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + +def get_c4(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset( + 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False + ) + valdata = load_dataset( + 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', use_auth_token=False + ) + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + import random + random.seed(0) + valenc = [] + for _ in range(256): + while True: + i = random.randint(0, len(valdata) - 1) + tmp = tokenizer(valdata[i]['text'], return_tensors='pt') + if tmp.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + valenc.append(tmp.input_ids[:, i:j]) + valenc = torch.hstack(valenc) + class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + +def get_ptb_new(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') + testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') + testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + +def get_c4_new(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset( + 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train' + ) + valdata = load_dataset( + 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation' + ) + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') + valenc = valenc.input_ids[:, :(256 * seqlen)] + + class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + +def get_loaders( + name, nsamples=128, seed=0, seqlen=2048, model='' +): + if 'wikitext2' in name: + return get_wikitext2(nsamples, seed, seqlen, model) + if 'ptb' in name: + if 'new' in name: + return get_ptb_new(nsamples, seed, seqlen, model) + return get_ptb(nsamples, seed, seqlen, model) + if 'c4' in name: + if 'new' in name: + return get_c4_new(nsamples, seed, seqlen, model) + return get_c4(nsamples, seed, seqlen, model) diff --git a/squeezellm/modelutils.py b/squeezellm/modelutils.py new file mode 100644 index 0000000..108f3d8 --- /dev/null +++ b/squeezellm/modelutils.py @@ -0,0 +1,13 @@ +import torch +import torch.nn as nn + +#function to find layers in the network (either for packing or for replacement) +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_layers( + child, layers=layers, name=name + '.' + name1 if name != '' else name1 + )) + return res diff --git a/squeezellm/quant.py b/squeezellm/quant.py new file mode 100644 index 0000000..f720c1b --- /dev/null +++ b/squeezellm/quant.py @@ -0,0 +1,73 @@ +import numpy as np +import torch +import torch.nn as nn +import math +import quant_cuda + +# drop-in layer replacement class +class QuantLinearLUT(nn.Module): + def __init__(self, bits, infeatures, outfeatures, bias): + super().__init__() + if bits not in [3,4]: + raise NotImplementedError("Only 3 and 4 bits is supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + + self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + if bias: + self.include_bias = True + self.register_buffer('bias', torch.zeros((outfeatures))) + else: + self.include_bias = False + self.bias = None + + self.register_buffer('lookup_table', torch.zeros((outfeatures, 2**self.bits), dtype=torch.float32)) + + #replacement forward pass + def forward(self, x): + if x.shape[-1] == x.numel(): + outshape = list(x.shape) + if self.bias is not None: + y = self.bias.clone() + outshape[-1] = self.bias.numel() + else: + y = torch.zeros((self.outfeatures), device='cuda', dtype=torch.float32) + outshape[-1] = self.outfeatures + dtype = x.dtype + if self.bits == 3: + x = x.float() + quant_cuda.vecquant3matmul_nuq_perchannel(x, self.qweight, y, self.lookup_table) + elif self.bits == 4: + x = x.float() + quant_cuda.vecquant4matmul_nuq_perchannel(x, self.qweight, y, self.lookup_table) + y = y.to(dtype) + return y.reshape(outshape) + else: + out_shape = x.shape[:-1] + (self.outfeatures, ) + x = x.reshape(-1,x.shape[-1]) + out = torch.zeros((x.shape[0], self.outfeatures), device='cuda', dtype=torch.float32) + dtype = x.dtype + if self.bits == 3: + x = x.float() + quant_cuda.vecquant3matmul_nuq_perchannel_batched(x, self.qweight, out, self.lookup_table) + elif self.bits == 4: + x = x.float() + quant_cuda.vecquant4matmul_nuq_perchannel_batched(x, self.qweight, out, self.lookup_table) + out = out.to(dtype) + out = out.reshape(out_shape) + out = out + self.bias if self.bias is not None else out + return out + +# function to iterate through model layers and replace with our LUT-based layer +def make_quant_lut(module, names, bits, name=''): + if isinstance(module, QuantLinearLUT): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + delattr(module, attr) + setattr(module, attr, QuantLinearLUT(bits, tmp.in_features, tmp.out_features, tmp.bias is not None)) + for name1, child in module.named_children(): + make_quant_lut(child, names, bits, name + '.' + name1 if name != '' else name1) diff --git a/squeezellm/quant_cuda.cpp b/squeezellm/quant_cuda.cpp new file mode 100644 index 0000000..dafb1eb --- /dev/null +++ b/squeezellm/quant_cuda.cpp @@ -0,0 +1,56 @@ +#include +#include +#include + +void vecquant3matmul_nuq_perchannel_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table +); +void vecquant4matmul_nuq_perchannel_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table +); +void vecquant3matmul_nuq_perchannel_batched_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table +); +void vecquant4matmul_nuq_perchannel_batched_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table +); + +void vecquant3matmul_nuq_perchannel( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant3matmul_nuq_perchannel_cuda(vec, mat, mul, lookup_table); +} +void vecquant4matmul_nuq_perchannel( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_nuq_perchannel_cuda(vec, mat, mul, lookup_table); +} +void vecquant3matmul_nuq_perchannel_batched( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant3matmul_nuq_perchannel_batched_cuda(vec, mat, mul, lookup_table); +} +void vecquant4matmul_nuq_perchannel_batched( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_nuq_perchannel_batched_cuda(vec, mat, mul, lookup_table); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("vecquant3matmul_nuq_perchannel", &vecquant3matmul_nuq_perchannel, "Non-Uniform Vector 3-bit Quantized Matrix Multiplication w/ Per-Channel LUT (CUDA)"); + m.def("vecquant4matmul_nuq_perchannel", &vecquant4matmul_nuq_perchannel, "Non-Uniform Vector 4-bit Quantized Matrix Multiplication w/ Per-Channel LUT (CUDA)"); + m.def("vecquant3matmul_nuq_perchannel_batched", &vecquant3matmul_nuq_perchannel_batched, "Non-Uniform Vector 3-bit Quantized Matrix Multiplication w/ Per-Channel LUT (CUDA)"); + m.def("vecquant4matmul_nuq_perchannel_batched", &vecquant4matmul_nuq_perchannel_batched, "Non-Uniform Vector 4-bit Quantized Matrix Multiplication w/ Per-Channel LUT (CUDA)"); +} diff --git a/squeezellm/quant_cuda_kernel.cu b/squeezellm/quant_cuda_kernel.cu new file mode 100644 index 0000000..f09b355 --- /dev/null +++ b/squeezellm/quant_cuda_kernel.cu @@ -0,0 +1,491 @@ +#include +#include +#include +#include +#include + +// half-tensor +#include +#include + +// atomicAdd for double-precision floating-point numbers on hardware with +// compute capability < 6.0 from: +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 +__device__ double atomicAdd( + double* address, + double val +) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + + do { + assumed = old; + old = atomicCAS( + address_as_ull, + assumed, + __double_as_longlong(val + __longlong_as_double(assumed)) + ); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __longlong_as_double(old); +} +#endif + +const int BLOCKWIDTH = 128; +const int BLOCKHEIGHT3 = 12; +const int BLOCKHEIGHT4 = 16; + +__device__ inline unsigned int as_unsigned(int i) { + return *reinterpret_cast(&i); +} + +__device__ inline int as_int(int i) { + return *reinterpret_cast(&i); +} + +__global__ void VecQuant3MatMulKernelNUQPerChannel( + const float* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ lookup_table, + int height, + int width +); + +__global__ void VecQuant4MatMulKernelNUQPerChannel( + const float* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ lookup_table, + int height, + int width +); + +__global__ void VecQuant3MatMulKernelNUQPerChannelBatched( + const float* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ lookup_table, + int height, + int width, + int batch, + int vec_height +); + +__global__ void VecQuant4MatMulKernelNUQPerChannelBatched( + const float* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ lookup_table, + int height, + int width, + int batch, + int vec_height +); + +void vecquant3matmul_nuq_perchannel_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor lookup_table +) { + int height = mat.size(0); + int width = mat.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + VecQuant3MatMulKernelNUQPerChannel<<>>( + vec.data_ptr(), + mat.data_ptr(), + mul.data_ptr(), + lookup_table.data_ptr(), + height, width + ); +} + +// 4-bit matvec kernel (LUT-based) +void vecquant4matmul_nuq_perchannel_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor lookup_table +) { + int height = mat.size(0); + int width = mat.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + VecQuant4MatMulKernelNUQPerChannel<<>>( + vec.data_ptr(), + mat.data_ptr(), + mul.data_ptr(), + lookup_table.data_ptr(), + height, width + ); +} + +// 3-bit batched matvec kernel (LUT-based) +void vecquant3matmul_nuq_perchannel_batched_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor lookup_table +) { + int height = mat.size(0); + int width = mat.size(1); + + int batch = vec.size(0); + int vec_height = vec.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + VecQuant3MatMulKernelNUQPerChannelBatched<<>>( + vec.data_ptr(), + mat.data_ptr(), + mul.data_ptr(), + lookup_table.data_ptr(), + height, width, batch, vec_height + ); +} + +// 4-bit batched matvec kernel (LUT-based) +void vecquant4matmul_nuq_perchannel_batched_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor lookup_table +) { + int height = mat.size(0); + int width = mat.size(1); + + int batch = vec.size(0); + int vec_height = vec.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + VecQuant4MatMulKernelNUQPerChannelBatched<<>>( + vec.data_ptr(), + mat.data_ptr(), + mul.data_ptr(), + lookup_table.data_ptr(), + height, width, batch, vec_height + ); +} + +__global__ void VecQuant3MatMulKernelNUQPerChannel( + const float* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ lookup_table, + int height, + int width +) { + + int row = BLOCKHEIGHT3 * blockIdx.x; + int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ float blockvec[BLOCKWIDTH]; + blockvec[threadIdx.x] = vec[(row / BLOCKHEIGHT3) * BLOCKWIDTH + threadIdx.x]; + + //Modified dequant block + __shared__ float deq2[8][BLOCKWIDTH]; + int off = threadIdx.x; + int column_offset = col * 8; + for (int val = 0; val < 8; val += 1) { + int lut_index = column_offset + val; + deq2[val][off] = lookup_table[lut_index]; + } + + int i = width * row + col; + int k = 0; + + float res = 0; + + unsigned int tmp1; + unsigned int tmp2; + unsigned int tmp; + + __syncthreads(); + + while (k < BLOCKWIDTH) { + tmp1 = as_unsigned(mat[i]); + + res += deq2[(tmp1 >> 0) & 0x7][off] * blockvec[k + 0]; + res += deq2[(tmp1 >> 3) & 0x7][off] * blockvec[k + 1]; + res += deq2[(tmp1 >> 6) & 0x7][off] * blockvec[k + 2]; + res += deq2[(tmp1 >> 9) & 0x7][off] * blockvec[k + 3]; + res += deq2[(tmp1 >> 12) & 0x7][off] * blockvec[k + 4]; + res += deq2[(tmp1 >> 15) & 0x7][off] * blockvec[k + 5]; + res += deq2[(tmp1 >> 18) & 0x7][off] * blockvec[k + 6]; + res += deq2[(tmp1 >> 21) & 0x7][off] * blockvec[k + 7]; + res += deq2[(tmp1 >> 24) & 0x7][off] * blockvec[k + 8]; + res += deq2[(tmp1 >> 27) & 0x7][off] * blockvec[k + 9]; + + i += width; + tmp2 = as_unsigned(mat[i]); + tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4); + tmp2 >>= 1; + res += deq2[(tmp >> 0) & 0x7][off] * blockvec[k + 10]; + k += 11; + res += deq2[(tmp2 >> 0) & 0x7][off] * blockvec[k + 0]; + res += deq2[(tmp2 >> 3) & 0x7][off] * blockvec[k + 1]; + res += deq2[(tmp2 >> 6) & 0x7][off] * blockvec[k + 2]; + res += deq2[(tmp2 >> 9) & 0x7][off] * blockvec[k + 3]; + res += deq2[(tmp2 >> 12) & 0x7][off] * blockvec[k + 4]; + res += deq2[(tmp2 >> 15) & 0x7][off] * blockvec[k + 5]; + res += deq2[(tmp2 >> 18) & 0x7][off] * blockvec[k + 6]; + res += deq2[(tmp2 >> 21) & 0x7][off] * blockvec[k + 7]; + res += deq2[(tmp2 >> 24) & 0x7][off] * blockvec[k + 8]; + res += deq2[(tmp2 >> 27) & 0x7][off] * blockvec[k + 9]; + + i += width; + tmp1 = as_unsigned(mat[i]); + tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6); + tmp1 >>= 2; + res += deq2[(tmp >> 0) & 0x7][off] * blockvec[k + 10]; + k += 11; + res += deq2[(tmp1 >> 0) & 0x7][off] * blockvec[k + 0]; + res += deq2[(tmp1 >> 3) & 0x7][off] * blockvec[k + 1]; + res += deq2[(tmp1 >> 6) & 0x7][off] * blockvec[k + 2]; + res += deq2[(tmp1 >> 9) & 0x7][off] * blockvec[k + 3]; + res += deq2[(tmp1 >> 12) & 0x7][off] * blockvec[k + 4]; + res += deq2[(tmp1 >> 15) & 0x7][off] * blockvec[k + 5]; + res += deq2[(tmp1 >> 18) & 0x7][off] * blockvec[k + 6]; + res += deq2[(tmp1 >> 21) & 0x7][off] * blockvec[k + 7]; + res += deq2[(tmp1 >> 24) & 0x7][off] * blockvec[k + 8]; + res += deq2[(tmp1 >> 27) & 0x7][off] * blockvec[k + 9]; + i += width; + k += 10; + } + + atomicAdd(&mul[col], res); +} + +//4-bit per-channel +__global__ void VecQuant4MatMulKernelNUQPerChannel( + const float* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ lookup_table, + int height, + int width +) { + + int row = BLOCKHEIGHT4 * blockIdx.x; + int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ float blockvec[BLOCKWIDTH]; + blockvec[threadIdx.x] = vec[(row / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x]; + + //Modified dequant block + __shared__ float deq2[16][BLOCKWIDTH]; + int off = threadIdx.x; + int column_offset = col * 16; + for (int val = 0; val < 16; val += 1) { + int lut_index = column_offset + val; + deq2[val][off] = lookup_table[lut_index]; + } + + __syncthreads(); + + float res = 0; + int i = width * row + col; + int k = 0; + + unsigned int tmp; + + while (k < BLOCKWIDTH) { + tmp = as_unsigned(mat[i]); + + res += deq2[(tmp >> 0) & 0xf][off] * blockvec[k + 0]; + res += deq2[(tmp >> 4) & 0xf][off] * blockvec[k + 1]; + res += deq2[(tmp >> 8) & 0xf][off] * blockvec[k + 2]; + res += deq2[(tmp >> 12) & 0xf][off] * blockvec[k + 3]; + res += deq2[(tmp >> 16) & 0xf][off] * blockvec[k + 4]; + res += deq2[(tmp >> 20) & 0xf][off] * blockvec[k + 5]; + res += deq2[(tmp >> 24) & 0xf][off] * blockvec[k + 6]; + res += deq2[(tmp >> 28) & 0xf][off] * blockvec[k + 7]; + + i += width; + k += 8; + } + + atomicAdd(&mul[col], res); +} + + +//batched version (3-bit) +__global__ void VecQuant3MatMulKernelNUQPerChannelBatched( + const float* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ lookup_table, + int height, + int width, + int batch, + int vec_height +) { + + int row = BLOCKHEIGHT3 * blockIdx.x; + int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ float blockvec[BLOCKWIDTH]; + + __shared__ float deq2[8][BLOCKWIDTH]; + int off = threadIdx.x; + int column_offset = col * 8; + for (int val = 0; val < 8; val += 1) { + int lut_index = column_offset + val; + deq2[val][off] = lookup_table[lut_index]; + } + + int i; + float res; + int k; + + unsigned int tmp1; + unsigned int tmp2; + unsigned int tmp; + + for (int b = 0; b < batch; ++b){ + //initialize vars + i = width * row + col; + res = 0; + k = 0; + + __syncthreads(); + blockvec[threadIdx.x] = vec[b * vec_height + (row / BLOCKHEIGHT3) * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + + while (k < BLOCKWIDTH) { + tmp1 = as_unsigned(mat[i]); + + res += deq2[(tmp1 >> 0) & 0x7][off] * blockvec[k + 0]; + res += deq2[(tmp1 >> 3) & 0x7][off] * blockvec[k + 1]; + res += deq2[(tmp1 >> 6) & 0x7][off] * blockvec[k + 2]; + res += deq2[(tmp1 >> 9) & 0x7][off] * blockvec[k + 3]; + res += deq2[(tmp1 >> 12) & 0x7][off] * blockvec[k + 4]; + res += deq2[(tmp1 >> 15) & 0x7][off] * blockvec[k + 5]; + res += deq2[(tmp1 >> 18) & 0x7][off] * blockvec[k + 6]; + res += deq2[(tmp1 >> 21) & 0x7][off] * blockvec[k + 7]; + res += deq2[(tmp1 >> 24) & 0x7][off] * blockvec[k + 8]; + res += deq2[(tmp1 >> 27) & 0x7][off] * blockvec[k + 9]; + + i += width; + tmp2 = as_unsigned(mat[i]); + tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4); + tmp2 >>= 1; + res += deq2[(tmp >> 0) & 0x7][off] * blockvec[k + 10]; + k += 11; + res += deq2[(tmp2 >> 0) & 0x7][off] * blockvec[k + 0]; + res += deq2[(tmp2 >> 3) & 0x7][off] * blockvec[k + 1]; + res += deq2[(tmp2 >> 6) & 0x7][off] * blockvec[k + 2]; + res += deq2[(tmp2 >> 9) & 0x7][off] * blockvec[k + 3]; + res += deq2[(tmp2 >> 12) & 0x7][off] * blockvec[k + 4]; + res += deq2[(tmp2 >> 15) & 0x7][off] * blockvec[k + 5]; + res += deq2[(tmp2 >> 18) & 0x7][off] * blockvec[k + 6]; + res += deq2[(tmp2 >> 21) & 0x7][off] * blockvec[k + 7]; + res += deq2[(tmp2 >> 24) & 0x7][off] * blockvec[k + 8]; + res += deq2[(tmp2 >> 27) & 0x7][off] * blockvec[k + 9]; + + i += width; + tmp1 = as_unsigned(mat[i]); + tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6); + tmp1 >>= 2; + res += deq2[(tmp >> 0) & 0x7][off] * blockvec[k + 10]; + k += 11; + res += deq2[(tmp1 >> 0) & 0x7][off] * blockvec[k + 0]; + res += deq2[(tmp1 >> 3) & 0x7][off] * blockvec[k + 1]; + res += deq2[(tmp1 >> 6) & 0x7][off] * blockvec[k + 2]; + res += deq2[(tmp1 >> 9) & 0x7][off] * blockvec[k + 3]; + res += deq2[(tmp1 >> 12) & 0x7][off] * blockvec[k + 4]; + res += deq2[(tmp1 >> 15) & 0x7][off] * blockvec[k + 5]; + res += deq2[(tmp1 >> 18) & 0x7][off] * blockvec[k + 6]; + res += deq2[(tmp1 >> 21) & 0x7][off] * blockvec[k + 7]; + res += deq2[(tmp1 >> 24) & 0x7][off] * blockvec[k + 8]; + res += deq2[(tmp1 >> 27) & 0x7][off] * blockvec[k + 9]; + i += width; + k += 10; + } + + atomicAdd(&mul[b * width + col], res); + } +} + +//batched version (4-bit) +__global__ void VecQuant4MatMulKernelNUQPerChannelBatched( + const float* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ lookup_table, + int height, + int width, + int batch, + int vec_height +) { + + int row = BLOCKHEIGHT4 * blockIdx.x; + int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; + __shared__ float blockvec[BLOCKWIDTH]; + + //Modified dequant block + __shared__ float deq2[16][BLOCKWIDTH]; + int off = threadIdx.x; + int column_offset = col * 16; + for (int val = 0; val < 16; val += 1) { + int lut_index = column_offset + (val & 0xf); + deq2[val][off] = lookup_table[lut_index]; + } + + int i; + float res; + int k; + unsigned int tmp; + + for (int b = 0; b < batch; ++b){ + i = width * row + col; + res = 0; + k = 0; + + __syncthreads(); + blockvec[threadIdx.x] = vec[b * vec_height + (row / BLOCKHEIGHT4) * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + + while (k < BLOCKWIDTH) { + tmp = as_unsigned(mat[i]); + + res += deq2[(tmp >> 0) & 0xf][off] * blockvec[k + 0]; + res += deq2[(tmp >> 4) & 0xf][off] * blockvec[k + 1]; + res += deq2[(tmp >> 8) & 0xf][off] * blockvec[k + 2]; + res += deq2[(tmp >> 12) & 0xf][off] * blockvec[k + 3]; + res += deq2[(tmp >> 16) & 0xf][off] * blockvec[k + 4]; + res += deq2[(tmp >> 20) & 0xf][off] * blockvec[k + 5]; + res += deq2[(tmp >> 24) & 0xf][off] * blockvec[k + 6]; + res += deq2[(tmp >> 28) & 0xf][off] * blockvec[k + 7]; + + i += width; + k += 8; + } + + atomicAdd(&mul[b * width + col], res); + } +} diff --git a/squeezellm/setup_cuda.py b/squeezellm/setup_cuda.py new file mode 100644 index 0000000..6f05634 --- /dev/null +++ b/squeezellm/setup_cuda.py @@ -0,0 +1,10 @@ +from setuptools import setup, Extension +from torch.utils import cpp_extension + +setup( + name='quant_cuda', + ext_modules=[cpp_extension.CUDAExtension( + 'quant_cuda', ['quant_cuda.cpp', 'quant_cuda_kernel.cu'] + )], + cmdclass={'build_ext': cpp_extension.BuildExtension} +)