Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AI Edge Torch TAP project #190

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions ai_edge_torch/generative/examples/test_models/toy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,56 @@ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
return self.lm_head(x)


class ToySingleLayerModelWeightSharing(torch.nn.Module):

def __init__(self, config: cfg.ModelConfig) -> None:
super().__init__()
self.lm_head = nn.Linear(
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
)
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
self.lm_head = nn.Linear(
config.embedding_dim,
config.vocab_size,
bias=config.lm_head_use_bias,
)
self.lm_head.weight.data = self.tok_embedding.weight.data
self.transformer_block = TransformerBlock(config)
self.final_norm = builder.build_norm(
config.embedding_dim,
config.final_norm_config,
)
self.rope_cache = attn_utils.build_rope_cache(
size=config.max_seq_len,
dim=int(
config.attn_config.rotary_percentage * config.attn_config.head_dim
),
base=10_000,
condense_ratio=1,
dtype=torch.float32,
device=torch.device('cpu'),
)
self.mask_cache = attn_utils.build_causal_mask_cache(
size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
)
self.config = config

@torch.inference_mode
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
x = self.tok_embedding(idx)
cos, sin = self.rope_cache

cos = cos.index_select(0, input_pos)
sin = sin.index_select(0, input_pos)
mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, : self.config.max_seq_len]

x = self.transformer_block(x, (cos, sin), mask, input_pos)
x = self.final_norm(x)
res = self.lm_head(x)
return res


def get_model_config() -> cfg.ModelConfig:
attn_config = cfg.AttentionConfig(
num_heads=32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from ai_edge_torch.generative.quantize import quant_attrs
from ai_edge_torch.generative.quantize import quant_recipe

_OpExecutionMode = quantizer.qtyping.OpExecutionMode
_ComputePrecision = quantizer.qtyping.ComputePrecision
_QuantGranularity = quantizer.qtyping.QuantGranularity
_OpName = quantizer.qtyping.TFLOperationName
_TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig
_OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
Expand Down Expand Up @@ -50,21 +51,31 @@ def _get_dtype_from_dtype(
return quantizer.qtyping.TensorDataType.INT


def _get_execution_mode_from_mode(mode: quant_attrs.Mode) -> _OpExecutionMode:
def _get_compute_precision_from_mode(
mode: quant_attrs.Mode,
) -> _ComputePrecision:
if mode == quant_attrs.Mode.DYNAMIC_RANGE:
return _OpExecutionMode.DRQ
return _ComputePrecision.INTEGER
elif mode == quant_attrs.Mode.WEIGHT_ONLY:
return _OpExecutionMode.WEIGHT_ONLY
return _ComputePrecision.FLOAT
raise ValueError('Unimplemented execution mode')


def _get_channelwise_from_granularity(
def _get_explicit_dequant_from_mode(mode: quant_attrs.Mode) -> bool:
if mode == quant_attrs.Mode.DYNAMIC_RANGE:
return False
elif mode == quant_attrs.Mode.WEIGHT_ONLY:
return True
raise ValueError('Unimplemented execution mode')


def _get_granularity(
granularity: quant_attrs.Granularity,
) -> bool:
if granularity == quant_attrs.Granularity.CHANNELWISE:
return True
elif granularity == quant_attrs.Granularity.NONE:
return False
return _QuantGranularity.CHANNELWISE
if granularity == quant_attrs.Granularity.NONE:
return _QuantGranularity.TENSORWISE
raise ValueError('Unimplemented granularity')


Expand All @@ -88,12 +99,13 @@ def _set_quant_config(
weight_tensor_config=_TensorQuantConfig(
num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
symmetric=True,
channel_wise=_get_channelwise_from_granularity(
layer_recipe.granularity
),
granularity=_get_granularity(layer_recipe.granularity),
dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
),
execution_mode=_get_execution_mode_from_mode(layer_recipe.mode),
compute_precision=_get_compute_precision_from_mode(layer_recipe.mode),
explicit_dequantize=_get_explicit_dequant_from_mode(
layer_recipe.mode
),
),
algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
)
Expand Down
33 changes: 23 additions & 10 deletions ai_edge_torch/generative/test/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================

import ai_edge_torch
from absl.testing import parameterized
from ai_edge_torch import config
from ai_edge_torch.generative.examples.test_models import toy_model # NOQA
from ai_edge_torch.generative.quantize import quant_recipe
Expand All @@ -25,16 +26,15 @@
from ai_edge_torch.generative.quantize.quant_attrs import Mode
from ai_edge_torch.quantize import quant_config
from ai_edge_torch.testing import model_coverage
from parameterized import parameterized
import torch

from absl.testing import absltest as googletest


class TestVerifyRecipes(googletest.TestCase):
class TestVerifyRecipes(parameterized.TestCase):
"""Unit tests that check for model quantization recipes."""

@parameterized.expand([
@parameterized.parameters([
(Dtype.FP32, Dtype.FP32),
(Dtype.INT8, Dtype.INT8),
(Dtype.INT8, Dtype.FP16),
Expand All @@ -52,7 +52,7 @@ def test_verify_invalid_recipes(
with self.assertRaises(ValueError):
quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()

@parameterized.expand([
@parameterized.parameters([
(
Dtype.FP32,
Dtype.INT8,
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_verify_valid_recipes(
).verify()


class TestQuantizeConvert(googletest.TestCase):
class TestQuantizeConvert(parameterized.TestCase):
"""Test conversion with quantization."""

def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
Expand All @@ -105,17 +105,13 @@ def _feedforward_int8_dynamic_recipe() -> quant_config.QuantConfig:
)
)

@parameterized.expand([
@parameterized.parameters([
(quant_recipes.full_fp16_recipe()),
(quant_recipes.full_int8_dynamic_recipe()),
(quant_recipes.full_int8_weight_only_recipe()),
(_attention_int8_dynamic_recipe()),
(_feedforward_int8_dynamic_recipe()),
])
@googletest.skipIf(
not config.Config.use_torch_xla,
reason="Not working with odml_torch at the moment.",
)
def test_quantize_convert_toy_sizes(self, quant_config):
config = toy_model.get_model_config()
pytorch_model = toy_model.ToySingleLayerModel(config)
Expand All @@ -132,6 +128,23 @@ def test_quantize_convert_toy_sizes(self, quant_config):
"Quantized model isn't smaller than F32 model.",
)

def test_quantize_convert_toy_weight_sharing(self):
config = toy_model.get_model_config()
pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config)
idx = torch.unsqueeze(torch.arange(0, 100), 0)
input_pos = torch.arange(0, 100)

quant_config = quant_recipes.full_int8_dynamic_recipe()
quantized_model = ai_edge_torch.convert(
pytorch_model, (idx, input_pos), quant_config=quant_config
)
float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
self.assertLess(
len(quantized_model._tflite_model),
len(float_model._tflite_model),
"Quantized model isn't smaller than F32 model.",
)

def test_quantize_convert_compare_toy(self):
self.skipTest("b/338288901")
config = toy_model_with_kv_cache.get_model_config()
Expand Down
20 changes: 20 additions & 0 deletions ai_edge_torch/lowertools/odml_torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from tensorflow.compiler.tf2xla.python import xla as tfxla
from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb
from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA

MlirBundle = odml_torch.export.MlirLowered

Expand Down Expand Up @@ -186,10 +187,29 @@ def merged_bundle_to_tfl_model(
converter._experimental_enable_composite_direct_lowering = True
converter.model_origin_framework = "PYTORCH"

conversion_utils.set_tfl_converter_quant_flags(converter, quant_config)
if (
quant_config is not None
and quant_config._quantizer_mode
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
):
translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
quant_config.generative_recipe
)

conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)

tflite_model = converter.convert()

if (
quant_config is not None
and quant_config._quantizer_mode
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
):
tflite_model = translate_recipe.quantize_model(
tflite_model, translated_recipe
)

return tflite_model


Expand Down
2 changes: 1 addition & 1 deletion odmltorch-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ torchaudio==2.4.0+cpu
--pre
tf-nightly>=2.18.0.dev20240722
torch_xla2[odml]>=0.0.1.dev20240801
ai-edge-quantizer-nightly==0.0.1.dev20240718
ai-edge-quantizer-nightly
jax[cpu]
scipy
numpy
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ torchaudio==2.4.0+cpu
torch_xla==2.4.0
--pre
tf-nightly>=2.18.0.dev20240722
ai-edge-quantizer-nightly==0.0.1.dev20240718
ai-edge-quantizer-nightly
scipy
numpy
tabulate
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,6 @@
"torch>=2.4.0",
"torch_xla>=2.4.0",
"tf-nightly>=2.18.0.dev20240722",
"ai-edge-quantizer-nightly==0.0.1.dev20240718",
"ai-edge-quantizer-nightly",
],
)
Loading