Skip to content

Commit

Permalink
Autoquant v2 initial version (pytorch#1240)
Browse files Browse the repository at this point in the history
* Autoquant v2 initial version

Summary:
We refactored the v1 to do benchmark for subgraphs of (prev_op -> linear -> post_op) in order to get more accurate estimation
of timing. One issue here is now we need to care about batch size of the subgraph, so we'd need the batch size dimension to use symbolic
shape, seems that it does not have good support on torch.compile right now

More improvements:
* current batch size adjustment code is hardcoded to work for llama model, need to think of a way to generalize it
* using canonicalized subgraph as key for the cache to reduce the number of times we need to do benchmarking
* add accuracy sanity checks

Test Plan:
Testing with torchao/_models/llama/generate.py

```
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant_v2-int4
```

Reviewers:

Subscribers:

Tasks:

Tags:

* tested on llama2 and sam

* ruff

* ruff

* import

* cleanup

* more ruff

* ruff

* ruff format

* rename autoquant v2

* cleanup

* ruff

* move to prototype folder

* remove prototype import

* calibration_seq_length
  • Loading branch information
jerryzh168 authored Nov 21, 2024
1 parent ca52cdc commit 7446433
Show file tree
Hide file tree
Showing 5 changed files with 2,086 additions and 18 deletions.
102 changes: 86 additions & 16 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,28 +205,31 @@ def main(


if quantization:
from torchao.quantization.quant_api import (
from torchao.quantization import (
quantize_,
autoquant,
int8_weight_only,
int8_dynamic_activation_int8_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
fpx_weight_only,
uintx_weight_only,
autoquant,
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
from torchao.prototype.quantization.autoquant_v2 import autoquant_v2
from torchao.utils import unwrap_tensor_subclass

from torchao.quantization.granularity import PerTensor, PerRow
from torchao.utils import unwrap_tensor_subclass
if "spinquant" in quantization:
from torchao.prototype.spinquant import apply_spinquant
apply_spinquant(model)
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
elif "int8dq" in quantization:
quantize_(model, int8_dynamic_activation_int8_weight())
if "int4wo" in quantization:
elif "int4wo" in quantization:
if "hqq" in quantization:
use_hqq=True
else:
Expand All @@ -246,14 +249,14 @@ def main(
layout=MarlinQQQLayout(),
),
)
else:
else:
from torchao.dtypes import MarlinSparseLayout
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
if "embed-int8wo" in quantization:
elif "embed-int8wo" in quantization:
quantize_(model, int8_weight_only(group_size=64), filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding))
if quantization.startswith("awq"):
elif quantization.startswith("awq"):
from torchao._models._eval import TransformerEvalWrapper
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
from torchao.prototype.awq.example import get_calib_dataset
Expand All @@ -274,13 +277,13 @@ def main(
input_prep_func=prepare_inputs_for_model,
device=device,
).run_eval(
tasks=['wikitext'],
tasks=['wikitext'],
limit=1,
)
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
use_hqq = "hqq" in quantization
quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, use_hqq=use_hqq), is_observed_linear)
if "uintx" in quantization:
elif "uintx" in quantization:
# uintx-nbits-group_size, e.g. "uintx-2-64"
if "hqq" in quantization:
# uintx-nbits-group_size-hqq
Expand All @@ -294,9 +297,9 @@ def main(
dtype = _NBITS_TO_DTYPE[nbits]
group_size = int(_quant_args[2])
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
if "float8wo" in quantization:
elif "float8wo" in quantization:
quantize_(model, float8_weight_only())
if "float8dq" in quantization:
elif "float8dq" in quantization:
granularity = str(quantization.split("-")[-1])
if granularity=="tensor":
granularity = PerTensor()
Expand All @@ -305,13 +308,79 @@ def main(
else:
granularity = PerTensor()
quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity))
if "autoquant" in quantization:
elif "autoquant_v2" in quantization:
from torchao._models._eval import InputRecorder
from torchao._models.llama.model import prepare_inputs_for_model

calibration_seq_length = 256
calibration_limit = 1
inputs = InputRecorder(
tokenizer,
calibration_seq_length,
prepare_inputs_for_model,
False, # pad_calibration_inputs
model.config.vocab_size,
device="cuda"
).record_inputs(
["wikitext"],
1,
).get_inputs()[0].values[0]
inputs = prepare_inputs_for_model(inputs)
with torch.device("cuda"):
model.setup_caches(
max_batch_size=1, max_seq_length=calibration_seq_length
)

if "autoquant_v2-int4" == quantization:
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs)
elif "autoquant_v2-float8" == quantization:
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs)
else:
model = autoquant_v2(model, manual=True, example_input=inputs)

print("running generate")
generate(
model,
encode_tokens(tokenizer, prompt, bos=True, device=device),
max_new_tokens,
batch_size,
interactive=False,
temperature=temperature,
top_k=top_k,
)

print("running finalize autoquant")
# do autoquantization
model.finalize_autoquant()
elif "autoquant" in quantization:
from torchao._models._eval import InputRecorder
from torchao._models.llama.model import prepare_inputs_for_model

calibration_seq_length = 256
calibration_limit = 1
inputs = InputRecorder(
tokenizer,
calibration_seq_length,
prepare_inputs_for_model,
False, # pad_calibration_inputs
model.config.vocab_size,
device="cuda"
).record_inputs(
["wikitext"],
1,
).get_inputs()[0].values[0]
inputs = prepare_inputs_for_model(inputs)
with torch.device("cuda"):
model.setup_caches(
max_batch_size=1, max_seq_length=calibration_seq_length
)

if "autoquant-int4" == quantization:
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs)
elif "autoquant-float8" == quantization:
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST)
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs)
else:
model = autoquant(model, manual=True)
model = autoquant(model, manual=True, example_input=inputs)

generate(
model,
Expand All @@ -325,6 +394,7 @@ def main(

# do autoquantization
model.finalize_autoquant()

else:
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(model)
Expand Down Expand Up @@ -489,7 +559,7 @@ def callback(x):
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str,
parser.add_argument('-q', '--quantization', type=str,
help=(
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, '
Expand Down
4 changes: 3 additions & 1 deletion torchao/_models/sam/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@ sh setup.sh

Finally, you can run benchmarks with
```
sh benchmark_sam.sh
sh benchmark.sh
```

You can check out the result in results.csv
32 changes: 31 additions & 1 deletion torchao/_models/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
import time
import resource

from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, int4_weight_only
import torchao
from torchao.quantization import (
quantize_,
int8_dynamic_activation_int8_weight,
int4_weight_only,
autoquant,
)
from torchao.prototype.quantization.autoquant_v2 import autoquant_v2
from torchao.sparsity import sparsify_, apply_fake_sparsity, semi_sparse_weight
from torchao.dtypes import SemiSparseLayout, MarlinSparseLayout
from torchao.utils import unwrap_tensor_subclass
Expand Down Expand Up @@ -336,6 +343,29 @@ def mlp_only(mod, name):
mlp_lin2_only)
if not TORCH_VERSION_AT_LEAST_2_5:
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)

elif compress is not None and "autoquant_v2" in compress:
example_input = torch.randn(1, 3, 1024, 1024, dtype=torch.bfloat16, device=device)
if "autoquant_v2-int4" == compress:
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
elif "autoquant_v2-float8" == compress:
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST)
else:
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True)

predictor.model.image_encoder(example_input)
predictor.model.image_encoder.finalize_autoquant()

elif compress is not None and "autoquant" in compress:
example_input = torch.randn(1, 3, 1024, 1024, dtype=torch.bfloat16, device=device)
if "autoquant-int4" == compress:
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
elif "autoquant-float8" == compress:
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST)
else:
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True)
predictor.model.image_encoder(example_input)
predictor.model.image_encoder.finalize_autoquant()
else:
assert compress is None, f"Unsupported compress mode {compress}"

Expand Down
Loading

0 comments on commit 7446433

Please sign in to comment.