Skip to content

Commit

Permalink
Use more accurate number of KV cache slots consumption for SWA (#114)
Browse files Browse the repository at this point in the history
Use more accurate number for KV cache slots consumption for SWA
  • Loading branch information
masahi authored Dec 15, 2023
1 parent b497a2c commit fa424e2
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ class EngineBase:
max_num_batched_tokens: int
max_decode_steps: int
min_decode_steps: int
kv_cache_size: int
max_prompt_len: int
model_context_window_size: int
queue_lock: Lock
queue: Deque[RequestState]
has_new_requests: Condition
Expand All @@ -285,6 +288,13 @@ def __init__(self, model_module: ModelModule):
self.min_decode_steps = min(
self.max_decode_steps - 1, model_module.engine_config.min_decode_steps
)
self.kv_cache_size = self.cache_manager.get_kv_cache_size()
self.max_prompt_len = min(self.max_context_length, self.max_num_batched_tokens)

if self.model_artifact_config.sliding_window is not None:
self.model_context_window_size = self.model_artifact_config.sliding_window
else:
self.model_context_window_size = self.max_context_length

self.queue_lock = Lock()
self.queue = deque[RequestState]()
Expand All @@ -293,14 +303,17 @@ def __init__(self, model_module: ModelModule):
self.current_batch = dict[RequestId, RequestState]()

def check_prompt_too_long(self, prompt_len: int, num_sequences: int = 1) -> bool:
kv_cache_size = self.cache_manager.get_kv_cache_size()
max_prompt_len = min(self.max_context_length, self.max_num_batched_tokens)

# We make sure that the KV cache will have enough free space for this request to proceed
# decoding for at least self.max_decode_steps steps.
#
# For models using SWA, the number of consumed cache slots is upper bounded by the window
# size. This assumes that the model implementation does not store past KV tensors beyond
# the window into the cache.
num_kv_slots_needed = min(prompt_len, self.model_context_window_size)
return (
prompt_len > max_prompt_len
or (kv_cache_size - prompt_len) < self.max_decode_steps * num_sequences
prompt_len > self.max_prompt_len
or (self.kv_cache_size - num_kv_slots_needed)
< self.max_decode_steps * num_sequences
)

def evict_request(self, cancell_callback: Callable[[RequestId], None]) -> int:
Expand Down Expand Up @@ -397,7 +410,9 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]:

# We make sure that the KV cache will have enough free space for this request to proceed
# decoding for at least self.max_decode_steps steps.
if (self.cache_manager.get_free_space() - num_tokens) / (
# See the comment in check_prompt_too_long for the optimization involving the window size.
num_kv_slots_needed = min(num_tokens, self.model_context_window_size)
if (self.cache_manager.get_free_space() - num_kv_slots_needed) / (
len(self.current_batch) + 1
) < self.max_decode_steps * state.num_sequences:
LOG.debug(
Expand Down

0 comments on commit fa424e2

Please sign in to comment.