Skip to content

Commit

Permalink
cancel works for staging engien too
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 12, 2023
1 parent c4279b9 commit 2ed26e0
Showing 1 changed file with 47 additions and 38 deletions.
85 changes: 47 additions & 38 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,33 @@ def cancel_request(self, request_id: RequestId):
def stop_request(self, request_id: RequestId):
self._cacnel_or_stop_request(request_id, self.stopped_requests)

def create_aborted_outputs(
self,
cancelled_or_stopped_requests: List[RequestState],
finish_reason: FinishReason,
):
outputs = []
for state in cancelled_or_stopped_requests:
err = None
if state.validation_err:
err = state.validation_err

for gen_seq in state.generation_sequences:
outputs.append(
SequenceGenerationOutput(
id=gen_seq.seq_id,
new_tokens=[],
finish_reason=finish_reason,
error=err,
)
)

if state.request_id in self.current_batch:
self.remove_request_from_batch(state.request_id)

cancelled_or_stopped_requests.clear()
return outputs

def wait_for_request(self, timeout_seconds=None):
with self.queue_lock:
self.has_new_requests.wait_for(
Expand Down Expand Up @@ -166,45 +193,24 @@ def step(self) -> GenerationLoopWorkerOutput:
duration = time.time() - state.arrival_timestamp
self.prom_metrics.histogram(E2E_LATENCY).observe(duration)

for state in self.stopped_requests:
for gen_seq in state.generation_sequences:
outputs.append(
SequenceGenerationOutput(
id=gen_seq.seq_id,
new_tokens=[],
finish_reason=FinishReason.Stop,
)
)

if state.request_id in self.current_batch:
self.remove_request_from_batch(state.request_id)

self.stopped_requests.clear()
outputs += self.create_aborted_outputs(
self.stopped_requests, finish_reason=FinishReason.Stop
)

with self.queue_lock:
# Hold the lock here since self.cancelled_requests is modified in add(...) as well.
for state in self.cancelled_requests:
err = None
if state.validation_err:
err = state.validation_err

for gen_seq in state.generation_sequences:
outputs.append(
SequenceGenerationOutput(
id=gen_seq.seq_id,
new_tokens=[],
finish_reason=FinishReason.Cancelled,
error=err,
)
)

if state.request_id in self.current_batch:
self.remove_request_from_batch(state.request_id)

self.cancelled_requests.clear()
outputs += self.create_aborted_outputs(
self.cancelled_requests, finish_reason=FinishReason.Cancelled
)

self._adjust_batch()

with self.queue_lock:
# _adjust_batch also adds to self.cancelled_requests
outputs += self.create_aborted_outputs(
self.cancelled_requests, finish_reason=FinishReason.Cancelled
)

if not self.current_batch:
if len(self.queue) > 0:
LOG.warn(
Expand All @@ -223,10 +229,9 @@ def step(self) -> GenerationLoopWorkerOutput:
for res in results:
request_id = res.sequence_id.request_id

if res.error is not None:
if request_id not in failed_requests:
failed_requests.add(request_id)
self.remove_request_from_batch(request_id)
if res.error is not None and request_id not in failed_requests:
failed_requests.add(request_id)
self.remove_request_from_batch(request_id)

outputs.append(
SequenceGenerationOutput(
Expand Down Expand Up @@ -275,7 +280,11 @@ def step(self) -> GenerationLoopWorkerOutput:

def _adjust_batch(self):
with self.queue_lock:
num_eviction = self.evict_request()
num_eviction = self.evict_request(
cancell_callback=lambda request_id: self.cancelled_requests.append(
self.current_batch[request_id]
)
)
self.prom_metrics.counter(NUM_CACHE_EVICTONS).inc(num_eviction)

if self.cache_manager.get_max_new_tokens() <= self.max_decode_steps:
Expand Down

0 comments on commit 2ed26e0

Please sign in to comment.