Skip to content

Commit

Permalink
use block size 128 or 64 when possible
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 1, 2024
1 parent 99af3fb commit a003965
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,23 @@ def init_tvm_model(
) -> Tuple[TextGenerator, CacheManager]:
dev = tvm.device("cuda", 0)

num_kv_heads = (
model_artifact_config.num_key_value_heads // model_artifact_config.num_shards
)
head_size = (
model_artifact_config.hidden_size // model_artifact_config.num_attention_heads
)

if model_artifact_config.paged_kv_cache_type == "flash-decoding":
allocate_func_name = "tvm.contrib.flash_attn.allocate_kv_cache"
copy_blocks_func_name = "tvm.contrib.flash_attn.copy_blocks"
block_size = 256
# This needs to match with the model definition in llama_batched_vllm.py
if head_size <= 64:
block_size = 256
elif head_size <= 128:
block_size = 128
else:
block_size = 64
else:
allocate_func_name = "tvm.contrib.vllm.allocate_kv_cache"
copy_blocks_func_name = "tvm.contrib.vllm.copy_blocks"
Expand All @@ -475,13 +488,6 @@ def init_tvm_model(
if model_artifact_config.num_shards > 1:
model.disco_session.sync_worker_0()

num_kv_heads = (
model_artifact_config.num_key_value_heads // model_artifact_config.num_shards
)
head_size = (
model_artifact_config.hidden_size // model_artifact_config.num_attention_heads
)

if engine_config.max_num_batched_tokens > 0:
LOG.info("Running memory profiling.")
num_blocks = get_num_cache_blocks(
Expand Down

0 comments on commit a003965

Please sign in to comment.