diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index 1ac19de062..fc2245b769 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -4,7 +4,7 @@ import os from http import HTTPStatus -from typing import Annotated, AsyncIterator, List +from typing import Annotated, AsyncIterator, List, Optional from fastapi import APIRouter, Depends, Request from fastapi.responses import JSONResponse, StreamingResponse @@ -33,6 +33,7 @@ from ..model.base import ModelArtifactConfig from ..engine.async_connector import AsyncEngineConnector from .dependencies import get_async_engine_connector +from ..openai_logprob_protocol import LogprobsContent def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse: @@ -226,10 +227,10 @@ async def collect_result_stream( ) -> ChatCompletionResponse: created_time = int(time.time()) sequences: List[List[str]] = [[] for _ in range(num_sequences)] - finish_reasons = [None] * num_sequences + finish_reasons: List[Optional[str]] = [None] * num_sequences num_prompt_tokens = 0 num_generated_tokens = [0 for _ in range(num_sequences)] - logprob_infos = [[] for _ in range(num_sequences)] # type: ignore + logprob_infos: List[List[Optional[LogprobsContent]]] = [[] for _ in range(num_sequences)] async for res in result_generator: # TODO: verify that the request cancellation happens after this returns if res.error: @@ -250,7 +251,7 @@ async def collect_result_stream( if seq.is_finished: assert seq.finish_reason is not None - finish_reasons[seq.index] = seq.finish_reason.value # type: ignore + finish_reasons[seq.index] = seq.finish_reason.value choices = [] for index, (logprob_info_seq, chunks, finish_reason) in enumerate( diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index acf1faa90d..d7e15cd1e9 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -94,7 +94,7 @@ def detokenize_incrementally( # This is the first iteration for this sequence if generation_sequence.prev_tokens is None: # TODO(masahi): Figure out a way to remove this concat - new_tokens: List[str] = tokenizer.convert_ids_to_tokens( # type: ignore + new_tokens = tokenizer.convert_ids_to_tokens( prompt_tokens + generation_sequence.generated_token_ids ) output_tokens = new_tokens @@ -110,7 +110,7 @@ def detokenize_incrementally( prefix_end_offset = max(len(output_tokens) - 1, 0) else: # Put new_token_id in a list so skip_special_tokens is respected - new_tokens: List[str] = tokenizer.convert_ids_to_tokens( # type: ignore + new_tokens = tokenizer.convert_ids_to_tokens( [new_token_id] ) output_tokens = generation_sequence.prev_tokens + new_tokens diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 867616d0d5..d542d06440 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -220,6 +220,6 @@ def conversation_template(self) -> ConversationTemplate: ... -class ModelModule(TextTokenGeneratorModule, TokenizerModule): +class ModelModule(TextTokenGeneratorModule, TokenizerModule, Protocol): model_artifact_config: ModelArtifactConfig engine_config: MLCServeEngineConfig diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index d5a010e4e5..ead841f33b 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -57,7 +57,7 @@ def prepare_textgen_result( outputs = [] if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: assert isinstance(request, PrefillRequest) - for seq_id in range(request.num_sequence): # type: ignore + for seq_id in range(request.num_sequence): outputs.append( TextGenerationResult( sequence_id=SequenceId(sequence_id.request_id, seq_id), diff --git a/serve/mlc_serve/model/paged_cache_manager.py b/serve/mlc_serve/model/paged_cache_manager.py index 40ee1b4c68..b26e41e817 100644 --- a/serve/mlc_serve/model/paged_cache_manager.py +++ b/serve/mlc_serve/model/paged_cache_manager.py @@ -108,9 +108,8 @@ def __init__( ): self.block_size = block_size - # SequenceId -> list[int] - self.prompt_block_tables = defaultdict(list) # type: ignore - self.slot_mappings = defaultdict(list) # type: ignore + self.prompt_block_tables = defaultdict[SequenceId, List[int]](list) + self.slot_mappings = defaultdict[SequenceId, List[int]](list) # The core data structure self.decode_block_tables: dict = dict[SequenceId, DecodeBlockTable]() diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 433ca2baa3..34755a8c36 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -1,6 +1,6 @@ from pathlib import Path import structlog -from typing import List +from typing import Sequence, List from .base import get_model_artifact_config from .paged_cache_manager import CacheManager @@ -27,7 +27,7 @@ def __init__(self, model: TextGenerator): def generate( self, - requests: List[RequestType], + requests: Sequence[RequestType], kv_cache, ) -> List[TextGenerationResult]: prefill_requests = [] @@ -94,4 +94,4 @@ def __init__( self.conversation_template = tokenizer_module.conversation_template def _check_implements_model_module(self) -> ModelModule: - return self # type: ignore + return self diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index a69e98118e..f806e944bc 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -299,7 +299,7 @@ def generate_multi_query( return sample_from_logits( last_query_logits, sequence_ids, - requests, # type: ignore + requests, sampling_metadata, self.vocab_size, self._copy_stream, diff --git a/serve/mlc_serve/run.py b/serve/mlc_serve/run.py index 3e77c7e2d5..50947912cb 100644 --- a/serve/mlc_serve/run.py +++ b/serve/mlc_serve/run.py @@ -41,11 +41,10 @@ def create_engine( } ) - # TODO(yelite, masahi): Protocol subtyping is not working if args.use_staging_engine: return StagingInferenceEngine( tokenizer_module=HfTokenizerModule(args.model_artifact_path), - model_module_loader=PagedCacheModelModule, # type: ignore + model_module_loader=PagedCacheModelModule, model_module_loader_kwargs={ "model_artifact_path": args.model_artifact_path, "engine_config": engine_config, @@ -56,7 +55,7 @@ def create_engine( PagedCacheModelModule( model_artifact_path=args.model_artifact_path, engine_config=engine_config, - ) # type: ignore + ) ) diff --git a/serve/mlc_serve/utils.py b/serve/mlc_serve/utils.py index ed3be6a451..10c616ac18 100644 --- a/serve/mlc_serve/utils.py +++ b/serve/mlc_serve/utils.py @@ -6,7 +6,7 @@ import random import argparse -from mlc_serve.engine import get_engine_config +from mlc_serve.engine import get_engine_config, InferenceEngine from mlc_serve.logging_utils import configure_logging from mlc_serve.engine.staging_engine import StagingInferenceEngine from mlc_serve.engine.sync_engine import SynchronousInferenceEngine @@ -46,7 +46,7 @@ def postproc_mlc_serve_args(args): random.seed(args.seed) -def create_mlc_engine(args: argparse.Namespace): +def create_mlc_engine(args: argparse.Namespace) -> InferenceEngine: engine_config = get_engine_config( { "use_staging_engine": args.use_staging_engine, @@ -56,11 +56,12 @@ def create_mlc_engine(args: argparse.Namespace): } ) - # TODO(@team): There is a type mismatch in the definition. Let's fix this when have time. + engine: InferenceEngine + if args.use_staging_engine: - engine = StagingInferenceEngine( # type: ignore + engine = StagingInferenceEngine( tokenizer_module=HfTokenizerModule(args.model_artifact_path), - model_module_loader=PagedCacheModelModule, # type: ignore + model_module_loader=PagedCacheModelModule, model_module_loader_kwargs={ "model_artifact_path": args.model_artifact_path, "engine_config": engine_config, @@ -68,8 +69,8 @@ def create_mlc_engine(args: argparse.Namespace): ) engine.start() else: - engine = SynchronousInferenceEngine( # type: ignore - PagedCacheModelModule( # type: ignore + engine = SynchronousInferenceEngine( + PagedCacheModelModule( model_artifact_path=args.model_artifact_path, engine_config=engine_config, )