From 70e5918f1af225df979dd20d9deb1e6789efd27f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 14 Mar 2024 21:08:42 +0000 Subject: [PATCH 1/5] Protect against invalid request format --- serve/mlc_serve/engine/staging_engine.py | 39 ++++++++++++++++++------ 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index a457ff8385..f6ee03a751 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -57,6 +57,9 @@ def __init__( self.next_generation_output = None self.requests_lock = Lock() self.requests = dict[RequestId, RequestState]() + self.requests_to_be_cancelled_lock = Lock() + # Error message for each request that fails to be added to the engine + self.requests_to_be_cancelled = dict[RequestId, str]() # TODO(@team): This is a temporary solution to expose model config to higher API layer. # Follow-up with the proper solution @@ -119,13 +122,17 @@ def add(self, requests: list[Request]): assert isinstance(req.stopping_criteria.stop_sequences, list) # If the request violates the tokenization, this returns None, so skip. - state = get_new_request_state( - req, - self.conversation_template, - self.tokenizer, - self.model_artifact_config.vocab_size, - ) - new_request_states.append(state) + try: + state = get_new_request_state( + req, + self.conversation_template, + self.tokenizer, + self.model_artifact_config.vocab_size, + ) + new_request_states.append(state) + except Exception as e: + with self.requests_to_be_cancelled_lock: + self.requests_to_be_cancelled[req.request_id] = str(e) self.command_queue.put(AddRequestsCommand(request_states=new_request_states)) @@ -171,11 +178,25 @@ def step(self) -> InferenceStepResult: has_pending_requests=self.has_pending_requests(), ) + outputs = list[RequestOutput]() + + with self.requests_to_be_cancelled_lock: + if len(self.requests_to_be_cancelled) > 0: + for req_id, err_msg in self.requests_to_be_cancelled.items(): + outputs.append( + RequestOutput( + req_id, + sequences=[], + error=err_msg, + ) + ) + self.requests_to_be_cancelled.clear() + if not self._is_ready_to_serve(): raise RuntimeError("GenerationLoopWorker process is not running") if not self.has_pending_requests(): - return InferenceStepResult([]) + return InferenceStepResult(outputs) if self.next_generation_output is None: generation_output = self.result_queue.get() @@ -188,8 +209,6 @@ def step(self) -> InferenceStepResult: f"Error from GenerationLoopWorker process: {generation_output.error}" ) from generation_output.error - outputs = list[RequestOutput]() - with self.requests_lock: LOG.debug( "StagingInferenceEngine.step obtained requests_lock", From 23a76a72f5c7b64ea612ced42d66b823a953e56b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 14 Mar 2024 21:41:37 +0000 Subject: [PATCH 2/5] add warning when a request fails to be added --- serve/mlc_serve/engine/staging_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index f6ee03a751..df4bd91566 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -131,6 +131,7 @@ def add(self, requests: list[Request]): ) new_request_states.append(state) except Exception as e: + LOG.warn("Failed to add a request", request_id=req.request_id) with self.requests_to_be_cancelled_lock: self.requests_to_be_cancelled[req.request_id] = str(e) From 3ddfbd84a28c6b1d7f30ccf122a53526e7097285 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 15 Mar 2024 11:09:03 +0000 Subject: [PATCH 3/5] revert --- serve/mlc_serve/engine/staging_engine.py | 40 ++++++------------------ 1 file changed, 10 insertions(+), 30 deletions(-) diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index df4bd91566..a457ff8385 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -57,9 +57,6 @@ def __init__( self.next_generation_output = None self.requests_lock = Lock() self.requests = dict[RequestId, RequestState]() - self.requests_to_be_cancelled_lock = Lock() - # Error message for each request that fails to be added to the engine - self.requests_to_be_cancelled = dict[RequestId, str]() # TODO(@team): This is a temporary solution to expose model config to higher API layer. # Follow-up with the proper solution @@ -122,18 +119,13 @@ def add(self, requests: list[Request]): assert isinstance(req.stopping_criteria.stop_sequences, list) # If the request violates the tokenization, this returns None, so skip. - try: - state = get_new_request_state( - req, - self.conversation_template, - self.tokenizer, - self.model_artifact_config.vocab_size, - ) - new_request_states.append(state) - except Exception as e: - LOG.warn("Failed to add a request", request_id=req.request_id) - with self.requests_to_be_cancelled_lock: - self.requests_to_be_cancelled[req.request_id] = str(e) + state = get_new_request_state( + req, + self.conversation_template, + self.tokenizer, + self.model_artifact_config.vocab_size, + ) + new_request_states.append(state) self.command_queue.put(AddRequestsCommand(request_states=new_request_states)) @@ -179,25 +171,11 @@ def step(self) -> InferenceStepResult: has_pending_requests=self.has_pending_requests(), ) - outputs = list[RequestOutput]() - - with self.requests_to_be_cancelled_lock: - if len(self.requests_to_be_cancelled) > 0: - for req_id, err_msg in self.requests_to_be_cancelled.items(): - outputs.append( - RequestOutput( - req_id, - sequences=[], - error=err_msg, - ) - ) - self.requests_to_be_cancelled.clear() - if not self._is_ready_to_serve(): raise RuntimeError("GenerationLoopWorker process is not running") if not self.has_pending_requests(): - return InferenceStepResult(outputs) + return InferenceStepResult([]) if self.next_generation_output is None: generation_output = self.result_queue.get() @@ -210,6 +188,8 @@ def step(self) -> InferenceStepResult: f"Error from GenerationLoopWorker process: {generation_output.error}" ) from generation_output.error + outputs = list[RequestOutput]() + with self.requests_lock: LOG.debug( "StagingInferenceEngine.step obtained requests_lock", From ad482bf346a6c44bb8b0f85d0888c141c831ebd8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 15 Mar 2024 11:21:28 +0000 Subject: [PATCH 4/5] alternative suggested by lite --- serve/mlc_serve/engine/async_connector.py | 5 ++++- serve/mlc_serve/engine/staging_engine.py | 20 ++++++++++++-------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/serve/mlc_serve/engine/async_connector.py b/serve/mlc_serve/engine/async_connector.py index 907e840314..b255e3378d 100644 --- a/serve/mlc_serve/engine/async_connector.py +++ b/serve/mlc_serve/engine/async_connector.py @@ -150,7 +150,10 @@ async def _add_request(self, request: Request) -> ResultQueue: queue = asyncio.Queue() self.result_queues[request.request_id] = queue - await asyncio.to_thread(self.engine.add, [request]) + try: + await asyncio.to_thread(self.engine.add, [request]) + except TextGenerationError as e: + raise asyncio.CancelledError(e) return queue diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index a457ff8385..91ed95c26a 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -21,9 +21,9 @@ ScopedInferenceEngine, SequenceOutput, ) +from .error import TextGenerationError from .engine_common import get_new_request_state, prepare_output from .model_module import ModelModule, TokenizerModule -from ..model.base import get_model_artifact_config from .staging_engine_worker import ( AddRequestsCommand, CancelRequestCommand, @@ -119,13 +119,17 @@ def add(self, requests: list[Request]): assert isinstance(req.stopping_criteria.stop_sequences, list) # If the request violates the tokenization, this returns None, so skip. - state = get_new_request_state( - req, - self.conversation_template, - self.tokenizer, - self.model_artifact_config.vocab_size, - ) - new_request_states.append(state) + try: + state = get_new_request_state( + req, + self.conversation_template, + self.tokenizer, + self.model_artifact_config.vocab_size, + ) + new_request_states.append(state) + except Exception as e: + LOG.warn("Failed to add a request", request_id=req.request_id) + raise TextGenerationError(str(e)) self.command_queue.put(AddRequestsCommand(request_states=new_request_states)) From 39065c2ae572296f9c188456b1213a675ecace2d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 15 Mar 2024 19:49:28 +0000 Subject: [PATCH 5/5] revert async_connector change --- serve/mlc_serve/engine/async_connector.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/serve/mlc_serve/engine/async_connector.py b/serve/mlc_serve/engine/async_connector.py index b255e3378d..907e840314 100644 --- a/serve/mlc_serve/engine/async_connector.py +++ b/serve/mlc_serve/engine/async_connector.py @@ -150,10 +150,7 @@ async def _add_request(self, request: Request) -> ResultQueue: queue = asyncio.Queue() self.result_queues[request.request_id] = queue - try: - await asyncio.to_thread(self.engine.add, [request]) - except TextGenerationError as e: - raise asyncio.CancelledError(e) + await asyncio.to_thread(self.engine.add, [request]) return queue