Skip to content

Commit

Permalink
Fix protocol typing (#197)
Browse files Browse the repository at this point in the history
And remove most "type: none". Fixes #138
  • Loading branch information
yelite authored Feb 8, 2024
1 parent 62e51f5 commit ab14322
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 25 deletions.
9 changes: 5 additions & 4 deletions serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os

from http import HTTPStatus
from typing import Annotated, AsyncIterator, List
from typing import Annotated, AsyncIterator, List, Optional

from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse
Expand Down Expand Up @@ -33,6 +33,7 @@
from ..model.base import ModelArtifactConfig
from ..engine.async_connector import AsyncEngineConnector
from .dependencies import get_async_engine_connector
from ..openai_logprob_protocol import LogprobsContent


def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
Expand Down Expand Up @@ -226,10 +227,10 @@ async def collect_result_stream(
) -> ChatCompletionResponse:
created_time = int(time.time())
sequences: List[List[str]] = [[] for _ in range(num_sequences)]
finish_reasons = [None] * num_sequences
finish_reasons: List[Optional[str]] = [None] * num_sequences
num_prompt_tokens = 0
num_generated_tokens = [0 for _ in range(num_sequences)]
logprob_infos = [[] for _ in range(num_sequences)] # type: ignore
logprob_infos: List[List[Optional[LogprobsContent]]] = [[] for _ in range(num_sequences)]
async for res in result_generator:
# TODO: verify that the request cancellation happens after this returns
if res.error:
Expand All @@ -250,7 +251,7 @@ async def collect_result_stream(

if seq.is_finished:
assert seq.finish_reason is not None
finish_reasons[seq.index] = seq.finish_reason.value # type: ignore
finish_reasons[seq.index] = seq.finish_reason.value

choices = []
for index, (logprob_info_seq, chunks, finish_reason) in enumerate(
Expand Down
4 changes: 2 additions & 2 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def detokenize_incrementally(
# This is the first iteration for this sequence
if generation_sequence.prev_tokens is None:
# TODO(masahi): Figure out a way to remove this concat
new_tokens: List[str] = tokenizer.convert_ids_to_tokens( # type: ignore
new_tokens = tokenizer.convert_ids_to_tokens(
prompt_tokens + generation_sequence.generated_token_ids
)
output_tokens = new_tokens
Expand All @@ -110,7 +110,7 @@ def detokenize_incrementally(
prefix_end_offset = max(len(output_tokens) - 1, 0)
else:
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens: List[str] = tokenizer.convert_ids_to_tokens( # type: ignore
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id]
)
output_tokens = generation_sequence.prev_tokens + new_tokens
Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,6 @@ def conversation_template(self) -> ConversationTemplate:
...


class ModelModule(TextTokenGeneratorModule, TokenizerModule):
class ModelModule(TextTokenGeneratorModule, TokenizerModule, Protocol):
model_artifact_config: ModelArtifactConfig
engine_config: MLCServeEngineConfig
2 changes: 1 addition & 1 deletion serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def prepare_textgen_result(
outputs = []
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
assert isinstance(request, PrefillRequest)
for seq_id in range(request.num_sequence): # type: ignore
for seq_id in range(request.num_sequence):
outputs.append(
TextGenerationResult(
sequence_id=SequenceId(sequence_id.request_id, seq_id),
Expand Down
5 changes: 2 additions & 3 deletions serve/mlc_serve/model/paged_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,8 @@ def __init__(
):
self.block_size = block_size

# SequenceId -> list[int]
self.prompt_block_tables = defaultdict(list) # type: ignore
self.slot_mappings = defaultdict(list) # type: ignore
self.prompt_block_tables = defaultdict[SequenceId, List[int]](list)
self.slot_mappings = defaultdict[SequenceId, List[int]](list)

# The core data structure
self.decode_block_tables: dict = dict[SequenceId, DecodeBlockTable]()
Expand Down
6 changes: 3 additions & 3 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path
import structlog
from typing import List
from typing import Sequence, List

from .base import get_model_artifact_config
from .paged_cache_manager import CacheManager
Expand All @@ -27,7 +27,7 @@ def __init__(self, model: TextGenerator):

def generate(
self,
requests: List[RequestType],
requests: Sequence[RequestType],
kv_cache,
) -> List[TextGenerationResult]:
prefill_requests = []
Expand Down Expand Up @@ -94,4 +94,4 @@ def __init__(
self.conversation_template = tokenizer_module.conversation_template

def _check_implements_model_module(self) -> ModelModule:
return self # type: ignore
return self
2 changes: 1 addition & 1 deletion serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def generate_multi_query(
return sample_from_logits(
last_query_logits,
sequence_ids,
requests, # type: ignore
requests,
sampling_metadata,
self.vocab_size,
self._copy_stream,
Expand Down
5 changes: 2 additions & 3 deletions serve/mlc_serve/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,10 @@ def create_engine(
}
)

# TODO(yelite, masahi): Protocol subtyping is not working
if args.use_staging_engine:
return StagingInferenceEngine(
tokenizer_module=HfTokenizerModule(args.model_artifact_path),
model_module_loader=PagedCacheModelModule, # type: ignore
model_module_loader=PagedCacheModelModule,
model_module_loader_kwargs={
"model_artifact_path": args.model_artifact_path,
"engine_config": engine_config,
Expand All @@ -56,7 +55,7 @@ def create_engine(
PagedCacheModelModule(
model_artifact_path=args.model_artifact_path,
engine_config=engine_config,
) # type: ignore
)
)


Expand Down
15 changes: 8 additions & 7 deletions serve/mlc_serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import random
import argparse

from mlc_serve.engine import get_engine_config
from mlc_serve.engine import get_engine_config, InferenceEngine
from mlc_serve.logging_utils import configure_logging
from mlc_serve.engine.staging_engine import StagingInferenceEngine
from mlc_serve.engine.sync_engine import SynchronousInferenceEngine
Expand Down Expand Up @@ -46,7 +46,7 @@ def postproc_mlc_serve_args(args):
random.seed(args.seed)


def create_mlc_engine(args: argparse.Namespace):
def create_mlc_engine(args: argparse.Namespace) -> InferenceEngine:
engine_config = get_engine_config(
{
"use_staging_engine": args.use_staging_engine,
Expand All @@ -56,20 +56,21 @@ def create_mlc_engine(args: argparse.Namespace):
}
)

# TODO(@team): There is a type mismatch in the definition. Let's fix this when have time.
engine: InferenceEngine

if args.use_staging_engine:
engine = StagingInferenceEngine( # type: ignore
engine = StagingInferenceEngine(
tokenizer_module=HfTokenizerModule(args.model_artifact_path),
model_module_loader=PagedCacheModelModule, # type: ignore
model_module_loader=PagedCacheModelModule,
model_module_loader_kwargs={
"model_artifact_path": args.model_artifact_path,
"engine_config": engine_config,
},
)
engine.start()
else:
engine = SynchronousInferenceEngine( # type: ignore
PagedCacheModelModule( # type: ignore
engine = SynchronousInferenceEngine(
PagedCacheModelModule(
model_artifact_path=args.model_artifact_path,
engine_config=engine_config,
)
Expand Down

0 comments on commit ab14322

Please sign in to comment.