diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index dabc99fe2c..bee83665a6 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -222,23 +222,6 @@ def step(self) -> GenerationLoopWorkerOutput: self._adjust_batch() if not self.current_batch: - if len(self.queue) > 0: - LOG.warn( - f"The engine has {len(self.queue)} requests to be processed in the queue, but" - " none of them were added to the current batch during the execution of" - " StagingEngine._adjust_batch" - ) - - hung_request_ids = [] - - for state in self.queue: - hung_request_ids.append(state.request_id) - # TODO(masahi): Proper error enum? - state.validation_err = ValidationError("Canceled due to a hang") - - for request_id in hung_request_ids: - self.cancel_request(request_id) - return result requests = self._get_requests_to_process() @@ -337,6 +320,24 @@ def _adjust_batch(self): self.cache_manager.allocate(state.request_id, num_tokens) self.current_batch[state.request_id] = state + if not self.current_batch: + if len(self.queue) > 0: + LOG.warn( + f"The engine has {len(self.queue)} requests to be processed in the queue, but" + " none of them were added to the current batch during the execution of" + " StagingEngine._adjust_batch" + ) + + hung_request_ids = [] + + for state in self.queue: + hung_request_ids.append(state.request_id) + # TODO(masahi): Proper error enum? + state.validation_err = ValidationError("Canceled due to a hang") + + for request_id in hung_request_ids: + self.cancel_request(request_id) + def _remove_request_from_batch(self, request_id: RequestId): del self.current_batch[request_id] self.cache_manager.free(SequenceId(request_id, 0))