Skip to content

Commit

Permalink
[Models] Implement support for gemma (#219)
Browse files Browse the repository at this point in the history
* model compiles

* use embedding weight for lm_head

* support gemma style rms

* fix head_dim

* runs but output garbage

* use gelu activation

* fix

* Fix, use `hasattr` to check if head_dim is available

* Bugfix, use the weight_offset at all locations for LlamaRMSNorm

* Bug-workaround, cutlass.rms_norm requires weight_offset of zero

---------

Co-authored-by: Masahiro Masuda <[email protected]>
  • Loading branch information
Lunderberg and masahi authored Feb 22, 2024
1 parent d66880c commit a0e680c
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 31 deletions.
12 changes: 10 additions & 2 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,11 +644,18 @@ def mod_transform_before_build(

if max_seq_len:
num_key_value_heads = config.get_num_key_value_heads()
num_query_heads = config.num_attention_heads // args.num_shards
hidden_size = config.hidden_size // args.num_shards
if hasattr(config, "head_dim"):
head_dim = config.head_dim
else:
head_dim = hidden_size // num_query_heads
# pylint: disable=no-value-for-parameter
mod = fuse_split_rotary_embedding(
config.num_attention_heads // args.num_shards,
num_query_heads,
num_key_value_heads // args.num_shards,
config.hidden_size // args.num_shards,
hidden_size,
head_dim,
config.position_embedding_base,
batched=args.enable_batching,
)(mod)
Expand Down Expand Up @@ -892,6 +899,7 @@ def build_model_from_args(args: argparse.Namespace):
model_generators["llama"] = llama_batched_vllm
model_generators["mistral"] = llama_batched_vllm
model_generators["mixtral"] = llama_batched_vllm
model_generators["gemma"] = llama_batched_vllm

assert args.model_category in model_generators, f"Model {args.model} not supported"

Expand Down
93 changes: 80 additions & 13 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

@dataclass
class LlamaConfig:
rms_norm_weight_offset = 0.0

def __init__(
self,
dtype="float32",
Expand Down Expand Up @@ -96,6 +98,19 @@ def __init__(
self.quantization_scheme = kwargs["quantization_scheme"]


class GemmaConfig(LlamaConfig):
rms_norm_weight_offset = 1.0

head_dim: int

def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
self.head_dim = kwargs["head_dim"]


class Linear(nn.Module):
def __init__(self, in_features, out_features, dtype: str, bias=True):
self.in_features = in_features
Expand Down Expand Up @@ -133,9 +148,10 @@ def forward(self, x: relax.Expr) -> relax.Var:


class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, dtype, eps=1e-6):
def __init__(self, hidden_size, dtype, eps=1e-6, weight_offset=0.0):
self.weight = nn.Parameter((hidden_size,), dtype=dtype, name="rms_norm_weight")
self.variance_epsilon = tvm.tir.const(eps, dtype)
self.weight_offset = weight_offset

def forward(self, hidden_states):
from tvm import te, tir
Expand Down Expand Up @@ -173,9 +189,12 @@ def f_div_cast_3d(bsz, i, k):
name=x.op.name + "red_temp",
)

return te.compute(
output = te.compute(
x.shape,
lambda i, k: f_mul_cast(weight(k), f_div_cast_2d(i, k)),
lambda i, k: f_mul_cast(
weight(k),
f_div_cast_2d(i, k),
),
name="rms_norm",
)
else:
Expand All @@ -185,13 +204,41 @@ def f_div_cast_3d(bsz, i, k):
name=x.op.name + "red_temp",
)

return te.compute(
output = te.compute(
x.shape,
lambda bsz, i, k: f_mul_cast(weight(k), f_div_cast_3d(bsz, i, k)),
lambda bsz, i, k: f_mul_cast(
weight(k),
f_div_cast_3d(bsz, i, k),
),
name="rms_norm",
)

return nn.emit_te(f_rms_norm, hidden_states, self.weight, primfunc_name_hint="rms_norm")
return output

# Currently, the cutlass.rms_norm assumes that
# `cutlass::rmsnorm` can be used in place of any PrimFunc that
# is named `rms_norm`. As a result, non-zero `weight_offset`
# applied inside the TE kernel definition would produce
# incorrect results. Applying the `weight_offset` outside the
# `nn.emit_te` is required for correct results. (It's also
# preferable for performance, so that the `weight_offset` can
# be preprocessed.)
#
# TODO(Lunderberg): Change the "cutlass.rms_norm" pattern to
# verify the function that it calls.
if self.weight_offset == 0:
rms_weights = self.weight
else:
rms_weights = nn.emit(
self.weight + R.const(self.weight_offset, dtype=self.weight.struct_info.dtype),
name_hint="rms_weights",
)
return nn.emit_te(
f_rms_norm,
hidden_states,
rms_weights,
primfunc_name_hint="rms_norm",
)


class LlamaMLP(nn.Module):
Expand All @@ -213,6 +260,8 @@ def __init__(self, config: LlamaConfig):
self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False)
self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False)

self.act = {"silu": relax.op.nn.silu, "gelu": relax.op.nn.gelu}[config.hidden_act]

def forward(self, x):
if self.combine_matmul:
gate_up_results = nn.emit(
Expand All @@ -228,7 +277,7 @@ def forward(self, x):
gate_result = self.gate_proj(x)
up_result = self.up_proj(x)

result = self.down_proj(relax.op.nn.silu(gate_result) * up_result)
result = self.down_proj(self.act(gate_result) * up_result)
return result


Expand Down Expand Up @@ -280,7 +329,12 @@ def __init__(self, config: LlamaConfig):
self.hidden_size = config.hidden_size
self.num_key_value_heads = config.get_num_key_value_heads() // config.num_shards
self.num_query_heads = config.num_attention_heads // self.num_shards
self.head_dim = self.hidden_size // config.num_attention_heads

if hasattr(config, "head_dim"):
self.head_dim = config.head_dim
else:
self.head_dim = config.hidden_size // config.num_attention_heads

self.position_embedding_base = config.position_embedding_base

self.combine_matmul = config.combine_matmul
Expand Down Expand Up @@ -322,7 +376,10 @@ def __init__(self, config: LlamaConfig):
self.v_proj.weight.shard_dim = 0

self.o_proj = Linear(
self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=config.attention_bias
self.head_dim * self.num_query_heads,
self.hidden_size,
dtype=dtype,
bias=config.attention_bias,
)
self.o_proj.weight.shard_dim = 1
self.o_proj.weight.shard_strategy = "shard_o_proj_k"
Expand Down Expand Up @@ -598,10 +655,16 @@ def __init__(self, config: LlamaConfig, enable_batching: bool):
self.use_moe = False
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(
config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps
config.hidden_size,
dtype=config.dtype,
eps=config.rms_norm_eps,
weight_offset=config.rms_norm_weight_offset,
)
self.post_attention_layernorm = LlamaRMSNorm(
config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps
config.hidden_size,
dtype=config.dtype,
eps=config.rms_norm_eps,
weight_offset=config.rms_norm_weight_offset,
)

def post_self_attn(self, hidden_states, residual):
Expand Down Expand Up @@ -1331,6 +1394,11 @@ def quantize(experts, relax_pname):
assert relax_pname.endswith("scales")
return qscale

if hasattr(config, "head_dim"):
head_dim = config.head_dim
else:
head_dim = config.hidden_size // config.num_attention_heads

def f_compute_relax_param(relax_pname: str, torch_params: List[Any]):
# Expected to enter this function only for the combined linear matmul weights.
# Other weights are supposed to be loaded in `f_convert_param_bkwd` since
Expand Down Expand Up @@ -1365,8 +1433,8 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]):
"Matmul combination is not turned on, and the function "
"is not expected to be entered"
)

hidden_size = config.hidden_size
head_dim = config.hidden_size // config.num_attention_heads

if "query_key_value_proj" in relax_pname:
q_heads = config.num_attention_heads
Expand Down Expand Up @@ -1401,7 +1469,6 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]):
device = tvm.cpu()
param_list = [None] * param_manager.nparam_to_load

head_dim = config.hidden_size / config.num_attention_heads
inv_freq = 1.0 / (
config.position_embedding_base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim)
)
Expand Down
42 changes: 37 additions & 5 deletions mlc_llm/relax_model/llama_batched_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .llama import (
LlamaConfig,
MixtralConfig,
GemmaConfig,
Linear,
Embedding,
LlamaRMSNorm,
Expand Down Expand Up @@ -492,6 +493,7 @@ def __init__(
kv_type: KVCacheType,
sep_embed: bool = False,
):
self.config = config
self.padding_idx = config.pad_token_id
self.embed_tokens = None

Expand All @@ -501,7 +503,12 @@ def __init__(
self.layers = ModuleList(
[LlamaDecoderLayerBatched(config, kv_type) for _ in range(config.num_hidden_layers)]
)
self.norm = LlamaRMSNorm(config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps)
self.norm = LlamaRMSNorm(
config.hidden_size,
dtype=config.dtype,
eps=config.rms_norm_eps,
weight_offset=config.rms_norm_weight_offset,
)

def forward(
self,
Expand All @@ -519,6 +526,9 @@ def forward(

hidden_states = inputs_embeds

if isinstance(self.config, GemmaConfig):
hidden_states = nn.emit(hidden_states * relax.const(self.config.hidden_size**0.5, dtype="float16"))

new_kvs = ()

for idx, decoder_layer in enumerate(self.layers):
Expand Down Expand Up @@ -551,11 +561,19 @@ def __init__(
self.num_shards = config.num_shards
self.cpu_device = cpu_device
self.model = LlamaModel(config, vocab_size_var, kv_type, sep_embed)
self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False)

if isinstance(config, GemmaConfig):
assert self.model.embed_tokens is not None
self.lm_head = lambda hidden: nn.emit(relax.op.linear(hidden, self.model.embed_tokens.weight))
else:
self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False)

############ Rotary embedding constants ############
assert config.hidden_size % config.num_attention_heads == 0
head_dim = config.hidden_size // config.num_attention_heads
if hasattr(config, "head_dim"):
head_dim = config.head_dim
else:
assert config.hidden_size % config.num_attention_heads == 0
head_dim = config.hidden_size // config.num_attention_heads

# Set the cached sin/cos to the maximum of 2048 and max seq len.
# This will be eliminated further with online rotary embedding calculation.
Expand Down Expand Up @@ -703,7 +721,11 @@ def get_inputs(
num_blocks = tvm.tir.Var("num_blocks", "int64")

num_key_value_heads = config.get_num_key_value_heads() // config.num_shards
head_size = hidden_size // config.num_attention_heads

if hasattr(config, "head_dim"):
head_size = config.head_dim
else:
head_size = config.hidden_size // config.num_attention_heads

if kv_type == KVCacheType.VLLM:
block_size = VllmAttention.block_size
Expand Down Expand Up @@ -1042,6 +1064,16 @@ def get_model(args, hf_config):
build_model_only=args.build_model_only,
quantization_scheme=args.quantization,
)
elif "gemma" in args.model.lower():
config = GemmaConfig(
**hf_config,
dtype=dtype,
max_sequence_length=hf_config["max_position_embeddings"],
position_embedding_base=position_embedding_base,
combine_matmul=True,
num_shards=args.num_shards,
build_model_only=args.build_model_only,
)
elif "max_sequence_length" in hf_config:
config = LlamaConfig(
**hf_config,
Expand Down
4 changes: 1 addition & 3 deletions mlc_llm/transform/fuse_split_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,8 @@ def apply_rewrite(mod, split_rotary, get_pattern_func):


def fuse_split_rotary_embedding(
num_query_heads, num_kv_heads, hidden_size, position_embedding_base, batched=False
num_query_heads, num_kv_heads, hidden_size, head_dim, position_embedding_base, batched=False
):
head_dim = hidden_size // num_query_heads

@tvm.ir.transform.module_pass(opt_level=0, name="fuse_split_rotary_embedding")
def ir_module_pass(mod: tvm.IRModule, _pass_context) -> tvm.IRModule:
split_rotary = get_dynamic_split_rotary()
Expand Down
1 change: 1 addition & 0 deletions mlc_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"mistral",
"mixtral",
"stablelm_epoch",
"gemma",
]
)

Expand Down
14 changes: 11 additions & 3 deletions serve/mlc_serve/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ModelArtifactConfig:
num_attention_heads: Optional[int] = None
num_hidden_layers: Optional[int] = None
hidden_size: Optional[int] = None
head_dim: Optional[int] = None

@classmethod
def _from_json(config_cls, json_obj: dict):
Expand All @@ -40,14 +41,16 @@ def _from_json(config_cls, json_obj: dict):
class AssetNotFound(Exception):
def __init__(self, asset_path):
self.asset_path = asset_path
super().__init__(f"{self.asset_path} should exist. Did you build with `--enable-batching`?")
super().__init__(
f"{self.asset_path} should exist. Did you build with `--enable-batching`?"
)


def get_model_artifact_config(model_artifact_path):
json_object = {"model_artifact_path": model_artifact_path}
for config_file_name in [
"build_config.json",
"model/mlc-model-config.json"
"model/mlc-model-config.json",
]:
config_file_path = os.path.join(model_artifact_path, config_file_name)
if not os.path.exists(config_file_path):
Expand All @@ -59,7 +62,12 @@ def get_model_artifact_config(model_artifact_path):
if not "paged_kv_cache_type" in json_object:
json_object["paged_kv_cache_type"] = "vllm"

return ModelArtifactConfig._from_json(json_object)
config = ModelArtifactConfig._from_json(json_object)

if config.head_dim is None:
config.head_dim = config.hidden_size // config.num_attention_heads

return config


def get_hf_config(model_path: Path) -> AutoConfig:
Expand Down
Loading

0 comments on commit a0e680c

Please sign in to comment.