From ae5be1e85d2c91a5d83117458e427538e45c6b3f Mon Sep 17 00:00:00 2001 From: hzjane Date: Tue, 7 Jan 2025 09:23:29 +0800 Subject: [PATCH 1/2] init --- .../src/ipex_llm/vllm/xpu/engine/engine.py | 7 ++--- .../src/ipex_llm/vllm/xpu/model_convert.py | 31 +++++++++---------- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py index c4bed0c52e7..fbb41e9c3f6 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py +++ b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py @@ -19,7 +19,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.entrypoints.llm import LLM from vllm.utils import Counter -from vllm.config import EngineConfig +from vllm.config import VllmConfig from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert from vllm.usage.usage_lib import UsageContext from vllm.engine.metrics import StatLoggerBase @@ -35,7 +35,7 @@ def __init__(self, *args, **kwargs): def from_engine_args( cls, engine_args: AsyncEngineArgs, - engine_config: Optional[EngineConfig] = None, + engine_config: Optional[VllmConfig] = None, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, load_in_low_bit: str = "sym_int4", @@ -67,7 +67,6 @@ def __init__( swap_space: int = 4, cpu_offload_gb: float = 0, enforce_eager: bool = False, - max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, load_in_low_bit: str = "sym_int4", @@ -96,11 +95,11 @@ def __init__( swap_space=swap_space, cpu_offload_gb=cpu_offload_gb, enforce_eager=enforce_eager, - max_context_len_to_capture=max_context_len_to_capture, max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, ) + self.engine_class = self.get_engine_class() self.llm_engine = IPEXLLMLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.LLM_CLASS, load_in_low_bit=load_in_low_bit) diff --git a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py index 3d88d8f9edc..da88316e8d2 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py +++ b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py @@ -77,19 +77,16 @@ def _ipex_llm_load_model(self) -> None: # from vllm.utils import measure_device_memory from vllm.utils import DeviceMemoryProfiler with DeviceMemoryProfiler() as m: + from dataclasses import replace + new_device_config = DeviceConfig("cpu") + new_vllm_config = replace(self.vllm_config, device_config=new_device_config) self.model = get_model( - model_config=self.model_config, - device_config=DeviceConfig("cpu"), - load_config=self.load_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config, + vllm_config=new_vllm_config ) - if "qwen" in self.model_config.model.lower() or \ - "baichuan" in self.model_config.model.lower() or \ - "codegeex4-all" in self.model_config.model.lower() or \ - "chatglm" in self.model_config.model.lower(): + if "qwen" in self.vllm_config.model_config.model.lower() or \ + "baichuan" in self.vllm_config.model_config.model.lower() or \ + "codegeex4-all" in self.vllm_config.model_config.model.lower() or \ + "chatglm" in self.vllm_config.model_config.model.lower(): self.model.apply(padding_mlp) from ipex_llm import optimize_model import os @@ -99,18 +96,18 @@ def _ipex_llm_load_model(self) -> None: modules = ["35.mlp", "36.mlp", "37.mlp", "38.mlp", "39.mlp"] else: modules = None - if "minicpm" in self.model_config.model.lower(): + if "minicpm" in self.vllm_config.model_config.model.lower(): modules = ["vpm", "resampler"] # only for minicpm_2_6 - if "minicpm-v" in self.model_config.model.lower(): + if "minicpm-v" in self.vllm_config.model_config.model.lower(): from ipex_llm.transformers.models.minicpmv import merge_qkv self.model.vpm.apply(merge_qkv) - if "internvl2" in self.model_config.model.lower(): + if "internvl2" in self.vllm_config.model_config.model.lower(): modules = ["vision_model", "mlp1"] - optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype, + optimize_model(self.model, low_bit=low_bit, torch_dtype=self.vllm_config.model_config.dtype, modules_to_not_convert=modules) - self.model = self.model.to(device=self.device_config.device, - dtype=self.model_config.dtype) + self.model = self.model.to(device=self.vllm_config.device_config.device, + dtype=self.vllm_config.model_config.dtype) self.model_memory_usage = m.consumed_memory logger = init_logger(__name__) From 58743e5e574e24ecc51863a16571d142fa598bb3 Mon Sep 17 00:00:00 2001 From: hzjane Date: Tue, 7 Jan 2025 15:10:35 +0800 Subject: [PATCH 2/2] update engine init --- .../src/ipex_llm/vllm/xpu/engine/engine.py | 21 +- .../vllm/xpu/entrypoints/openai/api_server.py | 344 ++++++++++++++---- .../vllm/xpu/entrypoints/openai/cli_args.py | 76 +++- 3 files changed, 352 insertions(+), 89 deletions(-) diff --git a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py index fbb41e9c3f6..312fd70a903 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py +++ b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from vllm.logger import init_logger from typing import Dict, Optional from vllm.engine.llm_engine import LLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -26,6 +27,7 @@ from vllm.engine.multiprocessing.engine import MQLLMEngine import signal +logger = init_logger(__name__) class IPEXLLMAsyncLLMEngine(AsyncLLMEngine): def __init__(self, *args, **kwargs): @@ -133,16 +135,21 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs, def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, - ipc_path: str, load_in_low_bit: str): + ipc_path: str, load_in_low_bit: str, engine_alive): def signal_handler(*_) -> None: # Interrupt server on sigterm raise KeyboardInterrupt("MQLLMEngine terminated") # noqa - signal.signal(signal.SIGTERM, signal_handler) + try: + signal.signal(signal.SIGTERM, signal_handler) - engine = IPEXLLMMQLLMEngine.from_engine_args(engine_args=engine_args, - usage_context=usage_context, - ipc_path=ipc_path, - load_in_low_bit=load_in_low_bit) - engine.start() + engine = IPEXLLMMQLLMEngine.from_engine_args(engine_args=engine_args, + usage_context=usage_context, + ipc_path=ipc_path, + load_in_low_bit=load_in_low_bit) + engine.start() + except BaseException as e: + logger.exception(e) + engine_alive.value = False + raise e diff --git a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/api_server.py b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/api_server.py index 1bd9835fe03..680edc993b0 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/api_server.py +++ b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/api_server.py @@ -1,4 +1,5 @@ import asyncio +import atexit import importlib import inspect import multiprocessing @@ -7,11 +8,12 @@ import signal import socket import tempfile +import uuid from argparse import Namespace from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import AsyncIterator, Set +from typing import AsyncIterator, Optional, Set, Tuple import uvloop from fastapi import APIRouter, FastAPI, Request @@ -29,8 +31,10 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from ipex_llm.vllm.xpu.engine import run_mp_engine from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.cli_args import validate_parsed_serve_args from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser # yapf conflicts with isort for this block # yapf: disable @@ -41,8 +45,12 @@ DetokenizeRequest, DetokenizeResponse, EmbeddingRequest, - EmbeddingResponse, ErrorResponse, + EmbeddingResponse, + EmbeddingResponseData, + ErrorResponse, LoadLoraAdapterRequest, + PoolingRequest, PoolingResponse, + ScoreRequest, ScoreResponse, TokenizeRequest, TokenizeResponse, UnloadLoraAdapterRequest) @@ -50,12 +58,17 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -from vllm.entrypoints.openai.serving_engine import BaseModelPath +from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling +from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path +from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, + is_valid_ipv6_address, set_ulimit) from vllm.version import __version__ as VLLM_VERSION TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -103,7 +116,7 @@ async def build_async_engine_client( engine_args = AsyncEngineArgs.from_cli_args(args) async with build_async_engine_client_from_engine_args( - engine_args, args.disable_frontend_multiprocessing, args.load_in_low_bit) as engine: + engine_args, args.disable_frontend_multiprocessing) as engine: yield engine @@ -111,7 +124,7 @@ async def build_async_engine_client( async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, - load_in_low_bit: str = 'sym_int4', + load_in_low_bit: str = "sym_int4", ) -> AsyncIterator[EngineClient]: """ Create EngineClient, either: @@ -124,14 +137,15 @@ async def build_async_engine_client_from_engine_args( # Fall back # TODO: fill out feature matrix. if (MQLLMEngineClient.is_unsupported_config(engine_args) - or disable_frontend_multiprocessing): - engine_config = engine_args.create_engine_config() + or envs.VLLM_USE_V1 or disable_frontend_multiprocessing): + engine_config = engine_args.create_engine_config( + UsageContext.OPENAI_API_SERVER) uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), "uses_ray", False) build_engine = partial(AsyncLLMEngine.from_engine_args, - engine_args=engine_args, load_in_low_bit=load_in_low_bit, + engine_args=engine_args, engine_config=engine_config, usage_context=UsageContext.OPENAI_API_SERVER) if uses_ray: @@ -142,6 +156,8 @@ async def build_async_engine_client_from_engine_args( None, build_engine) yield engine_client + if hasattr(engine_client, "shutdown"): + engine_client.shutdown() return # Otherwise, use the multiprocessing AsyncLLMEngine. @@ -163,45 +179,60 @@ async def build_async_engine_client_from_engine_args( # Select random path for IPC. ipc_path = get_open_zmq_ipc_path() - logger.info("Multiprocessing frontend to use %s for IPC Path.", - ipc_path) + logger.debug("Multiprocessing frontend to use %s for IPC Path.", + ipc_path) # Start RPCServer in separate process (holds the LLMEngine). # the current process might have CUDA context, # so we need to spawn a new process context = multiprocessing.get_context("spawn") + # The Process can raise an exception during startup, which may + # not actually result in an exitcode being reported. As a result + # we use a shared variable to communicate the information. + engine_alive = multiprocessing.Value('b', True, lock=False) engine_process = context.Process(target=run_mp_engine, args=(engine_args, UsageContext.OPENAI_API_SERVER, - ipc_path, - load_in_low_bit)) + ipc_path, load_in_low_bit, engine_alive)) engine_process.start() - logger.info("Started engine process with PID %d", engine_process.pid) + engine_pid = engine_process.pid + assert engine_pid is not None, "Engine process failed to start." + logger.info("Started engine process with PID %d", engine_pid) + + def _cleanup_ipc_path(): + socket_path = ipc_path.replace("ipc://", "") + if os.path.exists(socket_path): + os.remove(socket_path) + + # Ensure we clean up the local IPC socket file on exit. + atexit.register(_cleanup_ipc_path) # Build RPCClient, which conforms to EngineClient Protocol. - # NOTE: Actually, this is not true yet. We still need to support - # embedding models via RPC (see TODO above) engine_config = engine_args.create_engine_config() - mp_engine_client = MQLLMEngineClient(ipc_path, engine_config) - + build_client = partial(MQLLMEngineClient, ipc_path, engine_config, + engine_pid) + mq_engine_client = await asyncio.get_running_loop().run_in_executor( + None, build_client) try: while True: try: - await mp_engine_client.setup() + await mq_engine_client.setup() break except TimeoutError: - if not engine_process.is_alive(): + if (not engine_process.is_alive() + or not engine_alive.value): raise RuntimeError( - "Engine process failed to start") from None + "Engine process failed to start. See stack " + "trace for the root cause.") from None - yield mp_engine_client # type: ignore[misc] + yield mq_engine_client # type: ignore[misc] finally: # Ensure rpc server process was terminated engine_process.terminate() # Close all open connections to the backend - mp_engine_client.close() + mq_engine_client.close() # Wait for engine process to join engine_process.join(4) @@ -230,8 +261,8 @@ def mount_metrics(app: FastAPI): prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) if prometheus_multiproc_dir_path is not None: - logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", - prometheus_multiproc_dir_path) + logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", + prometheus_multiproc_dir_path) registry = CollectorRegistry() multiprocess.MultiProcessCollector(registry) @@ -246,22 +277,35 @@ def mount_metrics(app: FastAPI): app.routes.append(metrics_route) -def chat(request: Request) -> OpenAIServingChat: +def base(request: Request) -> OpenAIServing: + # Reuse the existing instance + return tokenization(request) + + +def chat(request: Request) -> Optional[OpenAIServingChat]: return request.app.state.openai_serving_chat -def completion(request: Request) -> OpenAIServingCompletion: +def completion(request: Request) -> Optional[OpenAIServingCompletion]: return request.app.state.openai_serving_completion -def tokenization(request: Request) -> OpenAIServingTokenization: - return request.app.state.openai_serving_tokenization +def pooling(request: Request) -> Optional[OpenAIServingPooling]: + return request.app.state.openai_serving_pooling -def embedding(request: Request) -> OpenAIServingEmbedding: +def embedding(request: Request) -> Optional[OpenAIServingEmbedding]: return request.app.state.openai_serving_embedding +def score(request: Request) -> Optional[OpenAIServingScores]: + return request.app.state.openai_serving_scores + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client @@ -274,8 +318,11 @@ async def health(raw_request: Request) -> Response: @router.post("/tokenize") +@with_cancellation async def tokenize(request: TokenizeRequest, raw_request: Request): - generator = await tokenization(raw_request).create_tokenize(request) + handler = tokenization(raw_request) + + generator = await handler.create_tokenize(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -286,8 +333,11 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): @router.post("/detokenize") +@with_cancellation async def detokenize(request: DetokenizeRequest, raw_request: Request): - generator = await tokenization(raw_request).create_detokenize(request) + handler = tokenization(raw_request) + + generator = await handler.create_detokenize(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -299,7 +349,9 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): @router.get("/v1/models") async def show_available_models(raw_request: Request): - models = await completion(raw_request).show_available_models() + handler = base(raw_request) + + models = await handler.show_available_models() return JSONResponse(content=models.model_dump()) @@ -310,11 +362,15 @@ async def show_version(): @router.post("/v1/chat/completions") +@with_cancellation async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): + handler = chat(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Chat Completions API") - generator = await chat(raw_request).create_chat_completion( - request, raw_request) + generator = await handler.create_chat_completion(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -327,9 +383,14 @@ async def create_chat_completion(request: ChatCompletionRequest, @router.post("/v1/completions") +@with_cancellation async def create_completion(request: CompletionRequest, raw_request: Request): - generator = await completion(raw_request).create_completion( - request, raw_request) + handler = completion(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Completions API") + + generator = await handler.create_completion(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -340,9 +401,40 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @router.post("/v1/embeddings") +@with_cancellation async def create_embedding(request: EmbeddingRequest, raw_request: Request): - generator = await embedding(raw_request).create_embedding( - request, raw_request) + handler = embedding(raw_request) + if handler is None: + fallback_handler = pooling(raw_request) + if fallback_handler is None: + return base(raw_request).create_error_response( + message="The model does not support Embeddings API") + + logger.warning( + "Embeddings API will become exclusive to embedding models " + "in a future release. To return the hidden states directly, " + "use the Pooling API (`/pooling`) instead.") + + res = await fallback_handler.create_pooling(request, raw_request) + if isinstance(res, PoolingResponse): + generator = EmbeddingResponse( + id=res.id, + object=res.object, + created=res.created, + model=res.model, + data=[ + EmbeddingResponseData( + index=d.index, + embedding=d.data, # type: ignore + ) for d in res.data + ], + usage=res.usage, + ) + else: + generator = res + else: + generator = await handler.create_embedding(request, raw_request) + if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -352,6 +444,52 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): assert_never(generator) +@router.post("/pooling") +@with_cancellation +async def create_pooling(request: PoolingRequest, raw_request: Request): + handler = pooling(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Pooling API") + + generator = await handler.create_pooling(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, PoolingResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/score") +@with_cancellation +async def create_score(request: ScoreRequest, raw_request: Request): + handler = score(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Score API") + + generator = await handler.create_score(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, ScoreResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/v1/score") +@with_cancellation +async def create_score_v1(request: ScoreRequest, raw_request: Request): + logger.warning( + "To indicate that Score API is not part of standard OpenAI API, we " + "have moved it to `/score`. Please update your client accordingly.") + + return await create_score(request, raw_request) + + if envs.VLLM_TORCH_PROFILER_DIR: logger.warning( "Torch Profiler is enabled in the API server. This should ONLY be " @@ -380,30 +518,26 @@ async def stop_profile(raw_request: Request): @router.post("/v1/load_lora_adapter") async def load_lora_adapter(request: LoadLoraAdapterRequest, raw_request: Request): - response = await chat(raw_request).load_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) - - response = await completion(raw_request).load_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) + for route in [chat, completion, embedding]: + handler = route(raw_request) + if handler is not None: + response = await handler.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) return Response(status_code=200, content=response) @router.post("/v1/unload_lora_adapter") async def unload_lora_adapter(request: UnloadLoraAdapterRequest, raw_request: Request): - response = await chat(raw_request).unload_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) - - response = await completion(raw_request).unload_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) + for route in [chat, completion, embedding]: + handler = route(raw_request) + if handler is not None: + response = await handler.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) return Response(status_code=200, content=response) @@ -431,8 +565,9 @@ def build_app(args: Namespace) -> FastAPI: @app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): - chat = app.state.openai_serving_chat - err = chat.create_error_response(message=str(exc)) + err = ErrorResponse(message=str(exc), + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST) return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) @@ -440,16 +575,31 @@ async def validation_exception_handler(_, exc): @app.middleware("http") async def authentication(request: Request, call_next): - root_path = "" if args.root_path is None else args.root_path if request.method == "OPTIONS": return await call_next(request) - if not request.url.path.startswith(f"{root_path}/v1"): + url_path = request.url.path + if app.root_path and url_path.startswith(app.root_path): + url_path = url_path[len(app.root_path):] + if not url_path.startswith("/v1"): return await call_next(request) if request.headers.get("Authorization") != "Bearer " + token: return JSONResponse(content={"error": "Unauthorized"}, status_code=401) return await call_next(request) + if args.enable_request_id_headers: + logger.warning( + "CAUTION: Enabling X-Request-Id headers in the API Server. " + "This can harm performance at high QPS.") + + @app.middleware("http") + async def add_request_id(request: Request, call_next): + request_id = request.headers.get( + "X-Request-Id") or uuid.uuid4().hex + response = await call_next(request) + response.headers["X-Request-Id"] = request_id + return response + for middleware in args.middleware: module_path, object_name = middleware.rsplit(".", 1) imported = getattr(importlib.import_module(module_path), object_name) @@ -488,6 +638,9 @@ def init_app_state( state.engine_client = engine_client state.log_stats = not args.disable_log_stats + resolved_chat_template = load_chat_template(args.chat_template) + logger.info("Using supplied chat template:\n%s", resolved_chat_template) + state.openai_serving_chat = OpenAIServingChat( engine_client, model_config, @@ -496,10 +649,13 @@ def init_app_state( lora_modules=args.lora_modules, prompt_adapters=args.prompt_adapters, request_logger=request_logger, - chat_template=args.chat_template, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, - tool_parser=args.tool_call_parser) + tool_parser=args.tool_call_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + ) if model_config.runner_type == "generate" else None state.openai_serving_completion = OpenAIServingCompletion( engine_client, model_config, @@ -508,29 +664,74 @@ def init_app_state( prompt_adapters=args.prompt_adapters, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, - ) + ) if model_config.runner_type == "generate" else None + state.openai_serving_pooling = OpenAIServingPooling( + engine_client, + model_config, + base_model_paths, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + ) if model_config.runner_type == "pooling" else None state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, model_config, base_model_paths, request_logger=request_logger, - ) + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + ) if model_config.task == "embed" else None + state.openai_serving_scores = OpenAIServingScores( + engine_client, + model_config, + base_model_paths, + request_logger=request_logger + ) if model_config.task == "score" else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, base_model_paths, lora_modules=args.lora_modules, request_logger=request_logger, - chat_template=args.chat_template, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, ) +def create_server_socket(addr: Tuple[str, int]) -> socket.socket: + family = socket.AF_INET + if is_valid_ipv6_address(addr[0]): + family = socket.AF_INET6 + + sock = socket.socket(family=family, type=socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(addr) + + return sock + + async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - temp_socket.bind(("", args.port)) + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + valide_tool_parses = ToolParserManager.tool_parsers.keys() + if args.enable_auto_tool_choice \ + and args.tool_call_parser not in valide_tool_parses: + raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " + f"(chose from {{ {','.join(valide_tool_parses)} }})") + + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + sock_addr = (args.host or "", args.port) + sock = create_server_socket(sock_addr) + + # workaround to avoid footguns where uvicorn drops requests with too + # many concurrent requests active + set_ulimit() def signal_handler(*_) -> None: # Interrupt server on sigterm while initializing @@ -544,8 +745,6 @@ def signal_handler(*_) -> None: model_config = await engine_client.get_model_config() init_app_state(engine_client, model_config, app.state, args) - temp_socket.close() - shutdown_task = await serve_http( app, host=args.host, @@ -562,13 +761,18 @@ def signal_handler(*_) -> None: # NB: Await server shutdown only after the backend context is exited await shutdown_task + sock.close() + if __name__ == "__main__": # NOTE(simon): # This section should be in sync with vllm/scripts.py for CLI entrypoints. + logger.warning("Warning: Please use `ipex_llm.vllm.xpu.entrypoints.openai.api_server` " + "instead of `vllm.entrypoints.openai.api_server` to start the API server") parser = FlexibleArgumentParser( description="vLLM OpenAI-Compatible RESTful API server.") parser = make_arg_parser(parser) args = parser.parse_args() + validate_parsed_serve_args(args) uvloop.run(run_server(args)) diff --git a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py index 4110f11b483..b7cdb58541f 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py +++ b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py @@ -7,11 +7,14 @@ import argparse import json import ssl -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Union, get_args from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, + validate_chat_template) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, PromptAdapterPath) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.utils import FlexibleArgumentParser @@ -27,7 +30,7 @@ def __call__( if values is None: values = [] if isinstance(values, str): - raise TypeError("Expected values to be a list") # noqa + raise TypeError("Expected values to be a list") lora_list: List[LoRAModulePath] = [] for item in values: @@ -63,7 +66,7 @@ def __call__( if values is None: values = [] if isinstance(values, str): - raise TypeError("Expected values to be a list") # noqa + raise TypeError("Expected values to be a list") adapter_list: List[PromptAdapterPath] = [] for item in values: @@ -130,10 +133,23 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="The file path to the chat template, " "or the template in single-line form " "for the specified model") + parser.add_argument( + '--chat-template-content-format', + type=str, + default="auto", + choices=get_args(ChatTemplateContentFormatOption), + help='The format to render message content within a chat template.' + '\n\n' + '* "string" will render the content as a string. ' + 'Example: "Hello World"\n' + '* "openai" will render the content as a list of dictionaries, ' + 'similar to OpenAI schema. ' + 'Example: [{"type": "text", "text": "Hello world!"}]') parser.add_argument("--response-role", type=nullable_str, default="assistant", - help="The role name to return if `request.add_generation_prompt=true`.") + help="The role name to return if " + "`request.add_generation_prompt=true`.") parser.add_argument("--ssl-keyfile", type=nullable_str, default=None, @@ -180,28 +196,39 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action="store_true", help="If specified, will run the OpenAI frontend server in the same " "process as the model serving engine.") - + parser.add_argument( + "--enable-request-id-headers", + action="store_true", + help="If specified, API server will add X-Request-Id header to " + "responses. Caution: this hurts performance at high QPS.") parser.add_argument( "--enable-auto-tool-choice", action="store_true", default=False, - help="Enable auto tool choice for supported models. Use --tool-call-parser" - "to specify which parser to use") + help= + "Enable auto tool choice for supported models. Use --tool-call-parser" + " to specify which parser to use") + valid_tool_parsers = ToolParserManager.tool_parsers.keys() parser.add_argument( "--tool-call-parser", type=str, - choices=["mistral", "hermes"], + metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in " + "--tool-parser-plugin", default=None, - help="Select the tool call parser depending on the model that you're using." + help= + "Select the tool call parser depending on the model that you're using." " This is used to parse the model-generated tool call into OpenAI API " "format. Required for --enable-auto-tool-choice.") parser.add_argument( - "--load-in-low-bit", + "--tool-parser-plugin", type=str, - default="sym_int4", - help="Low-bit quantization for IPEX-LLM models") + default="", + help= + "Special the tool parser plugin write to parse the model-generated tool" + " into OpenAI API format, the name register in this plugin can be used " + "in --tool-call-parser.") parser = AsyncEngineArgs.add_cli_args(parser) @@ -218,10 +245,35 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=False, help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint" ) + parser.add_argument( + "--enable-prompt-tokens-details", + action='store_true', + default=False, + help="If set to True, enable prompt_tokens_details in usage.") + + parser.add_argument( + "--load-in-low-bit", + type=str, + default="sym_int4", + help="Low-bit quantization for IPEX-LLM models") return parser +def validate_parsed_serve_args(args: argparse.Namespace): + """Quick checks for model serve args that raise prior to loading.""" + if hasattr(args, "subparser") and args.subparser != "serve": + return + + # Ensure that the chat template is valid; raises if it likely isn't + validate_chat_template(args.chat_template) + + # Enable auto tool needs a tool call parser to be valid + if args.enable_auto_tool_choice and not args.tool_call_parser: + raise TypeError("Error: --enable-auto-tool-choice requires " + "--tool-call-parser") + + def create_parser_for_docs() -> FlexibleArgumentParser: parser_for_docs = FlexibleArgumentParser( prog="-m vllm.entrypoints.openai.api_server")