diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 8cb503de06..4a02ce2f60 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -27,6 +27,7 @@ ConversationTemplate, KVCacheManager, ModelModule, + RequestType, TextGenerator, Tokenizer as TokenizerP, ) @@ -228,10 +229,8 @@ def update_sequence( def get_requests_to_process( current_states: list[RequestState], cache_manager: KVCacheManager -) -> Tuple[ - list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], bool, int -]: - requests: list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]] = [] +) -> Tuple[list[RequestType], bool, int]: + requests: list[RequestType] = [] # TODO: consider having hybrid batch if the underlying attention kernel supports # mixing prefill and decode. is_prompt_batch = any(not state.is_prefilled for state in current_states) diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 00893efa44..a5b86d69b9 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -66,6 +66,10 @@ class EvalMultiQueryRequest: sampling_params: SamplingParams +RequestType = Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest] +RequestsType = Sequence[RequestType] + + @dataclass class TextGenerationResult: """ diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index fee952ad4d..22ebec7ebb 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union, Sequence +from typing import List, Optional, Tuple, Union import structlog import numpy as np @@ -11,16 +11,16 @@ SamplingParams, get_prompt_sequence_id, LOGPROB_TOP_K_MAX, - RawLogprobsInfo, - RawLogprobsInfos, PROMPT_SEQEUNCE_INDEX, + RawLogprobsInfo, RawLogprobsInfos, SequenceId, ) from ..engine.model_module import ( - DecodeRequest, PrefillRequest, EvalMultiQueryRequest, + RequestType, + RequestsType, TextGenerationResult, ) @@ -302,49 +302,73 @@ def _is_safe_to_sample(prob_like): return res, check_logprob_infos(logprob_infos) +def update_tokens_frequency( + request: RequestType, + new_token: int +): + if not new_token in request.sampling_params.appeared_tokens_freq: + request.sampling_params.appeared_tokens_freq[new_token] = 0 + request.sampling_params.appeared_tokens_freq[new_token] += 1 + + +def append_text_gen_res( + outputs: List[TextGenerationResult], + request: RequestType, + new_token: List[int], + sequence_id: SequenceId, + logprob_info: Optional[RawLogprobsInfos], + err_msg: Optional[str]=None, +) -> List[TextGenerationResult]: + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + assert isinstance(request, PrefillRequest) + for seq_id in range(request.num_sequence): # type: ignore + outputs.append( + TextGenerationResult( + sequence_id=SequenceId(sequence_id.request_id, seq_id), + generated_tokens=new_token, + error=err_msg, + logprob_info=logprob_info, + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=new_token, + error=err_msg, + logprob_info=logprob_info, + ) + ) + return outputs + + def sample_from_logits( logits: Union[tvm.nd.NDArray, torch.Tensor], sequence_ids: List[SequenceId], - requests: Sequence[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], + requests: RequestsType, vocab_size, ) -> List[TextGenerationResult]: assert logits.shape[0] == len(requests) sampling_params = [req.sampling_params for req in requests] + outputs: List[TextGenerationResult] = [] try: next_tokens, logprob_infos = sample(logits, sampling_params, vocab_size) assert next_tokens is not None - outputs = [] for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)): - if not new_token in sampling_params[i].appeared_tokens_freq: - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - assert isinstance(requests[i], PrefillRequest) - for seq_id in range(requests[i].num_sequence): # type: ignore - outputs.append( - TextGenerationResult( - sequence_id=SequenceId(sequence_id.request_id, seq_id), - generated_tokens=[new_token], - error=None, - logprob_info=get_logprob_infos(i, logprob_infos), - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], - error=None, - logprob_info=get_logprob_infos(i, logprob_infos), - ) - ) + update_tokens_frequency(requests[i], new_token) + outputs = append_text_gen_res( + outputs, + requests[i], + [new_token], + sequence_id, + get_logprob_infos(i, logprob_infos), + ) return outputs except RuntimeError: # Fallback to per-token sampling in case some logits values are corrupted. - outputs = [] err_msg = ( "Error from sampling: probability tensor contains either `inf`, `nan`" " or element < 0" @@ -362,50 +386,23 @@ def sample_from_logits( if maybe_new_token is not None: new_token = maybe_new_token[0] - if not new_token in requests[i].sampling_params.appeared_tokens_freq: - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - assert isinstance(requests[i], PrefillRequest) - for seq_id in range(requests[i].num_sequence): # type: ignore - outputs.append( - TextGenerationResult( - sequence_id=SequenceId(sequence_id.request_id, seq_id), - generated_tokens=[new_token], # type: ignore - error=None, - logprob_info=get_logprob_infos(0, logprob_infos), - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], # type: ignore - error=None, - logprob_info=get_logprob_infos(0, logprob_infos), - ) - ) + update_tokens_frequency(requests[i], new_token) + outputs = append_text_gen_res( + outputs, + requests[i], + [new_token], + sequence_id, + get_logprob_infos(0, logprob_infos), + ) else: - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - assert isinstance(requests[i], PrefillRequest) - for seq_id in range(requests[i].num_sequence): # type: ignore - outputs.append( - TextGenerationResult( - sequence_id=SequenceId(sequence_id.request_id, seq_id), - generated_tokens=[], - error=err_msg, - logprob_info=get_logprob_infos(0, logprob_infos), - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[], - error=err_msg, - logprob_info=get_logprob_infos(0, logprob_infos), - ) - ) + outputs = append_text_gen_res( + outputs, + requests[i], + [], # new_token + sequence_id, + get_logprob_infos(0, logprob_infos), + err_msg, + ) return outputs diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 0b16ab0b3c..433ca2baa3 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -1,6 +1,6 @@ from pathlib import Path import structlog -from typing import List, Union +from typing import List from .base import get_model_artifact_config from .paged_cache_manager import CacheManager @@ -13,6 +13,7 @@ ModelModule, PrefillRequest, EvalMultiQueryRequest, + RequestType, TextGenerationResult, TextGenerator, ) @@ -26,9 +27,9 @@ def __init__(self, model: TextGenerator): def generate( self, - requests: list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], + requests: List[RequestType], kv_cache, - ) -> list[TextGenerationResult]: + ) -> List[TextGenerationResult]: prefill_requests = [] decode_requests = [] multi_query_decode_requests = [] diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 202a04e30d..0c28ff7003 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -1,6 +1,6 @@ import math import os -from typing import List, Union, Tuple, Sequence +from typing import List, Tuple import structlog import numpy as np @@ -24,9 +24,10 @@ ) from ..engine.model_module import ( DecodeRequest, - PrefillRequest, DraftTokens, EvalMultiQueryRequest, + PrefillRequest, + RequestsType, TextGenerationResult, TextGenerator, ) @@ -276,9 +277,7 @@ def generate_multi_query( def generate( self, - requests: Sequence[ - Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest] - ], + requests: RequestsType, cache: KVCacheInfo, ) -> List[TextGenerationResult]: if len(requests) == 0: