Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
HDCharles committed Jun 28, 2024
1 parent e04ba6c commit 8325534
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,18 +184,16 @@ def apply_torchdynamo_args(
if args.quantization:
import torchao
from torchao.quantization import (
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
quantize, int8_weight_only, int4_weight_only, int8_dynamic_activation_int8_weight
)
torch._dynamo.epilogue_fusion = False
from torchao.utils import unwrap_tensor_subclass

torch._dynamo.config.automatic_dynamic_shapes = False
torch._dynamo.config.force_parameter_static_shapes = False
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
torch._dynamo.config.cache_size_limit = 10000
assert "cuda" in model.device
module, example_inputs = model.get_module()
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
if isinstance(example_inputs, tuple([tuple, list])):
example_inputs = tuple([
x.to(torch.bfloat16)
Expand All @@ -209,22 +207,22 @@ def apply_torchdynamo_args(
module(**example_inputs)
else:
module(*example_inputs)

if args.quantization == "int8dynamic":
change_linear_weights_to_int8_dqtensors(module)
quantize(module, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
elif args.quantization == "int8weightonly":
change_linear_weights_to_int8_woqtensors(module)
quantize(module, int8_weight_only(), set_inductor_config=False)
elif args.quantization == "int4weightonly":
change_linear_weights_to_int4_woqtensors(module)
elif args.quantization == "autoquant":
torchao.autoquant(module, error_on_unseen=False)
quantize(module, int4_weight_only(), set_inductor_config=False)
if args.quantization == "autoquant":
torchao.autoquant(module, error_on_unseen=False, mode=["interpolate", .85], set_inductor_config=False)
if isinstance(example_inputs, dict):
module(**example_inputs)
else:
module(*example_inputs)
from torchao.quantization.autoquant import AUTOQUANT_CACHE
assert len(AUTOQUANT_CACHE)>0, f"Err: found no autoquantizable layers in model {type(module)}, stopping autoquantization"

else:
unwrap_tensor_subclass(module)

if args.freeze_prepack_weights:
torch._inductor.config.freezing = True
Expand Down

0 comments on commit 8325534

Please sign in to comment.