diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 675e97e173..4d40899ab5 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -25,6 +25,7 @@ ConversationTemplate, KVCacheManager, ModelModule, + RequestType, TextGenerator, Tokenizer as TokenizerP, ) @@ -226,8 +227,8 @@ def update_sequence( def get_requests_to_process( current_states: list[RequestState], cache_manager: KVCacheManager -) -> Tuple[list[Union[PrefillRequest, DecodeRequest]], bool, int]: - requests: list[Union[PrefillRequest, DecodeRequest]] = [] +) -> Tuple[list[RequestType], bool, int]: + requests: list[RequestType] = [] # TODO: consider having hybrid batch if the underlying attention kernel supports # mixing prefill and decode. is_prompt_batch = any(not state.is_prefilled for state in current_states) diff --git a/serve/mlc_serve/model/dummy_model.py b/serve/mlc_serve/model/dummy_model.py index 630ed7c267..b8900273a6 100644 --- a/serve/mlc_serve/model/dummy_model.py +++ b/serve/mlc_serve/model/dummy_model.py @@ -11,6 +11,7 @@ DecodeRequest, KVCache, PrefillRequest, + RequestType, SequenceId, TextGenerationResult, ) @@ -97,7 +98,7 @@ def get_max_new_tokens(self) -> int: class DummyTextGenerator: def generate( self, - requests: list[Union[PrefillRequest, DecodeRequest]], + requests: list[RequestType], kv_cache: DummyCache, ) -> list[TextGenerationResult]: result = [] diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 7daf3336f4..6c45621dcb 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -12,6 +12,7 @@ DecodeRequest, ModelModule, PrefillRequest, + RequestType, TextGenerationResult, TextGenerator, ) @@ -24,7 +25,7 @@ def __init__(self, model: TextGenerator): self.model = model def generate( - self, requests: List[Union[PrefillRequest, DecodeRequest]], kv_cache + self, requests: List[RequestType], kv_cache ) -> List[TextGenerationResult]: prefill_requests = [r for r in requests if isinstance(r, PrefillRequest)] decode_requests = [r for r in requests if isinstance(r, DecodeRequest)]