Skip to content

Commit

Permalink
Revert "Support Emoji (#102)"
Browse files Browse the repository at this point in the history
This reverts commit 3774516.
  • Loading branch information
sunggg authored Dec 11, 2023
1 parent 3774516 commit 94d151b
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 371 deletions.
30 changes: 7 additions & 23 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@

from typing import List, Callable, Any, Optional, Dict
import inspect
from .streamer import TextStreamer

from .sampling_params import SamplingParams, SamplingType

RequestId = str


# TODO(@sunggg): consider transition to something like Pydantic
@dataclass
class MLCServeEngineConfig:
Expand Down Expand Up @@ -40,7 +39,6 @@ def _from_json(config_cls, json_obj: Dict[Any, Any]):
}
)


def get_engine_config(dict_config):
engine_config = MLCServeEngineConfig._from_json(dict_config)
# Checks to make sure engine configs are set correctly
Expand All @@ -53,32 +51,26 @@ def get_engine_config(dict_config):
assert isinstance(engine_config.min_decode_steps, int)

# TODO(@sunggg): engine allows -1 for these params. figure out the behavior and enable checks properly
assert (
engine_config.max_num_batched_tokens == -1
), "`max_num_batched_tokens` is not supposed to be configured directly. \
assert engine_config.max_num_batched_tokens == -1, \
"`max_num_batched_tokens` is not supposed to be configured directly. \
Use `max_num_sequences` and `max_input_len` instead."
assert engine_config.max_input_len > 0
assert engine_config.max_num_sequences > 0
engine_config.max_num_batched_tokens = (
engine_config.max_num_sequences * engine_config.max_input_len
)
engine_config.max_num_batched_tokens = engine_config.max_num_sequences * engine_config.max_input_len

assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0)
assert engine_config.max_decode_steps > engine_config.min_decode_steps

return engine_config


@dataclass
class StoppingCriteria:
"""
Parameters about when to stop text generation.
"""

max_tokens: Optional[int] = None
stop_sequences: Optional[list[str]] = None


@dataclass
class ChatMessage:
role: str
Expand All @@ -97,20 +89,16 @@ class FinishReason(Enum):
Length = "length"
Cancelled = "cancelled"


# A single token.
Token = int


@dataclass
class ValidationError:
msg: str


# The type signature of the token validation callback.
ValidateTokensCallback = Callable[["Request", List[Token]], ValidationError]


@dataclass
class Request:
request_id: RequestId
Expand All @@ -122,9 +110,7 @@ class Request:
# Options for sampling.
sampling_params: SamplingParams = field(default_factory=SamplingParams)
# Options for stopping.
stopping_criteria: StoppingCriteria = field(
default_factory=lambda: StoppingCriteria()
)
stopping_criteria: StoppingCriteria = field(default_factory=lambda: StoppingCriteria())
# Options for debugging.
debug_options: DebugOptions = field(default_factory=DebugOptions)
# Perform request validation post-tokenization, used by the HTTP layer to control validation.
Expand Down Expand Up @@ -268,10 +254,8 @@ class RequestState:
debug_options: DebugOptions
arrival_timestamp: float
is_ended: bool = False
text_streamer: Optional[TextStreamer] = None
validation_err: Optional[ValidationError] = None


def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
if stopping_criteria.stop_sequences:
for t in stopping_criteria.stop_sequences:
Expand All @@ -284,8 +268,8 @@ def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
# While eventually we need to return "I "
if not output_text.endswith(t):
sub_index = output_text.find(t)
delta = delta[: -(len(output_text) - sub_index - len(t))]
output_text = output_text[: output_text.find(t) + len(t)]
delta = delta[:-(len(output_text) - sub_index - len(t))]
output_text = output_text[:output_text.find(t) + len(t)]
is_ended = True
break
return output_text, delta, is_ended
26 changes: 22 additions & 4 deletions serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import multiprocessing
import queue
from threading import Lock
from typing import Callable, Optional

import os

import structlog

from .base import (
Expand All @@ -26,7 +30,7 @@
ShutdownCommand,
run_generation_loop_worker,
)
from .streamer import TextStreamer

from ..logging_utils import log_every

LOG = structlog.stdlib.get_logger(__name__)
Expand Down Expand Up @@ -211,13 +215,12 @@ def step(self) -> InferenceStepResult:
state.token_ids.extend(seq_output.new_tokens)

# detokenize
delta = state.text_streamer.put([state.token_ids[-1]])
delta = self._decode_last_output(state)
state.output_text += delta

state.output_text, delta, state.is_ended = check_stopping_sequences(
state.stopping_criteria, state.output_text, delta, state.is_ended
)

# signal workers to stop generation
if state.is_ended:
self.stop_request(state.request_id)
Expand Down Expand Up @@ -269,6 +272,21 @@ def _get_new_request_state(self, request: Request) -> RequestState:
debug_options=request.debug_options,
output_text="",
validation_err=validation_err,
text_streamer=TextStreamer(self.tokenizer),
arrival_timestamp=time.time(),
)

def _decode_last_output(self, state: RequestState) -> str:
if len(state.output_text):
prefix_idx = max(0, state.next_start_position - 6)
else:
prefix_idx = state.next_start_position

if prefix_idx == 0:
return self.tokenizer.decode(state.token_ids)

prefix = self.tokenizer.decode(
state.token_ids[prefix_idx : state.next_start_position]
)
full = self.tokenizer.decode(state.token_ids[prefix_idx:])

return full[len(prefix) :]
95 changes: 0 additions & 95 deletions serve/mlc_serve/engine/streamer.py

This file was deleted.

66 changes: 30 additions & 36 deletions serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,7 @@
check_stopping_sequences,
ValidationError,
)
from .streamer import TextStreamer
from .model_module import (
DecodeRequest,
ModelModule,
PrefillRequest,
SequenceId,
TextGenerator,
Tokenizer as TokenizerP,
)
from .model_module import DecodeRequest, ModelModule, PrefillRequest, SequenceId, TextGenerator, Tokenizer as TokenizerP
from ..model.base import ModelArtifactConfig

logger = logging.getLogger(__name__)
Expand All @@ -39,7 +31,6 @@ class SynchronousInferenceEngine(InferenceEngine):
A implementation of InferenceEngine that does inference synchronously in the current thread
when `step` is called.
"""

text_generator: TextGenerator
tokenizer: TokenizerP
model_artifact_config: ModelArtifactConfig
Expand All @@ -62,14 +53,10 @@ def __init__(
self.conversation_template = model_module.conversation_template
self.cache_manager = model_module.cache_manager
self.model_artifact_config = model_module.model_artifact_config
assert (
self.model_artifact_config.max_context_length
), "max_context_length must not be zero"
assert self.model_artifact_config.max_context_length, "max_context_length must not be zero"
self.max_context_length = self.model_artifact_config.max_context_length
self.max_num_batched_tokens = model_module.engine_config.max_num_batched_tokens
assert (
self.max_num_batched_tokens > 0
), "max_num_batched_tokens must be positive"
assert self.max_num_batched_tokens > 0, "max_num_batched_tokens must be positive"
self.max_decode_steps = min(
self.cache_manager.get_kv_cache_size(),
model_module.engine_config.max_decode_steps,
Expand Down Expand Up @@ -108,18 +95,14 @@ def add(self, requests: list[Request]):

if (
state.validation_err is not None
or state.prompt_len
> min(self.max_context_length, self.max_num_batched_tokens)
or state.prompt_len > min(self.max_context_length, self.max_num_batched_tokens)
# We make sure that the KV cache will have enough free space for this request to proceed
# decoding for at least self.max_decode_steps steps.
or self.cache_manager.get_kv_cache_size() - state.prompt_len
< self.max_decode_steps
or self.cache_manager.get_kv_cache_size() - state.prompt_len < self.max_decode_steps
):
self.cancel(req.request_id)
if state.validation_err is None:
state.validation_err = ValidationError(
"The prompt is too long for the given set of engine parameters."
)
state.validation_err = ValidationError("The prompt is too long for the given set of engine parameters.")

with self.queue_lock:
self.queue.extend(new_request_states)
Expand Down Expand Up @@ -234,7 +217,7 @@ def step(self) -> InferenceStepResult:

state.token_ids.extend(new_token_ids)

delta = state.text_streamer.put([state.token_ids[-1]])
delta = self._decode_last_output(state)
state.output_text += delta

state.output_text, delta, state.is_ended = check_stopping_sequences(
Expand Down Expand Up @@ -304,14 +287,9 @@ def _adjust_batch(self):
# self.max_num_batched_tokens. In such cases, we need to discard the recent decode
# tokens that cannot fit into a batch, and recompute them after we fill the cache
# entries for the older tokens.
if (
not len(self.current_batch)
and num_new_batched_tokens > self.max_num_batched_tokens
):
state.token_ids = state.token_ids[: self.max_num_batched_tokens]
state.next_start_position = (
num_new_batched_tokens
) = num_tokens = self.max_num_batched_tokens
if not len(self.current_batch) and num_new_batched_tokens > self.max_num_batched_tokens:
state.token_ids = state.token_ids[:self.max_num_batched_tokens]
state.next_start_position = num_new_batched_tokens = num_tokens = self.max_num_batched_tokens
if num_new_batched_tokens > self.max_num_batched_tokens > 0:
logger.debug(
"Stop growing the batch due to max_num_batched_tokens. Batched tokens: %s",
Expand All @@ -320,9 +298,10 @@ def _adjust_batch(self):
break
# We make sure that the KV cache will have enough free space for this request to proceed
# decoding for at least self.max_decode_steps steps.
if (self.cache_manager.get_free_space() - num_tokens) / (
len(self.current_batch) + 1
) < self.max_decode_steps:
if (
(self.cache_manager.get_free_space() - num_tokens) / (len(self.current_batch) + 1)
< self.max_decode_steps
):
logger.debug(
"Stop growing the batch due to not enough free space. Free: %s, Num tokens: %s",
self.cache_manager.get_free_space(),
Expand Down Expand Up @@ -405,10 +384,25 @@ def _get_new_request_state(self, request: Request) -> RequestState:
stopping_criteria=request.stopping_criteria,
debug_options=request.debug_options,
output_text="",
text_streamer=TextStreamer(self.tokenizer),
arrival_timestamp=time.time(),
)

def _decode_last_output(self, state: RequestState) -> str:
if len(state.output_text):
prefix_idx = max(0, state.next_start_position - 6)
else:
prefix_idx = state.next_start_position

if prefix_idx == 0:
return self.tokenizer.decode(state.token_ids)

prefix = self.tokenizer.decode(
state.token_ids[prefix_idx : state.next_start_position]
)
full = self.tokenizer.decode(state.token_ids[prefix_idx:])

return full[len(prefix) :]

def _should_stop_by_length(self, state: RequestState) -> bool:
# TODO: currently, we simply return true for both stopping reasons.
# in the future, we can differentiate these two.
Expand Down
Loading

0 comments on commit 94d151b

Please sign in to comment.