From 0d861591f7ac1eb03fa4c7557cebb793735eba78 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 6 Dec 2023 10:28:02 +0000 Subject: [PATCH] Cancel hung requests --- .../mlc_serve/engine/staging_engine_worker.py | 34 +++++++++++++++---- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index f5e4968dde..5df50994fd 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -235,13 +235,12 @@ def step(self) -> GenerationLoopWorkerOutput: self.cancelled_requests.clear() - self._adjust_batch() + request_ids_to_cancel = self._adjust_batch() + + for request_id in request_ids_to_cancel: + self.cancel_request(request_id) if not self.current_batch: - if len(self.queue) > 0: - LOG.warn( - f"The engine has {len(self.queue)} requests to be processed in the queue, but none of them were added to the current batch during the execution of StagingEngine._adjust_batch" - ) return result requests, is_prompt_batch = self._get_requests_to_process() @@ -294,7 +293,8 @@ def step(self) -> GenerationLoopWorkerOutput: return result - def _adjust_batch(self): + def _adjust_batch(self) -> List[RequestId]: + """Form a new batch and return a list of request IDs that should be cancelled, if any.""" with self.queue_lock: while self.cache_manager.get_max_new_tokens() < 1: self.prom_metrics.counter(NUM_CACHE_EVICTONS).inc() @@ -313,7 +313,7 @@ def _adjust_batch(self): "Skip growing the batch due to max_decode_steps. Decode steps: %s", self.cache_manager.get_max_new_tokens(), ) - return + return [] num_new_batched_tokens = len(self.current_batch) while self.queue: @@ -363,6 +363,26 @@ def _adjust_batch(self): self.cache_manager.allocate(state.request_id, num_tokens) self.current_batch[state.request_id] = state + if len(self.current_batch) == 0 and len(self.queue) > 0: + LOG.warn( + f"The engine has {len(self.queue)} requests to be processed in the queue, but" + " none of them were added to the current batch during the execution of" + " StagingEngine._adjust_batch" + ) + + hung_request_ids = [] + + for state in self.queue: + hung_request_ids.append(state.request_id) + # TODO(masahi): Proper error enum? + state.validation_err = ValidationError( + "Internal Server Error: Canceled due to a hang in the server." + ) + + return hung_request_ids + + return [] + def _remove_request_from_batch(self, request_id: RequestId): del self.current_batch[request_id] self.cache_manager.free(SequenceId(request_id, 0))