Skip to content

Commit

Permalink
use RequestType
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Feb 2, 2024
1 parent bf719a5 commit 4a6156c
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
5 changes: 3 additions & 2 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ConversationTemplate,
KVCacheManager,
ModelModule,
RequestType,
TextGenerator,
Tokenizer as TokenizerP,
)
Expand Down Expand Up @@ -226,8 +227,8 @@ def update_sequence(

def get_requests_to_process(
current_states: list[RequestState], cache_manager: KVCacheManager
) -> Tuple[list[Union[PrefillRequest, DecodeRequest]], bool, int]:
requests: list[Union[PrefillRequest, DecodeRequest]] = []
) -> Tuple[list[RequestType], bool, int]:
requests: list[RequestType] = []
# TODO: consider having hybrid batch if the underlying attention kernel supports
# mixing prefill and decode.
is_prompt_batch = any(not state.is_prefilled for state in current_states)
Expand Down
3 changes: 2 additions & 1 deletion serve/mlc_serve/model/dummy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DecodeRequest,
KVCache,
PrefillRequest,
RequestType,
SequenceId,
TextGenerationResult,
)
Expand Down Expand Up @@ -97,7 +98,7 @@ def get_max_new_tokens(self) -> int:
class DummyTextGenerator:
def generate(
self,
requests: list[Union[PrefillRequest, DecodeRequest]],
requests: list[RequestType],
kv_cache: DummyCache,
) -> list[TextGenerationResult]:
result = []
Expand Down
3 changes: 2 additions & 1 deletion serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
DecodeRequest,
ModelModule,
PrefillRequest,
RequestType,
TextGenerationResult,
TextGenerator,
)
Expand All @@ -24,7 +25,7 @@ def __init__(self, model: TextGenerator):
self.model = model

def generate(
self, requests: List[Union[PrefillRequest, DecodeRequest]], kv_cache
self, requests: List[RequestType], kv_cache
) -> List[TextGenerationResult]:
prefill_requests = [r for r in requests if isinstance(r, PrefillRequest)]
decode_requests = [r for r in requests if isinstance(r, DecodeRequest)]
Expand Down

0 comments on commit 4a6156c

Please sign in to comment.