diff --git a/serve/mlc_serve/model/paged_cache_manager.py b/serve/mlc_serve/model/paged_cache_manager.py index 86b9c76322..d34eafd2d5 100644 --- a/serve/mlc_serve/model/paged_cache_manager.py +++ b/serve/mlc_serve/model/paged_cache_manager.py @@ -2,9 +2,6 @@ from collections import defaultdict from typing import List, Optional -# TODO: remove this -import tvm - from ..engine import ( RequestId, SequenceId, @@ -106,25 +103,11 @@ def replace_head_prompt_block_with(self, new_block): class KVCache: def __init__( - self, num_blocks, block_size, num_layers, num_heads, head_size, disco_session + self, + cache_blocks, + block_size, ): - if disco_session: - init_cache_func = disco_session.get_global_func( - "tvm.contrib.vllm.allocate_kv_cache" - ) - self.copy_cache_blocks_func = disco_session.get_global_func( - "tvm.contrib.vllm.copy_blocks" - ) - else: - init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache") - self.copy_cache_blocks_func = tvm.get_global_func( - "tvm.contrib.vllm.copy_blocks" - ) - - self.cache = init_cache_func( - head_size, num_layers, num_heads, block_size, num_blocks - ) - + self.cache_blocks = cache_blocks self.block_size = block_size # SequenceId -> list[int] @@ -152,18 +135,13 @@ def get_cache_block_size(num_layers, num_heads, head_size): def __init__( self, - num_blocks, - num_layers, - num_heads, - head_size, - disco_session=None, - sliding_window=None, + cache_blocks, # This can be any type + num_blocks: int, + sliding_window: Optional[int] = None, ): self.num_blocks = num_blocks self.free_blocks = list(range(num_blocks)) - self.kv_cache = KVCache( - num_blocks, self.block_size, num_layers, num_heads, head_size, disco_session - ) + self.kv_cache = KVCache(cache_blocks, self.block_size) self.allocated_prompt_tokens = dict[SequenceId, int]() self.allocated_decode_tokens = dict[SequenceId, int]() diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 5d218ba583..c0e4794289 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -275,6 +275,15 @@ def __init__( else: self.block_sliding_window = None + if self.disco_session: + self.copy_cache_blocks_func = self.disco_session.get_global_func( + "tvm.contrib.vllm.copy_blocks" + ) + else: + self.copy_cache_blocks_func = tvm.get_global_func( + "tvm.contrib.vllm.copy_blocks" + ) + def get_used_memory(self): if self.disco_session: params = self.params.debug_get_from_remote(0) @@ -376,8 +385,6 @@ def generate( seq_lens = copy_to_worker_0(self.disco_session, seq_lens) slot_mapping = copy_to_worker_0(self.disco_session, slot_mapping) - kv_cache = cache.cache - if is_prefill: torch.cuda.nvtx.range_push(f"forward prefill {input_shape}") @@ -391,14 +398,19 @@ def generate( input_ids, positions, seq_lens, - kv_cache, + cache.cache_blocks, slot_mapping, indices_within_window, self.params, ) else: out = self.mod["prefill"]( - input_ids, positions, seq_lens, kv_cache, slot_mapping, self.params + input_ids, + positions, + seq_lens, + cache.cache_blocks, + slot_mapping, + self.params, ) if self.disco_session: @@ -417,7 +429,7 @@ def generate( input_ids, positions, seq_lens, - kv_cache, + cache.cache_blocks, slot_mapping, block_tables, self.params, @@ -445,7 +457,7 @@ def generate( "int64", ) - cache.copy_cache_blocks_func(kv_cache, block_mapping) + self.copy_cache_blocks_func(cache.cache_blocks, block_mapping) cache.pending_copy_from_to = [] try: @@ -633,12 +645,24 @@ def __init__( LOG.info(f"Using {num_blocks} cache blocks.") - cache_manager = CacheManager( - num_blocks, + if model.disco_session: + init_cache_func = model.disco_session.get_global_func( + "tvm.contrib.vllm.allocate_kv_cache" + ) + else: + init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache") + + cache_blocks = init_cache_func( + head_size, model_artifact_config.num_hidden_layers, num_kv_heads, - head_size, - model.disco_session, + CacheManager.block_size, + num_blocks, + ) + + cache_manager = CacheManager( + cache_blocks, + num_blocks, model_artifact_config.sliding_window, )