diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 08ba75f7e3..f1950a7417 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -461,10 +461,23 @@ def init_tvm_model( ) -> Tuple[TextGenerator, CacheManager]: dev = tvm.device("cuda", 0) + num_kv_heads = ( + model_artifact_config.num_key_value_heads // model_artifact_config.num_shards + ) + head_size = ( + model_artifact_config.hidden_size // model_artifact_config.num_attention_heads + ) + if model_artifact_config.paged_kv_cache_type == "flash-decoding": allocate_func_name = "tvm.contrib.flash_attn.allocate_kv_cache" copy_blocks_func_name = "tvm.contrib.flash_attn.copy_blocks" - block_size = 256 + # This needs to match with the model definition in llama_batched_vllm.py + if head_size <= 64: + block_size = 256 + elif head_size <= 128: + block_size = 128 + else: + block_size = 64 else: allocate_func_name = "tvm.contrib.vllm.allocate_kv_cache" copy_blocks_func_name = "tvm.contrib.vllm.copy_blocks" @@ -475,13 +488,6 @@ def init_tvm_model( if model_artifact_config.num_shards > 1: model.disco_session.sync_worker_0() - num_kv_heads = ( - model_artifact_config.num_key_value_heads // model_artifact_config.num_shards - ) - head_size = ( - model_artifact_config.hidden_size // model_artifact_config.num_attention_heads - ) - if engine_config.max_num_batched_tokens > 0: LOG.info("Running memory profiling.") num_blocks = get_num_cache_blocks(