diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index bee83665a6..93bfcb378d 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -219,7 +219,10 @@ def step(self) -> GenerationLoopWorkerOutput: self.cancelled_requests.clear() - self._adjust_batch() + request_ids_to_cancel = self._adjust_batch() + + for request_id in request_ids_to_cancel: + self.cancel_request(request_id) if not self.current_batch: return result @@ -266,7 +269,8 @@ def step(self) -> GenerationLoopWorkerOutput: return result - def _adjust_batch(self): + def _adjust_batch(self) -> List[RequestId]: + """Form a new batch and return a list of request IDs that should be cancelled, if any.""" with self.queue_lock: while self.cache_manager.get_max_new_tokens() < 1: request_to_remove = min( @@ -284,7 +288,7 @@ def _adjust_batch(self): "Skip growing the batch due to max_decode_steps. Decode steps: %s", self.cache_manager.get_max_new_tokens(), ) - return + return [] num_new_batched_tokens = len(self.current_batch) while self.queue: @@ -335,8 +339,9 @@ def _adjust_batch(self): # TODO(masahi): Proper error enum? state.validation_err = ValidationError("Canceled due to a hang") - for request_id in hung_request_ids: - self.cancel_request(request_id) + return hung_request_ids + + return [] def _remove_request_from_batch(self, request_id: RequestId): del self.current_batch[request_id]