diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index 347d678745..02c4e8d5ab 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -128,6 +128,33 @@ def cancel_request(self, request_id: RequestId): def stop_request(self, request_id: RequestId): self._cacnel_or_stop_request(request_id, self.stopped_requests) + def create_aborted_outputs( + self, + cancelled_or_stopped_requests: List[RequestState], + finish_reason: FinishReason, + ): + outputs = [] + for state in cancelled_or_stopped_requests: + err = None + if state.validation_err: + err = state.validation_err + + for gen_seq in state.generation_sequences: + outputs.append( + SequenceGenerationOutput( + id=gen_seq.seq_id, + new_tokens=[], + finish_reason=finish_reason, + error=err, + ) + ) + + if state.request_id in self.current_batch: + self.remove_request_from_batch(state.request_id) + + cancelled_or_stopped_requests.clear() + return outputs + def wait_for_request(self, timeout_seconds=None): with self.queue_lock: self.has_new_requests.wait_for( @@ -166,45 +193,24 @@ def step(self) -> GenerationLoopWorkerOutput: duration = time.time() - state.arrival_timestamp self.prom_metrics.histogram(E2E_LATENCY).observe(duration) - for state in self.stopped_requests: - for gen_seq in state.generation_sequences: - outputs.append( - SequenceGenerationOutput( - id=gen_seq.seq_id, - new_tokens=[], - finish_reason=FinishReason.Stop, - ) - ) - - if state.request_id in self.current_batch: - self.remove_request_from_batch(state.request_id) - - self.stopped_requests.clear() + outputs += self.create_aborted_outputs( + self.stopped_requests, finish_reason=FinishReason.Stop + ) with self.queue_lock: # Hold the lock here since self.cancelled_requests is modified in add(...) as well. - for state in self.cancelled_requests: - err = None - if state.validation_err: - err = state.validation_err - - for gen_seq in state.generation_sequences: - outputs.append( - SequenceGenerationOutput( - id=gen_seq.seq_id, - new_tokens=[], - finish_reason=FinishReason.Cancelled, - error=err, - ) - ) - - if state.request_id in self.current_batch: - self.remove_request_from_batch(state.request_id) - - self.cancelled_requests.clear() + outputs += self.create_aborted_outputs( + self.cancelled_requests, finish_reason=FinishReason.Cancelled + ) self._adjust_batch() + with self.queue_lock: + # _adjust_batch also adds to self.cancelled_requests + outputs += self.create_aborted_outputs( + self.cancelled_requests, finish_reason=FinishReason.Cancelled + ) + if not self.current_batch: if len(self.queue) > 0: LOG.warn( @@ -223,10 +229,9 @@ def step(self) -> GenerationLoopWorkerOutput: for res in results: request_id = res.sequence_id.request_id - if res.error is not None: - if request_id not in failed_requests: - failed_requests.add(request_id) - self.remove_request_from_batch(request_id) + if res.error is not None and request_id not in failed_requests: + failed_requests.add(request_id) + self.remove_request_from_batch(request_id) outputs.append( SequenceGenerationOutput( @@ -275,7 +280,11 @@ def step(self) -> GenerationLoopWorkerOutput: def _adjust_batch(self): with self.queue_lock: - num_eviction = self.evict_request() + num_eviction = self.evict_request( + cancell_callback=lambda request_id: self.cancelled_requests.append( + self.current_batch[request_id] + ) + ) self.prom_metrics.counter(NUM_CACHE_EVICTONS).inc(num_eviction) if self.cache_manager.get_max_new_tokens() <= self.max_decode_steps: