Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 14, 2023
1 parent 49efaf6 commit e04199d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 40 deletions.
38 changes: 8 additions & 30 deletions serve/mlc_serve/model/paged_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
from collections import defaultdict
from typing import List, Optional

# TODO: remove this
import tvm

from ..engine import (
RequestId,
SequenceId,
Expand Down Expand Up @@ -106,25 +103,11 @@ def replace_head_prompt_block_with(self, new_block):

class KVCache:
def __init__(
self, num_blocks, block_size, num_layers, num_heads, head_size, disco_session
self,
cache_blocks,
block_size,
):
if disco_session:
init_cache_func = disco_session.get_global_func(
"tvm.contrib.vllm.allocate_kv_cache"
)
self.copy_cache_blocks_func = disco_session.get_global_func(
"tvm.contrib.vllm.copy_blocks"
)
else:
init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache")
self.copy_cache_blocks_func = tvm.get_global_func(
"tvm.contrib.vllm.copy_blocks"
)

self.cache = init_cache_func(
head_size, num_layers, num_heads, block_size, num_blocks
)

self.cache_blocks = cache_blocks
self.block_size = block_size

# SequenceId -> list[int]
Expand Down Expand Up @@ -152,18 +135,13 @@ def get_cache_block_size(num_layers, num_heads, head_size):

def __init__(
self,
num_blocks,
num_layers,
num_heads,
head_size,
disco_session=None,
sliding_window=None,
cache_blocks, # This can be any type
num_blocks: int,
sliding_window: Optional[int] = None,
):
self.num_blocks = num_blocks
self.free_blocks = list(range(num_blocks))
self.kv_cache = KVCache(
num_blocks, self.block_size, num_layers, num_heads, head_size, disco_session
)
self.kv_cache = KVCache(cache_blocks, self.block_size)
self.allocated_prompt_tokens = dict[SequenceId, int]()
self.allocated_decode_tokens = dict[SequenceId, int]()

Expand Down
44 changes: 34 additions & 10 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,15 @@ def __init__(
else:
self.block_sliding_window = None

if self.disco_session:
self.copy_cache_blocks_func = self.disco_session.get_global_func(
"tvm.contrib.vllm.copy_blocks"
)
else:
self.copy_cache_blocks_func = tvm.get_global_func(
"tvm.contrib.vllm.copy_blocks"
)

def get_used_memory(self):
if self.disco_session:
params = self.params.debug_get_from_remote(0)
Expand Down Expand Up @@ -376,8 +385,6 @@ def generate(
seq_lens = copy_to_worker_0(self.disco_session, seq_lens)
slot_mapping = copy_to_worker_0(self.disco_session, slot_mapping)

kv_cache = cache.cache

if is_prefill:
torch.cuda.nvtx.range_push(f"forward prefill {input_shape}")

Expand All @@ -391,14 +398,19 @@ def generate(
input_ids,
positions,
seq_lens,
kv_cache,
cache.cache_blocks,
slot_mapping,
indices_within_window,
self.params,
)
else:
out = self.mod["prefill"](
input_ids, positions, seq_lens, kv_cache, slot_mapping, self.params
input_ids,
positions,
seq_lens,
cache.cache_blocks,
slot_mapping,
self.params,
)

if self.disco_session:
Expand All @@ -417,7 +429,7 @@ def generate(
input_ids,
positions,
seq_lens,
kv_cache,
cache.cache_blocks,
slot_mapping,
block_tables,
self.params,
Expand Down Expand Up @@ -445,7 +457,7 @@ def generate(
"int64",
)

cache.copy_cache_blocks_func(kv_cache, block_mapping)
self.copy_cache_blocks_func(cache.cache_blocks, block_mapping)
cache.pending_copy_from_to = []

try:
Expand Down Expand Up @@ -633,12 +645,24 @@ def __init__(

LOG.info(f"Using {num_blocks} cache blocks.")

cache_manager = CacheManager(
num_blocks,
if model.disco_session:
init_cache_func = model.disco_session.get_global_func(
"tvm.contrib.vllm.allocate_kv_cache"
)
else:
init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache")

cache_blocks = init_cache_func(
head_size,
model_artifact_config.num_hidden_layers,
num_kv_heads,
head_size,
model.disco_session,
CacheManager.block_size,
num_blocks,
)

cache_manager = CacheManager(
cache_blocks,
num_blocks,
model_artifact_config.sliding_window,
)

Expand Down

0 comments on commit e04199d

Please sign in to comment.