Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix QKV weight sharding for gemma #222

Merged
merged 1 commit into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,10 +646,7 @@ def mod_transform_before_build(
num_key_value_heads = config.get_num_key_value_heads()
num_query_heads = config.num_attention_heads // args.num_shards
hidden_size = config.hidden_size // args.num_shards
if hasattr(config, "head_dim"):
head_dim = config.head_dim
else:
head_dim = hidden_size // num_query_heads
head_dim = config.get_head_dim()
# pylint: disable=no-value-for-parameter
mod = fuse_split_rotary_embedding(
num_query_heads,
Expand Down
2 changes: 1 addition & 1 deletion mlc_llm/relax_model/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def create_metadata_func(
def _get_shard_strategies(
model_config, num_shards: int, param_shape_is_already_sharded: bool
) -> Dict[str, tvm.tir.PrimFunc]:
head_dim = model_config.hidden_size // model_config.num_attention_heads
head_dim = model_config.get_head_dim()
q_heads = model_config.num_attention_heads
kv_heads = model_config.get_num_key_value_heads()

Expand Down
18 changes: 8 additions & 10 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def get_num_key_value_heads(self):

return self.num_key_value_heads

def get_head_dim(self):
return self.hidden_size // self.num_attention_heads


class MixtralConfig(LlamaConfig):
num_experts_per_tok: int
Expand Down Expand Up @@ -110,6 +113,9 @@ def __init__(
super().__init__(**kwargs)
self.head_dim = kwargs["head_dim"]

def get_head_dim(self):
return self.head_dim


class Linear(nn.Module):
def __init__(self, in_features, out_features, dtype: str, bias=True):
Expand Down Expand Up @@ -329,12 +335,7 @@ def __init__(self, config: LlamaConfig):
self.hidden_size = config.hidden_size
self.num_key_value_heads = config.get_num_key_value_heads() // config.num_shards
self.num_query_heads = config.num_attention_heads // self.num_shards

if hasattr(config, "head_dim"):
self.head_dim = config.head_dim
else:
self.head_dim = config.hidden_size // config.num_attention_heads

self.head_dim = config.get_head_dim()
self.position_embedding_base = config.position_embedding_base

self.combine_matmul = config.combine_matmul
Expand Down Expand Up @@ -1394,10 +1395,7 @@ def quantize(experts, relax_pname):
assert relax_pname.endswith("scales")
return qscale

if hasattr(config, "head_dim"):
head_dim = config.head_dim
else:
head_dim = config.hidden_size // config.num_attention_heads
head_dim = config.get_head_dim()

def f_compute_relax_param(relax_pname: str, torch_params: List[Any]):
# Expected to enter this function only for the combined linear matmul weights.
Expand Down
11 changes: 2 additions & 9 deletions mlc_llm/relax_model/llama_batched_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,7 @@ def __init__(
self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False)

############ Rotary embedding constants ############
if hasattr(config, "head_dim"):
head_dim = config.head_dim
else:
assert config.hidden_size % config.num_attention_heads == 0
head_dim = config.hidden_size // config.num_attention_heads
head_dim = config.get_head_dim()

# Set the cached sin/cos to the maximum of 2048 and max seq len.
# This will be eliminated further with online rotary embedding calculation.
Expand Down Expand Up @@ -722,10 +718,7 @@ def get_inputs(

num_key_value_heads = config.get_num_key_value_heads() // config.num_shards

if hasattr(config, "head_dim"):
head_size = config.head_dim
else:
head_size = config.hidden_size // config.num_attention_heads
head_size = config.get_head_dim()

if kv_type == KVCacheType.VLLM:
block_size = VllmAttention.block_size
Expand Down
Loading