Skip to content

Commit

Permalink
Support Emoji (#102)
Browse files Browse the repository at this point in the history
* done

* fix

* lint
  • Loading branch information
sunggg authored Dec 11, 2023
1 parent 9eeedfe commit 3774516
Show file tree
Hide file tree
Showing 6 changed files with 371 additions and 112 deletions.
30 changes: 23 additions & 7 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

from typing import List, Callable, Any, Optional, Dict
import inspect

from .streamer import TextStreamer
from .sampling_params import SamplingParams, SamplingType

RequestId = str


# TODO(@sunggg): consider transition to something like Pydantic
@dataclass
class MLCServeEngineConfig:
Expand Down Expand Up @@ -39,6 +40,7 @@ def _from_json(config_cls, json_obj: Dict[Any, Any]):
}
)


def get_engine_config(dict_config):
engine_config = MLCServeEngineConfig._from_json(dict_config)
# Checks to make sure engine configs are set correctly
Expand All @@ -51,26 +53,32 @@ def get_engine_config(dict_config):
assert isinstance(engine_config.min_decode_steps, int)

# TODO(@sunggg): engine allows -1 for these params. figure out the behavior and enable checks properly
assert engine_config.max_num_batched_tokens == -1, \
"`max_num_batched_tokens` is not supposed to be configured directly. \
assert (
engine_config.max_num_batched_tokens == -1
), "`max_num_batched_tokens` is not supposed to be configured directly. \
Use `max_num_sequences` and `max_input_len` instead."
assert engine_config.max_input_len > 0
assert engine_config.max_num_sequences > 0
engine_config.max_num_batched_tokens = engine_config.max_num_sequences * engine_config.max_input_len
engine_config.max_num_batched_tokens = (
engine_config.max_num_sequences * engine_config.max_input_len
)

assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0)
assert engine_config.max_decode_steps > engine_config.min_decode_steps

return engine_config


@dataclass
class StoppingCriteria:
"""
Parameters about when to stop text generation.
"""

max_tokens: Optional[int] = None
stop_sequences: Optional[list[str]] = None


@dataclass
class ChatMessage:
role: str
Expand All @@ -89,16 +97,20 @@ class FinishReason(Enum):
Length = "length"
Cancelled = "cancelled"


# A single token.
Token = int


@dataclass
class ValidationError:
msg: str


# The type signature of the token validation callback.
ValidateTokensCallback = Callable[["Request", List[Token]], ValidationError]


@dataclass
class Request:
request_id: RequestId
Expand All @@ -110,7 +122,9 @@ class Request:
# Options for sampling.
sampling_params: SamplingParams = field(default_factory=SamplingParams)
# Options for stopping.
stopping_criteria: StoppingCriteria = field(default_factory=lambda: StoppingCriteria())
stopping_criteria: StoppingCriteria = field(
default_factory=lambda: StoppingCriteria()
)
# Options for debugging.
debug_options: DebugOptions = field(default_factory=DebugOptions)
# Perform request validation post-tokenization, used by the HTTP layer to control validation.
Expand Down Expand Up @@ -254,8 +268,10 @@ class RequestState:
debug_options: DebugOptions
arrival_timestamp: float
is_ended: bool = False
text_streamer: Optional[TextStreamer] = None
validation_err: Optional[ValidationError] = None


def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
if stopping_criteria.stop_sequences:
for t in stopping_criteria.stop_sequences:
Expand All @@ -268,8 +284,8 @@ def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
# While eventually we need to return "I "
if not output_text.endswith(t):
sub_index = output_text.find(t)
delta = delta[:-(len(output_text) - sub_index - len(t))]
output_text = output_text[:output_text.find(t) + len(t)]
delta = delta[: -(len(output_text) - sub_index - len(t))]
output_text = output_text[: output_text.find(t) + len(t)]
is_ended = True
break
return output_text, delta, is_ended
26 changes: 4 additions & 22 deletions serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
import multiprocessing
import queue
from threading import Lock
from typing import Callable, Optional

import os

import structlog

from .base import (
Expand All @@ -30,7 +26,7 @@
ShutdownCommand,
run_generation_loop_worker,
)

from .streamer import TextStreamer
from ..logging_utils import log_every

LOG = structlog.stdlib.get_logger(__name__)
Expand Down Expand Up @@ -215,12 +211,13 @@ def step(self) -> InferenceStepResult:
state.token_ids.extend(seq_output.new_tokens)

# detokenize
delta = self._decode_last_output(state)
delta = state.text_streamer.put([state.token_ids[-1]])
state.output_text += delta

state.output_text, delta, state.is_ended = check_stopping_sequences(
state.stopping_criteria, state.output_text, delta, state.is_ended
)

# signal workers to stop generation
if state.is_ended:
self.stop_request(state.request_id)
Expand Down Expand Up @@ -272,21 +269,6 @@ def _get_new_request_state(self, request: Request) -> RequestState:
debug_options=request.debug_options,
output_text="",
validation_err=validation_err,
text_streamer=TextStreamer(self.tokenizer),
arrival_timestamp=time.time(),
)

def _decode_last_output(self, state: RequestState) -> str:
if len(state.output_text):
prefix_idx = max(0, state.next_start_position - 6)
else:
prefix_idx = state.next_start_position

if prefix_idx == 0:
return self.tokenizer.decode(state.token_ids)

prefix = self.tokenizer.decode(
state.token_ids[prefix_idx : state.next_start_position]
)
full = self.tokenizer.decode(state.token_ids[prefix_idx:])

return full[len(prefix) :]
95 changes: 95 additions & 0 deletions serve/mlc_serve/engine/streamer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import List, Deque
from collections import deque

kReplacementCharacter = b"\xef\xbf\xbd".decode("utf8")


class TextStreamer:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.prefix_tokens: List[int] = []
self.pending_tokens: Deque[int] = deque([])

def put(self, delta_tokens: List[int]) -> str:
if len(delta_tokens) == 0:
return ""

ret = ""
for delta_token in delta_tokens:
self.pending_tokens.append(delta_token)
all_tokens = self.prefix_tokens + list(self.pending_tokens)

prefix_str = (
self.tokenizer.decode(self.prefix_tokens)
if len(self.prefix_tokens) > 0
else ""
)
full_str = self.tokenizer.decode(all_tokens)
prefix_len = len(prefix_str)

new_pending_tokens: Deque[int] = deque([])
if full_str[:prefix_len] == prefix_str:
# Case 1. prefix_str is a prefix of `full_str`.
# We cannot naively do `validated_str = self.tokenizer.decode(validated_tokens)`
# since it will lose the contextual information, such as ' '.
validated_str = full_str[prefix_len:]
while (
len(self.pending_tokens) > 0
and len(new_pending_tokens) < 3
and len(validated_str) >= 1
and validated_str[len(validated_str) - 1 :] == kReplacementCharacter
):
new_pending_tokens.appendleft(self.pending_tokens.pop())
validated_str = validated_str[: len(validated_str) - 1]
else:
# Case 2. prefix_str is not a prefix of `full_str`.
# Pop pending tokens from the back.
# - Pop until prefix_str is indeed a prefix of full_str.
# - A valid UTF-8 has 4 chars at most.
# So there will be at most 3 tokens popped.
# - If there are no more than 3 pending tokens, skip popping.
# This is because it is impossible to make full_str contain
# prefix_str without popping all the pending tokens.
if len(self.pending_tokens) < 3:
continue
get_valid_full_str = False
while len(self.pending_tokens) > 0 and len(new_pending_tokens) < 3:
new_pending_tokens.appendleft(self.pending_tokens.pop())
all_tokens.pop()
full_str = self.tokenizer.decode(all_tokens)
if full_str[:prefix_len] == prefix_str:
get_valid_full_str = True
break
if get_valid_full_str:
# We find a full_str which starts from prefix_str
# So we return the sliced full string without the prefix.
validated_str = full_str[prefix_len:]
else:
# We cannot find a full_str which starts from prefix_str by
# popping 3 tokens.
# In this case, the remaining pending tokens are invalid UTF-8
# characters already, so we return the decoded pending tokens.
validated_str = self.tokenizer.decode(self.pending_tokens)

if len(self.pending_tokens) > 0:
# set the new prefix
self.prefix_tokens = list(self.pending_tokens)
self.pending_tokens = new_pending_tokens

ret += validated_str
return ret

def finish(self) -> str:
all_tokens = self.prefix_tokens + list(self.pending_tokens)
prefix_str = (
self.tokenizer.decode(self.prefix_tokens)
if len(self.prefix_tokens) > 0
else ""
)
full_str = self.tokenizer.decode(all_tokens) if len(all_tokens) > 0 else ""
prefix_len = len(prefix_str)

if full_str[:prefix_len] == prefix_str:
return full_str[prefix_len:]
else:
return self.tokenizer.decode(self.pending_tokens)
66 changes: 36 additions & 30 deletions serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@
check_stopping_sequences,
ValidationError,
)
from .model_module import DecodeRequest, ModelModule, PrefillRequest, SequenceId, TextGenerator, Tokenizer as TokenizerP
from .streamer import TextStreamer
from .model_module import (
DecodeRequest,
ModelModule,
PrefillRequest,
SequenceId,
TextGenerator,
Tokenizer as TokenizerP,
)
from ..model.base import ModelArtifactConfig

logger = logging.getLogger(__name__)
Expand All @@ -31,6 +39,7 @@ class SynchronousInferenceEngine(InferenceEngine):
A implementation of InferenceEngine that does inference synchronously in the current thread
when `step` is called.
"""

text_generator: TextGenerator
tokenizer: TokenizerP
model_artifact_config: ModelArtifactConfig
Expand All @@ -53,10 +62,14 @@ def __init__(
self.conversation_template = model_module.conversation_template
self.cache_manager = model_module.cache_manager
self.model_artifact_config = model_module.model_artifact_config
assert self.model_artifact_config.max_context_length, "max_context_length must not be zero"
assert (
self.model_artifact_config.max_context_length
), "max_context_length must not be zero"
self.max_context_length = self.model_artifact_config.max_context_length
self.max_num_batched_tokens = model_module.engine_config.max_num_batched_tokens
assert self.max_num_batched_tokens > 0, "max_num_batched_tokens must be positive"
assert (
self.max_num_batched_tokens > 0
), "max_num_batched_tokens must be positive"
self.max_decode_steps = min(
self.cache_manager.get_kv_cache_size(),
model_module.engine_config.max_decode_steps,
Expand Down Expand Up @@ -95,14 +108,18 @@ def add(self, requests: list[Request]):

if (
state.validation_err is not None
or state.prompt_len > min(self.max_context_length, self.max_num_batched_tokens)
or state.prompt_len
> min(self.max_context_length, self.max_num_batched_tokens)
# We make sure that the KV cache will have enough free space for this request to proceed
# decoding for at least self.max_decode_steps steps.
or self.cache_manager.get_kv_cache_size() - state.prompt_len < self.max_decode_steps
or self.cache_manager.get_kv_cache_size() - state.prompt_len
< self.max_decode_steps
):
self.cancel(req.request_id)
if state.validation_err is None:
state.validation_err = ValidationError("The prompt is too long for the given set of engine parameters.")
state.validation_err = ValidationError(
"The prompt is too long for the given set of engine parameters."
)

with self.queue_lock:
self.queue.extend(new_request_states)
Expand Down Expand Up @@ -217,7 +234,7 @@ def step(self) -> InferenceStepResult:

state.token_ids.extend(new_token_ids)

delta = self._decode_last_output(state)
delta = state.text_streamer.put([state.token_ids[-1]])
state.output_text += delta

state.output_text, delta, state.is_ended = check_stopping_sequences(
Expand Down Expand Up @@ -287,9 +304,14 @@ def _adjust_batch(self):
# self.max_num_batched_tokens. In such cases, we need to discard the recent decode
# tokens that cannot fit into a batch, and recompute them after we fill the cache
# entries for the older tokens.
if not len(self.current_batch) and num_new_batched_tokens > self.max_num_batched_tokens:
state.token_ids = state.token_ids[:self.max_num_batched_tokens]
state.next_start_position = num_new_batched_tokens = num_tokens = self.max_num_batched_tokens
if (
not len(self.current_batch)
and num_new_batched_tokens > self.max_num_batched_tokens
):
state.token_ids = state.token_ids[: self.max_num_batched_tokens]
state.next_start_position = (
num_new_batched_tokens
) = num_tokens = self.max_num_batched_tokens
if num_new_batched_tokens > self.max_num_batched_tokens > 0:
logger.debug(
"Stop growing the batch due to max_num_batched_tokens. Batched tokens: %s",
Expand All @@ -298,10 +320,9 @@ def _adjust_batch(self):
break
# We make sure that the KV cache will have enough free space for this request to proceed
# decoding for at least self.max_decode_steps steps.
if (
(self.cache_manager.get_free_space() - num_tokens) / (len(self.current_batch) + 1)
< self.max_decode_steps
):
if (self.cache_manager.get_free_space() - num_tokens) / (
len(self.current_batch) + 1
) < self.max_decode_steps:
logger.debug(
"Stop growing the batch due to not enough free space. Free: %s, Num tokens: %s",
self.cache_manager.get_free_space(),
Expand Down Expand Up @@ -384,25 +405,10 @@ def _get_new_request_state(self, request: Request) -> RequestState:
stopping_criteria=request.stopping_criteria,
debug_options=request.debug_options,
output_text="",
text_streamer=TextStreamer(self.tokenizer),
arrival_timestamp=time.time(),
)

def _decode_last_output(self, state: RequestState) -> str:
if len(state.output_text):
prefix_idx = max(0, state.next_start_position - 6)
else:
prefix_idx = state.next_start_position

if prefix_idx == 0:
return self.tokenizer.decode(state.token_ids)

prefix = self.tokenizer.decode(
state.token_ids[prefix_idx : state.next_start_position]
)
full = self.tokenizer.decode(state.token_ids[prefix_idx:])

return full[len(prefix) :]

def _should_stop_by_length(self, state: RequestState) -> bool:
# TODO: currently, we simply return true for both stopping reasons.
# in the future, we can differentiate these two.
Expand Down
Loading

0 comments on commit 3774516

Please sign in to comment.