From 21a9211491af7deeb8dff343a0141855573ac0d8 Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 21 Nov 2023 11:05:14 +0900 Subject: [PATCH] Fix hanging on prompt counts > max model context len (#74) * fix cancelled request not awaited * fix working * compare against max_context_len --------- Co-authored-by: Masahiro Masuda --- serve/mlc_serve/engine/async_connector.py | 2 +- serve/mlc_serve/engine/staging_engine_worker.py | 5 +++-- serve/mlc_serve/engine/sync_engine.py | 5 ++++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/serve/mlc_serve/engine/async_connector.py b/serve/mlc_serve/engine/async_connector.py index 888cd81447..afc8068b37 100644 --- a/serve/mlc_serve/engine/async_connector.py +++ b/serve/mlc_serve/engine/async_connector.py @@ -85,7 +85,7 @@ async def generate(self, request: Request) -> AsyncIterator[RequestOutput]: if output.is_finished: return except asyncio.CancelledError: - asyncio.to_thread(self.engine.cancel, request.request_id) + await asyncio.to_thread(self.engine.cancel, request.request_id) finally: self.result_queues.pop(request.request_id, None) diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index 82b0b6a8a8..6e12175184 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -58,7 +58,7 @@ def __init__( self.cache_manager = model_module.cache_manager self.tokenizer = model_module.tokenizer self.model_artifact_config = model_module.model_artifact_config - + self.max_context_length = self.model_artifact_config.max_context_length self.max_num_batched_tokens = model_module.engine_config.max_num_batched_tokens self.max_decode_steps = min( self.cache_manager.get_kv_cache_size(), model_module.engine_config.max_decode_steps @@ -81,10 +81,11 @@ def add(self, request_states: list[RequestState]): # cancel them instead. valid_states = [] for request_state in request_states: - if request_state.validation_err is not None: + if request_state.validation_err is not None or request_state.prompt_len >= self.max_context_length: self.cancelled_requests.append(request_state) else: valid_states.append(request_state) + self.queue.extend(valid_states) self.has_new_requests.notify_all() diff --git a/serve/mlc_serve/engine/sync_engine.py b/serve/mlc_serve/engine/sync_engine.py index 007b0c4678..1d52b300f0 100644 --- a/serve/mlc_serve/engine/sync_engine.py +++ b/serve/mlc_serve/engine/sync_engine.py @@ -41,7 +41,7 @@ def __init__( self.conversation_template = model_module.conversation_template self.cache_manager = model_module.cache_manager self.model_artifact_config = model_module.model_artifact_config - + self.max_context_length = self.model_artifact_config.max_context_length self.max_num_batched_tokens = model_module.engine_config.max_num_batched_tokens self.max_decode_steps = min( self.cache_manager.get_kv_cache_size(), model_module.engine_config.max_decode_steps @@ -69,6 +69,9 @@ def add(self, requests: list[Request]): state = self._get_new_request_state(req) new_request_states.append(state) + if state.prompt_len >= self.max_context_length: + self.cancel(req.request_id) + with self.queue_lock: self.queue.extend(new_request_states) self.has_new_requests.notify_all()