Skip to content

Commit

Permalink
undo sync engine change
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Nov 30, 2023
1 parent ab5e4d2 commit c7e4c60
Showing 1 changed file with 8 additions and 24 deletions.
32 changes: 8 additions & 24 deletions serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,25 @@
import logging
from typing import Deque, Set, Dict
from collections import deque
from sre_parse import Tokenizer
from threading import Condition, Lock
from uuid import uuid4

from .base import (
DebugOptions,
FinishReason,
InferenceEngine,
InferenceStepResult,
Request,
RequestId,
RequestOutput,
RequestState,
SamplingParams,
SequenceOutput,
StoppingCriteria,
check_stopping_sequences,
ValidationError,
)
from .model_module import (
DecodeRequest,
ModelModule,
PrefillRequest,
SequenceId,
TextGenerator,
Tokenizer as TokenizerP,
)
from .model_module import DecodeRequest, ModelModule, PrefillRequest, SequenceId, TextGenerator, Tokenizer as TokenizerP
from ..model.base import ModelArtifactConfig

logger = logging.getLogger(__name__)
Expand All @@ -37,7 +34,6 @@ class SynchronousInferenceEngine(InferenceEngine):
A implementation of InferenceEngine that does inference synchronously in the current thread
when `step` is called.
"""

text_generator: TextGenerator
tokenizer: TokenizerP
model_artifact_config: ModelArtifactConfig
Expand All @@ -61,9 +57,7 @@ def __init__(
self.conversation_template = model_module.conversation_template
self.cache_manager = model_module.cache_manager
self.model_artifact_config = model_module.model_artifact_config
assert (
self.model_artifact_config.max_context_length
), "max_context_length must not be zero"
assert self.model_artifact_config.max_context_length, "max_context_length must not be zero"
self.max_context_length = self.model_artifact_config.max_context_length
self.max_num_batched_tokens = model_module.engine_config.max_num_batched_tokens
self.max_decode_steps = min(
Expand Down Expand Up @@ -159,23 +153,13 @@ def step(self) -> InferenceStepResult:

previous_requests_to_be_cancelled = set(self.requests_to_be_cancelled)
self._adjust_batch()

if not self.current_batch:
if len(self.queue) > 0:
logger.warning(
f"The engine has {len(self.queue)} requests to be processed in the queue, but none of them were added to the current batch during the execution of SyncEngine._adjust_batch"
)

hung_request_ids = []

for state in self.queue:
hung_request_ids.append(state.request_id)
# TODO(masahi): Proper error enum?
state.validation_err = ValidationError("Canceled due to a hang")

for request_id in hung_request_ids:
self.cancel(request_id)

for request_id in previous_requests_to_be_cancelled:
if request_id not in self.requests_to_be_cancelled:
outputs.append(
Expand Down

0 comments on commit c7e4c60

Please sign in to comment.