diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 43e12290c0..488cfc1d1c 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -262,6 +262,9 @@ class EngineBase: max_num_batched_tokens: int max_decode_steps: int min_decode_steps: int + kv_cache_size: int + max_prompt_len: int + model_context_window_size: int queue_lock: Lock queue: Deque[RequestState] has_new_requests: Condition @@ -285,6 +288,13 @@ def __init__(self, model_module: ModelModule): self.min_decode_steps = min( self.max_decode_steps - 1, model_module.engine_config.min_decode_steps ) + self.kv_cache_size = self.cache_manager.get_kv_cache_size() + self.max_prompt_len = min(self.max_context_length, self.max_num_batched_tokens) + + if self.model_artifact_config.sliding_window is not None: + self.model_context_window_size = self.model_artifact_config.sliding_window + else: + self.model_context_window_size = self.max_context_length self.queue_lock = Lock() self.queue = deque[RequestState]() @@ -293,14 +303,17 @@ def __init__(self, model_module: ModelModule): self.current_batch = dict[RequestId, RequestState]() def check_prompt_too_long(self, prompt_len: int, num_sequences: int = 1) -> bool: - kv_cache_size = self.cache_manager.get_kv_cache_size() - max_prompt_len = min(self.max_context_length, self.max_num_batched_tokens) - # We make sure that the KV cache will have enough free space for this request to proceed # decoding for at least self.max_decode_steps steps. + # + # For models using SWA, the number of consumed cache slots is upper bounded by the window + # size. This assumes that the model implementation does not store past KV tensors beyond + # the window into the cache. + num_kv_slots_needed = min(prompt_len, self.model_context_window_size) return ( - prompt_len > max_prompt_len - or (kv_cache_size - prompt_len) < self.max_decode_steps * num_sequences + prompt_len > self.max_prompt_len + or (self.kv_cache_size - num_kv_slots_needed) + < self.max_decode_steps * num_sequences ) def evict_request(self, cancell_callback: Callable[[RequestId], None]) -> int: @@ -397,7 +410,9 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]: # We make sure that the KV cache will have enough free space for this request to proceed # decoding for at least self.max_decode_steps steps. - if (self.cache_manager.get_free_space() - num_tokens) / ( + # See the comment in check_prompt_too_long for the optimization involving the window size. + num_kv_slots_needed = min(num_tokens, self.model_context_window_size) + if (self.cache_manager.get_free_space() - num_kv_slots_needed) / ( len(self.current_batch) + 1 ) < self.max_decode_steps * state.num_sequences: LOG.debug(