Skip to content

Commit

Permalink
Merge branch 'tmp' into bugfix/2023-Nov/stop-token-staging-engine
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg committed Nov 21, 2023
2 parents b708330 + 753b732 commit c5e7017
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 53 deletions.
2 changes: 0 additions & 2 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ class StoppingCriteria:
"""
max_tokens: Optional[int]
stop_sequences: Optional[list[str]] = None
list_stop_token_ids: Optional[list[list[str]]] = None


@dataclass
class ChatMessage:
Expand Down
51 changes: 22 additions & 29 deletions serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
ShutdownCommand,
run_generation_loop_worker,
)
#logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -65,11 +64,14 @@ def __init__(
)

def start(self):
self.worker_process.start()
if not self.ready_event.wait(timeout=180):
raise RuntimeError(
"StagingInferenceEngine worker is not ready before timeout."
)
try:
self.worker_process.start()
if not self.ready_event.wait(timeout=90):
raise RuntimeError(
"StagingInferenceEngine worker is not ready before timeout."
)
except:
raise RuntimeError("Failed to start StagingInferenceEngine worker process.")

def stop(self):
self.command_queue.put(ShutdownCommand())
Expand All @@ -86,23 +88,12 @@ def add(self, requests: list[Request]):
if req.num_sequences > 1:
raise RuntimeError("num_sequences > 1 is not supported for now")

if req.stopping_criteria.stop_sequences and not req.stopping_criteria.list_stop_token_ids:
# TODO: verify tokenizer setting
list_stop_token_ids = []
for stop in req.stopping_criteria.stop_sequences:
stop_token_ids = self.tokenizer._tokenizer.encode(stop, add_special_tokens=False, padding=False)
# If there is a special token `SPIECE_UNDERLINE`, truncate it.
# You will see this for stop tokens like `\n`.
# Currently, there is no easy way to disable its insertion,
# so manually truncate it.
# Related discussion: https://github.com/huggingface/transformers/issues/26273
if stop_token_ids[0] == 29871:
stop_token_ids = stop_token_ids[1:]
# TODO: Currently, staging engine only can handle single-token stop strings.
# Extend it to multi-tokens
assert len(stop_token_ids) == 1
list_stop_token_ids.append(stop_token_ids)
req.stopping_criteria.list_stop_token_ids = list_stop_token_ids
# wrap the stop sequence with list if necessary
if req.stopping_criteria.stop_sequences:
if isinstance(req.stopping_criteria.stop_sequences, str):
req.stopping_criteria.stop_sequences = [req.stopping_criteria.stop_sequences]
assert isinstance(req.stopping_criteria.stop_sequences, list)

# If the request violates the tokenization, this returns None, so skip.
state = self._get_new_request_state(req)
new_request_states.append(state)
Expand Down Expand Up @@ -186,12 +177,14 @@ def step(self) -> InferenceStepResult:
delta = self._decode_last_output(state)
state.output_text += delta


#state.output_text, delta, state.is_ended = check_stopping_sequences(state.stopping_criteria,
# state.output_text,
# delta,
# state.is_ended)

state.output_text, delta, state.is_ended = check_stopping_sequences(state.stopping_criteria,
state.output_text,
delta,
state.is_ended)
# signal workers to stop generation
if state.is_ended:
self.cancel(state.request_id)

outputs.append(
RequestOutput(
request_id,
Expand Down
24 changes: 2 additions & 22 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .base import FinishReason, RequestId, RequestState, check_stopping_sequences
from .model_module import DecodeRequest, ModelModule, PrefillRequest, SequenceId
import structlog


#logging.basicConfig(filename='example.log', filemode='w')
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -193,27 +193,7 @@ def step(self) -> GenerationLoopWorkerOutput:
new_tokens = new_tokens[:i+1]
state.is_ended = True
break

list_stop_token_ids = state.stopping_criteria.list_stop_token_ids
if list_stop_token_ids:
for stop_token_ids in list_stop_token_ids:
num = len(stop_token_ids)
#TODO: currently, it seems tricky to see multiple generation tokens within the worker.
assert num == 1

# TODO(@team): any better way?
found = (len(new_tokens[i:i+num]) == num) and all([ n1==n2 for n1, n2 in zip(new_tokens[i:i+num], stop_token_ids)])
if found:
new_tokens = new_tokens[:i+num]
state.is_ended = True
break

if state.is_ended:
break
#if token_id == state.stopping_criteria.list_stop_token_ids[0][-1]:
# new_tokens = new_tokens[:i+1]
# state.is_ended = True
# break

state.token_ids.extend(new_tokens)
outputs.append(
SequenceGenerationOutput(id=res.sequence_id, new_tokens=new_tokens)
Expand Down
7 changes: 7 additions & 0 deletions serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def add(self, requests: list[Request]):
# TODO: verify that request id is unique
if req.num_sequences > 1:
raise RuntimeError("num_sequences > 1 is not supported for now")

# wrap the stop sequence with list if necessary
if req.stopping_criteria.stop_sequences:
if isinstance(req.stopping_criteria.stop_sequences, str):
req.stopping_criteria.stop_sequences = [req.stopping_criteria.stop_sequences]
assert isinstance(req.stopping_criteria.stop_sequences, list)

state = self._get_new_request_state(req)
new_request_states.append(state)

Expand Down
197 changes: 197 additions & 0 deletions serve/tests/unittest/test_engine_with_samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import torch

import argparse
import json
import random
import os

from mlc_llm import utils
from mlc_serve.engine import (
Request,
ChatMessage,
DebugOptions,
SamplingParams,
StoppingCriteria,
FinishReason,
get_engine_config
)
from mlc_serve.engine.staging_engine import StagingInferenceEngine
from mlc_serve.engine.sync_engine import SynchronousInferenceEngine
from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule

def create_engine(
model_artifact_path,
use_staging_engine,
max_num_batched_tokens,
max_input_len,

):
engine_config = get_engine_config({
"use_staging_engine": use_staging_engine,
"max_num_batched_tokens": max_num_batched_tokens,
"max_input_len": max_input_len,
# Use defaults for "min_decode_steps", "max_decode_steps", "prompt_allocate_ratio"
})

if use_staging_engine:
engine = StagingInferenceEngine(
tokenizer_module=HfTokenizerModule(model_artifact_path),
model_module_loader=PagedCacheModelModule,
model_module_loader_kwargs={
"model_artifact_path": model_artifact_path,
"engine_config": engine_config,
},
)
engine.start()
else:
engine = SynchronousInferenceEngine(
PagedCacheModelModule(
model_artifact_path = model_artifact_path,
engine_config = engine_config,
))
return engine

def create_request(idx, prompt, temp, max_tokens, stop, ignore_eos):
return Request(
request_id = str(idx),
messages = [ChatMessage(role="user", content=prompt)],
sampling_params = SamplingParams(
temperature=0.0,
),
stopping_criteria = StoppingCriteria(
max_tokens=max_tokens,
stop_sequences=stop
),
debug_options = DebugOptions(ignore_eos = ignore_eos)
)

def test_max_tokens(
model_artifact_path,
use_staging_engine,
max_num_batched_tokens=2560,
max_input_len=2560,
num_requests=5,
ignore_eos=False
):
prompt = "Write a merge sort program in Python."
engine = create_engine(
model_artifact_path,
use_staging_engine,
max_num_batched_tokens,
max_input_len,
)

requests = [create_request(idx=str(n-1), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=ignore_eos) for n in range(1, num_requests)]
engine.add(requests)

generated = ["" for _ in range(num_requests)]

while engine.has_pending_requests():
results = engine.step()
for res in results.outputs:
assert len(res.sequences) == 1
seq = res.sequences[0]

if seq.is_finished:
assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens
assert seq.finish_reason == FinishReason.Length
else:
generated[int(res.request_id)] += seq.delta

if use_staging_engine:
engine.stop()


def test_ignore_eos(
model_artifact_path,
use_staging_engine,
max_num_batched_tokens=2560,
max_input_len=2560,
num_requests=5,
):
prompt = "hi"
engine = create_engine(
model_artifact_path,
use_staging_engine,
max_num_batched_tokens,
max_input_len,
)
s = 113
requests = [create_request(idx=str(n-s), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=True) for n in range(s, s+num_requests)]
engine.add(requests)

generated = ["" for _ in range(num_requests)]

while engine.has_pending_requests():
results = engine.step()
for res in results.outputs:
assert len(res.sequences) == 1
seq = res.sequences[0]

if seq.is_finished:
assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens
assert seq.finish_reason == FinishReason.Length
else:
generated[int(res.request_id)] += seq.delta

if use_staging_engine:
engine.stop()


def test_stop(
model_artifact_path,
use_staging_engine,
max_num_batched_tokens=2560,
max_input_len=2560,
num_requests=5,
):
prompt = "Write a merge sort program in Python."
engine = create_engine(
model_artifact_path,
use_staging_engine,
max_num_batched_tokens,
max_input_len,
)
ignore_eos = False
requests = []
for n, stop in enumerate(["\n", ["\n"], "\n\n", "!", ["n", "!"]]):
requests.append(create_request(idx=str(n), prompt=prompt, temp=0, max_tokens=300, stop=stop, ignore_eos=False))
engine.add(requests)

generated = ["" for _ in range(num_requests)]

while engine.has_pending_requests():
results = engine.step()
for res in results.outputs:
assert len(res.sequences) == 1
seq = res.sequences[0]
req_id = int(res.request_id)
if seq.is_finished:
#assert seq.finish_reason == FinishReason.Stop, f"{seq.finish_reason.name}"
assert not seq.delta
gen_txt = generated[req_id]
print(f"request id {req_id} : {gen_txt!r}")
# stop token should appear only once in the gen text.
found = sum([gen_txt.count(str_stop) for str_stop in requests[req_id].stopping_criteria.stop_sequences])
assert found == 1, f"{gen_txt!r}, matches: {found}"
else:
generated[int(res.request_id)] += seq.delta

if use_staging_engine:
engine.stop()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local-id", type=str, required=True)
parser.add_argument("--artifact-path", type=str, default="../../../dist")
args = parser.parse_args()
model_artifact_path = os.path.join(args.artifact_path, args.local_id)

#test_max_tokens(model_artifact_path, use_staging_engine=True)
#test_max_tokens(model_artifact_path, use_staging_engine=False)
#test_ignore_eos(model_artifact_path, use_staging_engine=True)
#test_ignore_eos(model_artifact_path, use_staging_engine=False)

#test_stop(model_artifact_path, use_staging_engine=False)
test_stop(model_artifact_path, use_staging_engine=True)

0 comments on commit c5e7017

Please sign in to comment.