Skip to content

Commit

Permalink
Fix QKV weight sharding for gemma (#222)
Browse files Browse the repository at this point in the history
Fix gemma multi-gpu
  • Loading branch information
masahi authored Feb 27, 2024
1 parent 941320c commit 0dfb756
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 24 deletions.
5 changes: 1 addition & 4 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,10 +646,7 @@ def mod_transform_before_build(
num_key_value_heads = config.get_num_key_value_heads()
num_query_heads = config.num_attention_heads // args.num_shards
hidden_size = config.hidden_size // args.num_shards
if hasattr(config, "head_dim"):
head_dim = config.head_dim
else:
head_dim = hidden_size // num_query_heads
head_dim = config.get_head_dim()
# pylint: disable=no-value-for-parameter
mod = fuse_split_rotary_embedding(
num_query_heads,
Expand Down
2 changes: 1 addition & 1 deletion mlc_llm/relax_model/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def create_metadata_func(
def _get_shard_strategies(
model_config, num_shards: int, param_shape_is_already_sharded: bool
) -> Dict[str, tvm.tir.PrimFunc]:
head_dim = model_config.hidden_size // model_config.num_attention_heads
head_dim = model_config.get_head_dim()
q_heads = model_config.num_attention_heads
kv_heads = model_config.get_num_key_value_heads()

Expand Down
18 changes: 8 additions & 10 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def get_num_key_value_heads(self):

return self.num_key_value_heads

def get_head_dim(self):
return self.hidden_size // self.num_attention_heads


class MixtralConfig(LlamaConfig):
num_experts_per_tok: int
Expand Down Expand Up @@ -110,6 +113,9 @@ def __init__(
super().__init__(**kwargs)
self.head_dim = kwargs["head_dim"]

def get_head_dim(self):
return self.head_dim


class Linear(nn.Module):
def __init__(self, in_features, out_features, dtype: str, bias=True):
Expand Down Expand Up @@ -329,12 +335,7 @@ def __init__(self, config: LlamaConfig):
self.hidden_size = config.hidden_size
self.num_key_value_heads = config.get_num_key_value_heads() // config.num_shards
self.num_query_heads = config.num_attention_heads // self.num_shards

if hasattr(config, "head_dim"):
self.head_dim = config.head_dim
else:
self.head_dim = config.hidden_size // config.num_attention_heads

self.head_dim = config.get_head_dim()
self.position_embedding_base = config.position_embedding_base

self.combine_matmul = config.combine_matmul
Expand Down Expand Up @@ -1394,10 +1395,7 @@ def quantize(experts, relax_pname):
assert relax_pname.endswith("scales")
return qscale

if hasattr(config, "head_dim"):
head_dim = config.head_dim
else:
head_dim = config.hidden_size // config.num_attention_heads
head_dim = config.get_head_dim()

def f_compute_relax_param(relax_pname: str, torch_params: List[Any]):
# Expected to enter this function only for the combined linear matmul weights.
Expand Down
11 changes: 2 additions & 9 deletions mlc_llm/relax_model/llama_batched_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,7 @@ def __init__(
self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False)

############ Rotary embedding constants ############
if hasattr(config, "head_dim"):
head_dim = config.head_dim
else:
assert config.hidden_size % config.num_attention_heads == 0
head_dim = config.hidden_size // config.num_attention_heads
head_dim = config.get_head_dim()

# Set the cached sin/cos to the maximum of 2048 and max seq len.
# This will be eliminated further with online rotary embedding calculation.
Expand Down Expand Up @@ -722,10 +718,7 @@ def get_inputs(

num_key_value_heads = config.get_num_key_value_heads() // config.num_shards

if hasattr(config, "head_dim"):
head_size = config.head_dim
else:
head_size = config.hidden_size // config.num_attention_heads
head_size = config.get_head_dim()

if kv_type == KVCacheType.VLLM:
block_size = VllmAttention.block_size
Expand Down

0 comments on commit 0dfb756

Please sign in to comment.