Skip to content

Commit

Permalink
Cancel hung requests
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 6, 2023
1 parent f55e6f6 commit 0d86159
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,12 @@ def step(self) -> GenerationLoopWorkerOutput:

self.cancelled_requests.clear()

self._adjust_batch()
request_ids_to_cancel = self._adjust_batch()

for request_id in request_ids_to_cancel:
self.cancel_request(request_id)

if not self.current_batch:
if len(self.queue) > 0:
LOG.warn(
f"The engine has {len(self.queue)} requests to be processed in the queue, but none of them were added to the current batch during the execution of StagingEngine._adjust_batch"
)
return result

requests, is_prompt_batch = self._get_requests_to_process()
Expand Down Expand Up @@ -294,7 +293,8 @@ def step(self) -> GenerationLoopWorkerOutput:

return result

def _adjust_batch(self):
def _adjust_batch(self) -> List[RequestId]:
"""Form a new batch and return a list of request IDs that should be cancelled, if any."""
with self.queue_lock:
while self.cache_manager.get_max_new_tokens() < 1:
self.prom_metrics.counter(NUM_CACHE_EVICTONS).inc()
Expand All @@ -313,7 +313,7 @@ def _adjust_batch(self):
"Skip growing the batch due to max_decode_steps. Decode steps: %s",
self.cache_manager.get_max_new_tokens(),
)
return
return []

num_new_batched_tokens = len(self.current_batch)
while self.queue:
Expand Down Expand Up @@ -363,6 +363,26 @@ def _adjust_batch(self):
self.cache_manager.allocate(state.request_id, num_tokens)
self.current_batch[state.request_id] = state

if len(self.current_batch) == 0 and len(self.queue) > 0:
LOG.warn(
f"The engine has {len(self.queue)} requests to be processed in the queue, but"
" none of them were added to the current batch during the execution of"
" StagingEngine._adjust_batch"
)

hung_request_ids = []

for state in self.queue:
hung_request_ids.append(state.request_id)
# TODO(masahi): Proper error enum?
state.validation_err = ValidationError(
"Internal Server Error: Canceled due to a hang in the server."
)

return hung_request_ids

return []

def _remove_request_from_batch(self, request_id: RequestId):
del self.current_batch[request_id]
self.cache_manager.free(SequenceId(request_id, 0))
Expand Down

0 comments on commit 0d86159

Please sign in to comment.