Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop token follow up for staging engine. Properly report stop reason #78

Merged
merged 4 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
RequestState,
ScopedInferenceEngine,
SequenceOutput,
check_stopping_sequences
check_stopping_sequences,
)
from .model_module import ModelModule, TokenizerModule
from .staging_engine_worker import (
AddRequestsCommand,
CancelRequestCommand,
StopRequestCommand,
ShutdownCommand,
run_generation_loop_worker,
)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -90,7 +92,9 @@ def add(self, requests: list[Request]):
# wrap the stop sequence with list if necessary
if req.stopping_criteria.stop_sequences:
if isinstance(req.stopping_criteria.stop_sequences, str):
req.stopping_criteria.stop_sequences = [req.stopping_criteria.stop_sequences]
req.stopping_criteria.stop_sequences = [
req.stopping_criteria.stop_sequences
]
assert isinstance(req.stopping_criteria.stop_sequences, list)

# If the request violates the tokenization, this returns None, so skip.
Expand All @@ -107,6 +111,11 @@ def cancel(self, request_id: RequestId):
raise RuntimeError("GenerationLoopWorker process is not running")
self.command_queue.put(CancelRequestCommand(request_id))

def stop_request(self, request_id: RequestId):
if not self._is_ready_to_serve():
raise RuntimeError("GenerationLoopWorker process is not running")
self.command_queue.put(StopRequestCommand(request_id))

def has_pending_requests(self) -> bool:
with self.requests_lock:
return len(self.requests) > 0
Expand Down Expand Up @@ -173,13 +182,12 @@ def step(self) -> InferenceStepResult:
delta = self._decode_last_output(state)
state.output_text += delta

state.output_text, delta, state.is_ended = check_stopping_sequences(state.stopping_criteria,
state.output_text,
delta,
state.is_ended)
# signal workers to stop generation
state.output_text, delta, state.is_ended = check_stopping_sequences(
state.stopping_criteria, state.output_text, delta, state.is_ended
)
# signal workers to stop generation
if state.is_ended:
self.cancel(state.request_id)
self.stop_request(state.request_id)

outputs.append(
RequestOutput(
Expand Down
78 changes: 58 additions & 20 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

logger = logging.getLogger(__name__)


@dataclass
class ShutdownCommand:
pass
Expand All @@ -30,6 +31,11 @@ class CancelRequestCommand:
request_id: RequestId


@dataclass
class StopRequestCommand:
request_id: RequestId


GenerationLoopWorkerCommand = Union[
ShutdownCommand, AddRequestsCommand, CancelRequestCommand
]
Expand Down Expand Up @@ -61,9 +67,12 @@ def __init__(
self.max_context_length = self.model_artifact_config.max_context_length
self.max_num_batched_tokens = model_module.engine_config.max_num_batched_tokens
self.max_decode_steps = min(
self.cache_manager.get_kv_cache_size(), model_module.engine_config.max_decode_steps
self.cache_manager.get_kv_cache_size(),
model_module.engine_config.max_decode_steps,
)
self.min_decode_steps = min(
self.max_decode_steps - 1, model_module.engine_config.min_decode_steps
)
self.min_decode_steps = min(self.max_decode_steps - 1, model_module.engine_config.min_decode_steps)
self.prompt_allocate_ratio = model_module.engine_config.prompt_allocate_ratio
assert self.prompt_allocate_ratio >= 1.0

Expand All @@ -72,6 +81,7 @@ def __init__(
self.has_new_requests = Condition(lock=self.queue_lock)

self.cancelled_requests = list[RequestState]()
self.stopped_requests = list[RequestState]()

self.current_batch = dict[RequestId, RequestState]()

Expand All @@ -81,28 +91,40 @@ def add(self, request_states: list[RequestState]):
# cancel them instead.
valid_states = []
for request_state in request_states:
if request_state.validation_err is not None or request_state.prompt_len >= self.max_context_length:
if (
request_state.validation_err is not None
or request_state.prompt_len >= self.max_context_length
):
self.cancelled_requests.append(request_state)
else:
valid_states.append(request_state)

self.queue.extend(valid_states)
self.has_new_requests.notify_all()

def cancel(self, request_id: RequestId):
with self.queue_lock:
queue_index_to_delete = None
for i, state in enumerate(self.queue):
if state.request_id == request_id:
queue_index_to_delete = i
self.cancelled_requests.append(state)
break
def _get_request_state(self, request_id: RequestId) -> Optional[RequestState]:
for state in self.queue:
if state.request_id == request_id:
return state

if queue_index_to_delete is not None:
del self.queue[queue_index_to_delete]
return None

def _cacnel_or_stop_request(
self, request_id: RequestId, requests: list[RequestState]
):
with self.queue_lock:
state = self._get_request_state(request_id)
if state:
del state
Copy link
Member Author

@masahi masahi Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we need to requests.append(state) here like in the original code. The original code has two append in cancel but they seem to be appending the same state. cc @yelite

Copy link
Member Author

@masahi masahi Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, I realized that self.queue and current_batch are disjoint. Sent a hot fix to #79


if request_id in self.current_batch:
self.cancelled_requests.append(self.current_batch[request_id])
requests.append(self.current_batch[request_id])

def cancel_request(self, request_id: RequestId):
self._cacnel_or_stop_request(request_id, self.cancelled_requests)

def stop_request(self, request_id: RequestId):
self._cacnel_or_stop_request(request_id, self.stopped_requests)

def wait_for_request(self, timeout_seconds=None) -> bool:
with self.queue_lock:
Expand Down Expand Up @@ -138,6 +160,20 @@ def step(self) -> GenerationLoopWorkerOutput:
)
self._remove_request_from_batch(state.request_id)

for state in self.stopped_requests:
outputs.append(
SequenceGenerationOutput(
# TODO: support multi-sequence
id=SequenceId(state.request_id, 0),
new_tokens=[],
finish_reason=FinishReason.Stop,
)
)
if state.request_id in self.current_batch:
self._remove_request_from_batch(state.request_id)

self.stopped_requests.clear()

for state in self.cancelled_requests:
err = None
if state.validation_err:
Expand All @@ -149,7 +185,7 @@ def step(self) -> GenerationLoopWorkerOutput:
id=SequenceId(state.request_id, 0),
new_tokens=[],
finish_reason=FinishReason.Cancelled,
error = err
error=err,
)
)
if state.request_id in self.current_batch:
Expand Down Expand Up @@ -307,13 +343,13 @@ def _has_request_to_process(self) -> bool:
return self.queue or self.current_batch

def _should_stop_by_length(self, state: RequestState) -> bool:
# TODO: currently, we simply return true for both stopping reasons.
# in the future, we can differentiate these two.
# TODO: currently, we simply return true for both stopping reasons.
# in the future, we can differentiate these two.
# this include prompt tokens and gen tokens so far
num_context_tokens = len(state.token_ids)
num_context_tokens = len(state.token_ids)
if num_context_tokens >= self.model_artifact_config.max_context_length:
return True
num_gen_tokens = num_context_tokens - state.prompt_len
num_gen_tokens = num_context_tokens - state.prompt_len
if num_gen_tokens >= state.stopping_criteria.max_tokens:
return True
return False
Expand Down Expand Up @@ -366,7 +402,9 @@ def handle_command():
elif isinstance(cmd, AddRequestsCommand):
worker.add(cmd.request_states)
elif isinstance(cmd, CancelRequestCommand):
worker.cancel(cmd.request_id)
worker.cancel_request(cmd.request_id)
elif isinstance(cmd, StopRequestCommand):
worker.stop_request(cmd.request_id)
else:
logger.error("Unknown command type %s", type(cmd))
break
Expand Down
3 changes: 1 addition & 2 deletions serve/tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test(args: argparse.Namespace):
engine_config = get_engine_config({
"use_staging_engine": args.use_staging_engine,
"max_num_sequences": args.max_num_sequences,
"max_input_len": args.max_input_len,
"max_input_len": args.max_input_len,
"min_decode_steps": args.min_decode_steps,
"max_decode_steps": args.max_decode_steps,
"prompt_allocate_ratio": args.prompt_allocate_ratio
Expand Down Expand Up @@ -119,7 +119,6 @@ def test(args: argparse.Namespace):
parser = argparse.ArgumentParser()
parser.add_argument("--local-id", type=str, required=True)
parser.add_argument("--artifact-path", type=str, default="dist")
parser.add_argument("--num-shards", type=int, default=1)
parser.add_argument("--max-input-len", type=int, default=512)
parser.add_argument("--max-num-sequences", type=int, default=8)
parser.add_argument("--max-output-len", type=int, default=20)
Expand Down
56 changes: 27 additions & 29 deletions serve/tests/unittest/test_engine_with_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@
from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule

def create_engine(
model_artifact_path,
use_staging_engine,
max_num_sequences,
model_artifact_path,
use_staging_engine,
max_num_sequences,
max_input_len,

):
engine_config = get_engine_config({
"use_staging_engine": use_staging_engine,
"max_num_sequences": max_num_sequences,
"max_input_len": max_input_len,
"max_num_sequences": max_num_sequences,
"max_input_len": max_input_len,
# Use defaults for "min_decode_steps", "max_decode_steps", "prompt_allocate_ratio"
})

Expand Down Expand Up @@ -57,27 +57,27 @@ def create_request(idx, prompt, temp, max_tokens, stop, ignore_eos):
messages = [ChatMessage(role="user", content=prompt)],
sampling_params = SamplingParams(
temperature=0.0,
),
),
stopping_criteria = StoppingCriteria(
max_tokens=max_tokens,
max_tokens=max_tokens,
stop_sequences=stop
),
),
debug_options = DebugOptions(ignore_eos = ignore_eos)
)

def test_max_tokens(
model_artifact_path,
use_staging_engine,
max_num_sequences=4,
model_artifact_path,
use_staging_engine,
max_num_sequences=4,
max_input_len=512,
num_requests=5,
ignore_eos=False
):
prompt = "Write a merge sort program in Python."
engine = create_engine(
model_artifact_path,
use_staging_engine,
max_num_sequences,
model_artifact_path,
use_staging_engine,
max_num_sequences,
max_input_len,
)

Expand All @@ -91,7 +91,7 @@ def test_max_tokens(
for res in results.outputs:
assert len(res.sequences) == 1
seq = res.sequences[0]

if seq.is_finished:
assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens
assert seq.finish_reason == FinishReason.Length
Expand All @@ -103,17 +103,17 @@ def test_max_tokens(


def test_ignore_eos(
model_artifact_path,
use_staging_engine,
max_num_sequences=4,
model_artifact_path,
use_staging_engine,
max_num_sequences=4,
max_input_len=512,
num_requests=5,
):
prompt = "hi"
engine = create_engine(
model_artifact_path,
use_staging_engine,
max_num_sequences,
model_artifact_path,
use_staging_engine,
max_num_sequences,
max_input_len,
)
s = 113
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_ignore_eos(
def test_stop(
model_artifact_path,
use_staging_engine,
max_num_sequences=4,
max_num_sequences=4,
max_input_len=512,
num_requests=5,
):
Expand All @@ -167,12 +167,10 @@ def test_stop(
seq = res.sequences[0]
req_id = int(res.request_id)
if seq.is_finished:
# TODO: Currently staging engine returns FinishReason.Cancelled.
# This needs to be fixed.
#assert seq.finish_reason == FinishReason.Stop, f"{seq.finish_reason.name}"
assert seq.finish_reason == FinishReason.Stop, f"{seq.finish_reason.name}"
assert not seq.delta
gen_txt = generated[req_id]

# stop token should appear only once in the gen text.
found = sum([gen_txt.count(str_stop) for str_stop in requests[req_id].stopping_criteria.stop_sequences])
assert found == 1, f"{gen_txt!r}, matches: {found}"
Expand All @@ -186,10 +184,10 @@ def test_stop(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local-id", type=str, required=True)
parser.add_argument("--artifact-path", type=str, default="../../../dist")
parser.add_argument("--artifact-path", type=str, default="dist")
args = parser.parse_args()
model_artifact_path = os.path.join(args.artifact_path, args.local_id)

test_max_tokens(model_artifact_path, use_staging_engine=True)
test_max_tokens(model_artifact_path, use_staging_engine=False)
test_ignore_eos(model_artifact_path, use_staging_engine=True)
Expand Down