From 4a951cd267960aed2254728c7c7e6589e1401660 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 22 Nov 2023 01:55:50 +0000 Subject: [PATCH 1/4] wip --- serve/mlc_serve/engine/staging_engine.py | 10 ++- .../mlc_serve/engine/staging_engine_worker.py | 61 ++++++++++++++----- 2 files changed, 55 insertions(+), 16 deletions(-) diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index 757ed0bf38..be8113ec57 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -21,6 +21,7 @@ from .staging_engine_worker import ( AddRequestsCommand, CancelRequestCommand, + StopRequestCommand, ShutdownCommand, run_generation_loop_worker, ) @@ -107,6 +108,11 @@ def cancel(self, request_id: RequestId): raise RuntimeError("GenerationLoopWorker process is not running") self.command_queue.put(CancelRequestCommand(request_id)) + def stop_request(self, request_id: RequestId): + if not self._is_ready_to_serve(): + raise RuntimeError("GenerationLoopWorker process is not running") + self.command_queue.put(StopRequestCommand(request_id)) + def has_pending_requests(self) -> bool: with self.requests_lock: return len(self.requests) > 0 @@ -177,9 +183,9 @@ def step(self) -> InferenceStepResult: state.output_text, delta, state.is_ended) - # signal workers to stop generation + # signal workers to stop generation if state.is_ended: - self.cancel(state.request_id) + self.stop_request(state.request_id) outputs.append( RequestOutput( diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index ba81bebedc..60a9fcdeb2 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -30,6 +30,11 @@ class CancelRequestCommand: request_id: RequestId +@dataclass +class StopRequestCommand: + request_id: RequestId + + GenerationLoopWorkerCommand = Union[ ShutdownCommand, AddRequestsCommand, CancelRequestCommand ] @@ -72,6 +77,7 @@ def __init__( self.has_new_requests = Condition(lock=self.queue_lock) self.cancelled_requests = list[RequestState]() + self.stopped_requests = list[RequestState]() self.current_batch = dict[RequestId, RequestState]() @@ -89,21 +95,32 @@ def add(self, request_states: list[RequestState]): self.queue.extend(valid_states) self.has_new_requests.notify_all() + def _get_request_state(self, request_id: RequestId) -> Optional[RequestState]: + for i, state in enumerate(self.queue): + if state.request_id == request_id: + return state + + return None + def cancel(self, request_id: RequestId): with self.queue_lock: - queue_index_to_delete = None - for i, state in enumerate(self.queue): - if state.request_id == request_id: - queue_index_to_delete = i - self.cancelled_requests.append(state) - break - - if queue_index_to_delete is not None: - del self.queue[queue_index_to_delete] + state = self._get_request_state(request_id) + if state: + del state if request_id in self.current_batch: self.cancelled_requests.append(self.current_batch[request_id]) + def stop(self, request_id: RequestId): + print("stop ", request_id) + with self.queue_lock: + state = self._get_request_state(request_id) + if state: + del state + + if request_id in self.current_batch: + self.stopped_requests.append(self.current_batch[request_id]) + def wait_for_request(self, timeout_seconds=None) -> bool: with self.queue_lock: self.has_new_requests.wait_for( @@ -138,6 +155,20 @@ def step(self) -> GenerationLoopWorkerOutput: ) self._remove_request_from_batch(state.request_id) + for state in self.stopped_requests: + print("stopping",state.request_id) + outputs.append( + SequenceGenerationOutput( + # TODO: support multi-sequence + id=SequenceId(state.request_id, 0), + new_tokens=[], + finish_reason=FinishReason.Stop, + error=None, + ) + ) + if state.request_id in self.current_batch: + self._remove_request_from_batch(state.request_id) + for state in self.cancelled_requests: err = None if state.validation_err: @@ -149,7 +180,7 @@ def step(self) -> GenerationLoopWorkerOutput: id=SequenceId(state.request_id, 0), new_tokens=[], finish_reason=FinishReason.Cancelled, - error = err + error=err, ) ) if state.request_id in self.current_batch: @@ -307,13 +338,13 @@ def _has_request_to_process(self) -> bool: return self.queue or self.current_batch def _should_stop_by_length(self, state: RequestState) -> bool: - # TODO: currently, we simply return true for both stopping reasons. - # in the future, we can differentiate these two. + # TODO: currently, we simply return true for both stopping reasons. + # in the future, we can differentiate these two. # this include prompt tokens and gen tokens so far - num_context_tokens = len(state.token_ids) + num_context_tokens = len(state.token_ids) if num_context_tokens >= self.model_artifact_config.max_context_length: return True - num_gen_tokens = num_context_tokens - state.prompt_len + num_gen_tokens = num_context_tokens - state.prompt_len if num_gen_tokens >= state.stopping_criteria.max_tokens: return True return False @@ -367,6 +398,8 @@ def handle_command(): worker.add(cmd.request_states) elif isinstance(cmd, CancelRequestCommand): worker.cancel(cmd.request_id) + elif isinstance(cmd, StopRequestCommand): + worker.stop(cmd.request_id) else: logger.error("Unknown command type %s", type(cmd)) break From 66aa313b7203a3e544af6227c8c16c5a9d9a1578 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 22 Nov 2023 02:03:22 +0000 Subject: [PATCH 2/4] clean --- .../mlc_serve/engine/staging_engine_worker.py | 26 ++++----- .../unittest/test_engine_with_samplers.py | 56 +++++++++---------- 2 files changed, 38 insertions(+), 44 deletions(-) diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index 60a9fcdeb2..be1eb53fb4 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -96,30 +96,26 @@ def add(self, request_states: list[RequestState]): self.has_new_requests.notify_all() def _get_request_state(self, request_id: RequestId) -> Optional[RequestState]: - for i, state in enumerate(self.queue): + for state in self.queue: if state.request_id == request_id: return state return None - def cancel(self, request_id: RequestId): + def _cacnel_or_stop_request(self, request_id: RequestId, requests: list[RequestState]): with self.queue_lock: state = self._get_request_state(request_id) if state: del state if request_id in self.current_batch: - self.cancelled_requests.append(self.current_batch[request_id]) + requests.append(self.current_batch[request_id]) - def stop(self, request_id: RequestId): - print("stop ", request_id) - with self.queue_lock: - state = self._get_request_state(request_id) - if state: - del state + def cancel_request(self, request_id: RequestId): + self._cacnel_or_stop_request(request_id, self.cancelled_requests) - if request_id in self.current_batch: - self.stopped_requests.append(self.current_batch[request_id]) + def stop_request(self, request_id: RequestId): + self._cacnel_or_stop_request(request_id, self.stopped_requests) def wait_for_request(self, timeout_seconds=None) -> bool: with self.queue_lock: @@ -156,19 +152,19 @@ def step(self) -> GenerationLoopWorkerOutput: self._remove_request_from_batch(state.request_id) for state in self.stopped_requests: - print("stopping",state.request_id) outputs.append( SequenceGenerationOutput( # TODO: support multi-sequence id=SequenceId(state.request_id, 0), new_tokens=[], finish_reason=FinishReason.Stop, - error=None, ) ) if state.request_id in self.current_batch: self._remove_request_from_batch(state.request_id) + self.stopped_requests.clear() + for state in self.cancelled_requests: err = None if state.validation_err: @@ -397,9 +393,9 @@ def handle_command(): elif isinstance(cmd, AddRequestsCommand): worker.add(cmd.request_states) elif isinstance(cmd, CancelRequestCommand): - worker.cancel(cmd.request_id) + worker.cancel_request(cmd.request_id) elif isinstance(cmd, StopRequestCommand): - worker.stop(cmd.request_id) + worker.stop_request(cmd.request_id) else: logger.error("Unknown command type %s", type(cmd)) break diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py index 296633dab5..40ba9b7e80 100644 --- a/serve/tests/unittest/test_engine_with_samplers.py +++ b/serve/tests/unittest/test_engine_with_samplers.py @@ -20,16 +20,16 @@ from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule def create_engine( - model_artifact_path, - use_staging_engine, - max_num_sequences, + model_artifact_path, + use_staging_engine, + max_num_sequences, max_input_len, - + ): engine_config = get_engine_config({ "use_staging_engine": use_staging_engine, - "max_num_sequences": max_num_sequences, - "max_input_len": max_input_len, + "max_num_sequences": max_num_sequences, + "max_input_len": max_input_len, # Use defaults for "min_decode_steps", "max_decode_steps", "prompt_allocate_ratio" }) @@ -57,27 +57,27 @@ def create_request(idx, prompt, temp, max_tokens, stop, ignore_eos): messages = [ChatMessage(role="user", content=prompt)], sampling_params = SamplingParams( temperature=0.0, - ), + ), stopping_criteria = StoppingCriteria( - max_tokens=max_tokens, + max_tokens=max_tokens, stop_sequences=stop - ), + ), debug_options = DebugOptions(ignore_eos = ignore_eos) ) def test_max_tokens( - model_artifact_path, - use_staging_engine, - max_num_sequences=4, + model_artifact_path, + use_staging_engine, + max_num_sequences=4, max_input_len=512, num_requests=5, ignore_eos=False ): prompt = "Write a merge sort program in Python." engine = create_engine( - model_artifact_path, - use_staging_engine, - max_num_sequences, + model_artifact_path, + use_staging_engine, + max_num_sequences, max_input_len, ) @@ -91,7 +91,7 @@ def test_max_tokens( for res in results.outputs: assert len(res.sequences) == 1 seq = res.sequences[0] - + if seq.is_finished: assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens assert seq.finish_reason == FinishReason.Length @@ -103,17 +103,17 @@ def test_max_tokens( def test_ignore_eos( - model_artifact_path, - use_staging_engine, - max_num_sequences=4, + model_artifact_path, + use_staging_engine, + max_num_sequences=4, max_input_len=512, num_requests=5, ): prompt = "hi" engine = create_engine( - model_artifact_path, - use_staging_engine, - max_num_sequences, + model_artifact_path, + use_staging_engine, + max_num_sequences, max_input_len, ) s = 113 @@ -141,7 +141,7 @@ def test_ignore_eos( def test_stop( model_artifact_path, use_staging_engine, - max_num_sequences=4, + max_num_sequences=4, max_input_len=512, num_requests=5, ): @@ -167,12 +167,10 @@ def test_stop( seq = res.sequences[0] req_id = int(res.request_id) if seq.is_finished: - # TODO: Currently staging engine returns FinishReason.Cancelled. - # This needs to be fixed. - #assert seq.finish_reason == FinishReason.Stop, f"{seq.finish_reason.name}" + assert seq.finish_reason == FinishReason.Stop, f"{seq.finish_reason.name}" assert not seq.delta gen_txt = generated[req_id] - + # stop token should appear only once in the gen text. found = sum([gen_txt.count(str_stop) for str_stop in requests[req_id].stopping_criteria.stop_sequences]) assert found == 1, f"{gen_txt!r}, matches: {found}" @@ -186,10 +184,10 @@ def test_stop( if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--local-id", type=str, required=True) - parser.add_argument("--artifact-path", type=str, default="../../../dist") + parser.add_argument("--artifact-path", type=str, default="dist") args = parser.parse_args() model_artifact_path = os.path.join(args.artifact_path, args.local_id) - + test_max_tokens(model_artifact_path, use_staging_engine=True) test_max_tokens(model_artifact_path, use_staging_engine=False) test_ignore_eos(model_artifact_path, use_staging_engine=True) From 45cd4a5cea92a4da5ae3b939699a4321b3c5365d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 22 Nov 2023 02:04:05 +0000 Subject: [PATCH 3/4] black --- serve/mlc_serve/engine/staging_engine.py | 14 ++++++++------ serve/mlc_serve/engine/staging_engine_worker.py | 17 +++++++++++++---- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index be8113ec57..7be1c6ba24 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -15,7 +15,7 @@ RequestState, ScopedInferenceEngine, SequenceOutput, - check_stopping_sequences + check_stopping_sequences, ) from .model_module import ModelModule, TokenizerModule from .staging_engine_worker import ( @@ -25,6 +25,7 @@ ShutdownCommand, run_generation_loop_worker, ) + logger = logging.getLogger(__name__) @@ -91,7 +92,9 @@ def add(self, requests: list[Request]): # wrap the stop sequence with list if necessary if req.stopping_criteria.stop_sequences: if isinstance(req.stopping_criteria.stop_sequences, str): - req.stopping_criteria.stop_sequences = [req.stopping_criteria.stop_sequences] + req.stopping_criteria.stop_sequences = [ + req.stopping_criteria.stop_sequences + ] assert isinstance(req.stopping_criteria.stop_sequences, list) # If the request violates the tokenization, this returns None, so skip. @@ -179,10 +182,9 @@ def step(self) -> InferenceStepResult: delta = self._decode_last_output(state) state.output_text += delta - state.output_text, delta, state.is_ended = check_stopping_sequences(state.stopping_criteria, - state.output_text, - delta, - state.is_ended) + state.output_text, delta, state.is_ended = check_stopping_sequences( + state.stopping_criteria, state.output_text, delta, state.is_ended + ) # signal workers to stop generation if state.is_ended: self.stop_request(state.request_id) diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index be1eb53fb4..d912e58610 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -15,6 +15,7 @@ logger = logging.getLogger(__name__) + @dataclass class ShutdownCommand: pass @@ -66,9 +67,12 @@ def __init__( self.max_context_length = self.model_artifact_config.max_context_length self.max_num_batched_tokens = model_module.engine_config.max_num_batched_tokens self.max_decode_steps = min( - self.cache_manager.get_kv_cache_size(), model_module.engine_config.max_decode_steps + self.cache_manager.get_kv_cache_size(), + model_module.engine_config.max_decode_steps, + ) + self.min_decode_steps = min( + self.max_decode_steps - 1, model_module.engine_config.min_decode_steps ) - self.min_decode_steps = min(self.max_decode_steps - 1, model_module.engine_config.min_decode_steps) self.prompt_allocate_ratio = model_module.engine_config.prompt_allocate_ratio assert self.prompt_allocate_ratio >= 1.0 @@ -87,7 +91,10 @@ def add(self, request_states: list[RequestState]): # cancel them instead. valid_states = [] for request_state in request_states: - if request_state.validation_err is not None or request_state.prompt_len >= self.max_context_length: + if ( + request_state.validation_err is not None + or request_state.prompt_len >= self.max_context_length + ): self.cancelled_requests.append(request_state) else: valid_states.append(request_state) @@ -102,7 +109,9 @@ def _get_request_state(self, request_id: RequestId) -> Optional[RequestState]: return None - def _cacnel_or_stop_request(self, request_id: RequestId, requests: list[RequestState]): + def _cacnel_or_stop_request( + self, request_id: RequestId, requests: list[RequestState] + ): with self.queue_lock: state = self._get_request_state(request_id) if state: From 4e0fb8f4c35235d6b8c6668c97e5b1d5ceea31ee Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 22 Nov 2023 02:08:49 +0000 Subject: [PATCH 4/4] minor --- serve/tests/test_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index 9d1782796b..9ec2024768 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -34,7 +34,7 @@ def test(args: argparse.Namespace): engine_config = get_engine_config({ "use_staging_engine": args.use_staging_engine, "max_num_sequences": args.max_num_sequences, - "max_input_len": args.max_input_len, + "max_input_len": args.max_input_len, "min_decode_steps": args.min_decode_steps, "max_decode_steps": args.max_decode_steps, "prompt_allocate_ratio": args.prompt_allocate_ratio @@ -119,7 +119,6 @@ def test(args: argparse.Namespace): parser = argparse.ArgumentParser() parser.add_argument("--local-id", type=str, required=True) parser.add_argument("--artifact-path", type=str, default="dist") - parser.add_argument("--num-shards", type=int, default=1) parser.add_argument("--max-input-len", type=int, default=512) parser.add_argument("--max-num-sequences", type=int, default=8) parser.add_argument("--max-output-len", type=int, default=20)