From 2149d5dd8930e18a311263afe03e2e191bf62d22 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Tue, 9 Jan 2024 00:05:53 +0000 Subject: [PATCH] fix --- serve/benchmarks/benchmark_latency.py | 2 +- serve/mlc_serve/utils.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/serve/benchmarks/benchmark_latency.py b/serve/benchmarks/benchmark_latency.py index eab04ce54a..6d3b897edf 100644 --- a/serve/benchmarks/benchmark_latency.py +++ b/serve/benchmarks/benchmark_latency.py @@ -1,4 +1,4 @@ -"""Benchmark offline user metric.""" +"""Benchmark latency offline.""" import argparse import time, numpy as np from mlc_serve.engine import ( diff --git a/serve/mlc_serve/utils.py b/serve/mlc_serve/utils.py index e47596e784..f2a2e1220c 100644 --- a/serve/mlc_serve/utils.py +++ b/serve/mlc_serve/utils.py @@ -56,11 +56,11 @@ def create_mlc_engine(args: argparse.Namespace): "max_decode_steps": args.max_decode_steps, } ) - # type: off + if args.use_staging_engine: - engine = StagingInferenceEngine( + engine = StagingInferenceEngine( # type: ignore tokenizer_module=HfTokenizerModule(args.model_artifact_path), - model_module_loader=PagedCacheModelModule, + model_module_loader=PagedCacheModelModule, # type: ignore model_module_loader_kwargs={ "model_artifact_path": args.model_artifact_path, "engine_config": engine_config, @@ -68,12 +68,11 @@ def create_mlc_engine(args: argparse.Namespace): ) engine.start() else: - engine = SynchronousInferenceEngine( - PagedCacheModelModule( + engine = SynchronousInferenceEngine( # type: ignore + PagedCacheModelModule( # type: ignore model_artifact_path=args.model_artifact_path, engine_config=engine_config, ) ) - # type: on return engine