diff --git a/serve/mlc_serve/engine/sync_engine.py b/serve/mlc_serve/engine/sync_engine.py index df655d3e59..8793797edf 100644 --- a/serve/mlc_serve/engine/sync_engine.py +++ b/serve/mlc_serve/engine/sync_engine.py @@ -5,9 +5,12 @@ import logging from typing import Deque, Set, Dict from collections import deque +from sre_parse import Tokenizer from threading import Condition, Lock +from uuid import uuid4 from .base import ( + DebugOptions, FinishReason, InferenceEngine, InferenceStepResult, @@ -15,18 +18,12 @@ RequestId, RequestOutput, RequestState, + SamplingParams, SequenceOutput, + StoppingCriteria, check_stopping_sequences, - ValidationError, -) -from .model_module import ( - DecodeRequest, - ModelModule, - PrefillRequest, - SequenceId, - TextGenerator, - Tokenizer as TokenizerP, ) +from .model_module import DecodeRequest, ModelModule, PrefillRequest, SequenceId, TextGenerator, Tokenizer as TokenizerP from ..model.base import ModelArtifactConfig logger = logging.getLogger(__name__) @@ -37,7 +34,6 @@ class SynchronousInferenceEngine(InferenceEngine): A implementation of InferenceEngine that does inference synchronously in the current thread when `step` is called. """ - text_generator: TextGenerator tokenizer: TokenizerP model_artifact_config: ModelArtifactConfig @@ -61,9 +57,7 @@ def __init__( self.conversation_template = model_module.conversation_template self.cache_manager = model_module.cache_manager self.model_artifact_config = model_module.model_artifact_config - assert ( - self.model_artifact_config.max_context_length - ), "max_context_length must not be zero" + assert self.model_artifact_config.max_context_length, "max_context_length must not be zero" self.max_context_length = self.model_artifact_config.max_context_length self.max_num_batched_tokens = model_module.engine_config.max_num_batched_tokens self.max_decode_steps = min( @@ -159,23 +153,13 @@ def step(self) -> InferenceStepResult: previous_requests_to_be_cancelled = set(self.requests_to_be_cancelled) self._adjust_batch() - + if not self.current_batch: if len(self.queue) > 0: logger.warning( f"The engine has {len(self.queue)} requests to be processed in the queue, but none of them were added to the current batch during the execution of SyncEngine._adjust_batch" ) - hung_request_ids = [] - - for state in self.queue: - hung_request_ids.append(state.request_id) - # TODO(masahi): Proper error enum? - state.validation_err = ValidationError("Canceled due to a hang") - - for request_id in hung_request_ids: - self.cancel(request_id) - for request_id in previous_requests_to_be_cancelled: if request_id not in self.requests_to_be_cancelled: outputs.append(