Skip to content

Commit

Permalink
Fix hanging on prompt counts > max model context len (#74)
Browse files Browse the repository at this point in the history
* fix cancelled request not awaited

* fix working

* compare against max_context_len

---------

Co-authored-by: Masahiro Masuda <[email protected]>
  • Loading branch information
masahi and Masahiro Masuda committed Nov 21, 2023
1 parent 5f253c7 commit 21a9211
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def generate(self, request: Request) -> AsyncIterator[RequestOutput]:
if output.is_finished:
return
except asyncio.CancelledError:
asyncio.to_thread(self.engine.cancel, request.request_id)
await asyncio.to_thread(self.engine.cancel, request.request_id)
finally:
self.result_queues.pop(request.request_id, None)

Expand Down
5 changes: 3 additions & 2 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self.cache_manager = model_module.cache_manager
self.tokenizer = model_module.tokenizer
self.model_artifact_config = model_module.model_artifact_config

self.max_context_length = self.model_artifact_config.max_context_length
self.max_num_batched_tokens = model_module.engine_config.max_num_batched_tokens
self.max_decode_steps = min(
self.cache_manager.get_kv_cache_size(), model_module.engine_config.max_decode_steps
Expand All @@ -81,10 +81,11 @@ def add(self, request_states: list[RequestState]):
# cancel them instead.
valid_states = []
for request_state in request_states:
if request_state.validation_err is not None:
if request_state.validation_err is not None or request_state.prompt_len >= self.max_context_length:
self.cancelled_requests.append(request_state)
else:
valid_states.append(request_state)

self.queue.extend(valid_states)
self.has_new_requests.notify_all()

Expand Down
5 changes: 4 additions & 1 deletion serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
self.conversation_template = model_module.conversation_template
self.cache_manager = model_module.cache_manager
self.model_artifact_config = model_module.model_artifact_config

self.max_context_length = self.model_artifact_config.max_context_length
self.max_num_batched_tokens = model_module.engine_config.max_num_batched_tokens
self.max_decode_steps = min(
self.cache_manager.get_kv_cache_size(), model_module.engine_config.max_decode_steps
Expand Down Expand Up @@ -69,6 +69,9 @@ def add(self, requests: list[Request]):
state = self._get_new_request_state(req)
new_request_states.append(state)

if state.prompt_len >= self.max_context_length:
self.cancel(req.request_id)

with self.queue_lock:
self.queue.extend(new_request_states)
self.has_new_requests.notify_all()
Expand Down

0 comments on commit 21a9211

Please sign in to comment.