Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Some improvements for Xavier #2777

Merged
merged 7 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 9 additions & 17 deletions doc/source/locale/zh_CN/LC_MESSAGES/user_guide/vllm_enhancement.po
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: Xinference \n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2025-01-10 14:44+0800\n"
"POT-Creation-Date: 2025-01-23 14:46+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <[email protected]>\n"
Expand All @@ -31,9 +31,10 @@ msgid ""
" instances. This allows KV cache computed by other replicas to be "
"directly reused, avoiding redundant computations."
msgstr ""
"对于长文档查询和多轮对话等场景,在推理预填充阶段的计算可能特别繁重,这会影响整体吞吐量和单次推理的延迟。"
"Xinference 通过引入 ``Xavier`` 框架来增强 vllm 引擎,支持在多个 vllm 实例之间共享 KV 缓存。"
"这使得其他副本计算出的 KV 缓存可以被直接重用,从而避免了冗余计算。"
"对于长文档查询和多轮对话等场景,在推理预填充阶段的计算可能特别繁重,这会"
"影响整体吞吐量和单次推理的延迟。Xinference 通过引入 ``Xavier`` 框架来增强"
" vllm 引擎,支持在多个 vllm 实例之间共享 KV 缓存。这使得其他副本计算出的 "
"KV 缓存可以被直接重用,从而避免了冗余计算。"

#: ../../source/user_guide/vllm_enhancement.rst:15
msgid "Usage"
Expand All @@ -43,31 +44,22 @@ msgstr "使用"
msgid ""
"Simply add the parameter ``enable_xavier=True`` when starting the vllm "
"model."
msgstr ""
"启动 vllm 模型时设置选项 ``enable_xavier=True`` 即可。"
msgstr "启动 vllm 模型时设置选项 ``enable_xavier=True`` 即可。"

#: ../../source/user_guide/vllm_enhancement.rst:20
msgid "Limitations"
msgstr "限制"

#: ../../source/user_guide/vllm_enhancement.rst:21
msgid "Xavier requires vllm version >= ``0.6.5``."
msgstr ""
"Xavier 要求 vllm 版本不低于 ``0.6.5`` 。"
msgstr "Xavier 要求 vllm 版本不低于 ``0.6.5`` 。"

#: ../../source/user_guide/vllm_enhancement.rst:22
msgid ""
"Xavier is currently not compatible with model reloading after CUDA OOM in"
" Xinference. (it will be supported in the future)"
msgstr ""
"目前 Xavier 与 Xinference 中模型 CUDA OOM 后的重新拉起特性不兼容(未来将解决此问题)。"

#: ../../source/user_guide/vllm_enhancement.rst:23
msgid ""
"Due to the underlying communication not recognizing ``0.0.0.0``, the "
"actual IP address needs to be passed when starting Xinference, for "
"example: ``xinference-local -H 192.168.xx.xx``."
msgstr ""
"由于底层通信无法识别 ``0.0.0.0`` 地址,启动 xinference 时需要配置实际的 IP 地址,"
"例如:``xinference-local -H 192.168.xx.xx`` 。"
"由于底层通信无法识别 ``0.0.0.0`` 地址,启动 xinference 时需要配置实际的 "
"IP 地址,例如:``xinference-local -H 192.168.xx.xx`` 。"

1 change: 0 additions & 1 deletion doc/source/user_guide/vllm_enhancement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,4 @@ Simply add the parameter ``enable_xavier=True`` when starting the vllm model.
Limitations
***********
* Xavier requires vllm version >= ``0.6.5``.
* Xavier is currently not compatible with model reloading after CUDA OOM in Xinference. (it will be supported in the future)
* Due to the underlying communication not recognizing ``0.0.0.0``, the actual IP address needs to be passed when starting Xinference, for example: ``xinference-local -H 192.168.xx.xx``.
2 changes: 2 additions & 0 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
List,
Optional,
Union,
no_type_check,
)

import sse_starlette.sse
Expand Down Expand Up @@ -302,6 +303,7 @@ def __repr__(self) -> str:
def decrease_serve_count(self):
self._serve_count -= 1

@no_type_check
async def start_transfer_for_vllm(self, rank_addresses: List[str]):
from ..model.llm.vllm.core import VLLMModel
from ..model.llm.vllm.xavier.transfer import TransferActor
Expand Down
137 changes: 114 additions & 23 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,12 @@ async def signal_handler():
)

from ..model.llm.vllm.xavier.block_tracker import VLLMBlockTracker
from ..model.llm.vllm.xavier.collective_manager import CollectiveManager

self._block_tracker: Optional[xo.ActorRefType[VLLMBlockTracker]] = None
self._block_tracker_mapping: Dict[str, xo.ActorRefType[VLLMBlockTracker]] = {}
self._collective_manager_mapping: Dict[
str, xo.ActorRefType[CollectiveManager]
] = {}

@typing.no_type_check
async def get_cluster_device_info(self, detailed: bool = False) -> List:
Expand Down Expand Up @@ -960,26 +964,40 @@ async def launch_builtin_model(
]:
raise ValueError("Tensorizer is not supported for %s." % model_name)

if model_uid is None:
model_uid = self._gen_model_uid(model_name)

# Xavier-related
enable_xavier: bool = (
bool(kwargs.pop("enable_xavier", False))
and model_engine is not None
and model_engine.lower() == "vllm"
)
store_address = None
store_port = None
world_size = None
if enable_xavier:
if replica <= 1:
logger.warning(f"Enabling xavier when `replica<=1` is meaningless.")
enable_xavier = False
else:
from ..model.llm.vllm.xavier.block_tracker import VLLMBlockTracker
from ..model.llm.vllm.xavier.collective_manager import CollectiveManager

self._block_tracker = await xo.create_actor(
self._block_tracker_mapping[model_uid] = await xo.create_actor(
VLLMBlockTracker,
address=self.address,
uid=VLLMBlockTracker.default_uid(),
uid=f"{VLLMBlockTracker.default_uid()}-{model_uid}",
)

if model_uid is None:
model_uid = self._gen_model_uid(model_name)
world_size = replica + 1
logger.info(f"Going to start xavier with world size: {world_size}")
self._collective_manager_mapping[model_uid] = await xo.create_actor(
CollectiveManager,
address=self.address,
uid=f"{CollectiveManager.default_uid()}-{model_uid}",
model_uid=model_uid,
)
logger.info(f"Start collective manager for {model_uid} done.")

model_size = str(model_size_in_billions) if model_size_in_billions else ""
logger.debug(
Expand All @@ -988,13 +1006,38 @@ async def launch_builtin_model(
f"kwargs: {kwargs}"
)

async def _launch_one_model(
worker_ref, _replica_model_uid, rank: int, store_port: int
):
async def _launch_one_model(worker_ref, _replica_model_uid, rank: int):
if _replica_model_uid in self._replica_model_uid_to_worker:
raise ValueError(
f"Model is already in the model list, uid: {_replica_model_uid}"
)

nonlocal store_address
nonlocal store_port
xavier_config = (
{
"block_tracker_uid": self._block_tracker_mapping[model_uid].uid,
"block_tracker_address": self._block_tracker_mapping[
model_uid
].address,
"rank": rank,
"world_size": world_size,
"store_address": store_address,
"store_port": store_port,
}
if enable_xavier
else None
)

if enable_xavier and rank == 0:
rank0_address, _port = await worker_ref.launch_rank0_model(
_replica_model_uid, xavier_config
)
self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
store_address = rank0_address.split(":")[0]
store_port = _port
return rank0_address

replica_gpu_idx = assign_replica_gpu(_replica_model_uid, replica, gpu_idx)
nonlocal model_type

Expand All @@ -1014,37 +1057,36 @@ async def _launch_one_model(
gpu_idx=replica_gpu_idx,
download_hub=download_hub,
model_path=model_path,
xavier_config={
"block_tracker_address": self._block_tracker.address
if self._block_tracker is not None
else None,
"rank": rank,
"world_size": replica,
"store_address": self.address.split(":")[0],
"store_port": store_port,
}
if enable_xavier
else None,
xavier_config=xavier_config,
**kwargs,
)
self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
return subpool_address

async def _launch_model():
try:
store_port = xo.utils.get_next_port()
worker_refs = []
rank_addresses = []
for rank, rep_model_uid in enumerate(
for _idx, rep_model_uid in enumerate(
iter_replica_model_uid(model_uid, replica)
):
worker_ref = (
target_ip_worker_ref
if target_ip_worker_ref is not None
else await self._choose_worker()
)
if enable_xavier and _idx == 0:
"""
Start the rank 0 model actor on the worker that holds the rank 1 replica,
solely for constructing the collective communication world.
"""
_uid = model_uid + "-rank0"
rank0_address = await _launch_one_model(worker_ref, _uid, 0)
worker_refs.append((worker_ref, _uid))
rank_addresses.append(rank0_address)

subpool_address = await _launch_one_model(
worker_ref, rep_model_uid, rank, store_port
worker_ref, rep_model_uid, _idx + 1
)
worker_refs.append((worker_ref, rep_model_uid))
rank_addresses.append(subpool_address)
Expand All @@ -1054,6 +1096,7 @@ async def _launch_model():
# because the transfer actor needs all the rank addresses used for collective communication
if enable_xavier:
logger.debug(f"Init transfer component for xavier...")
collective_manager_ref = self._collective_manager_mapping[model_uid]
tasks = []
for worker_ref, rep_model_uid in worker_refs:
tasks.append(
Expand All @@ -1064,6 +1107,13 @@ async def _launch_model():
# Here you must use asyncio.gather, not a for loop,
# or you will get stuck.
await asyncio.gather(*tasks)

# init collective_manager
for idx, addr in enumerate(rank_addresses):
await collective_manager_ref.register_rank(
idx, addr, update=False
)

logger.debug(f"Init transfer component for xavier done.")
except Exception:
# terminate_model will remove the replica info.
Expand Down Expand Up @@ -1193,6 +1243,38 @@ async def _terminate_one_model(_replica_model_uid):
raise
self._model_uid_to_replica_info.pop(model_uid, None)

# clear for xavier
rank0_uid = model_uid + "-rank0"
if rank0_uid in self._replica_model_uid_to_worker:
await _terminate_one_model(rank0_uid)

collective_manager_ref = self._collective_manager_mapping.pop(model_uid, None)
if collective_manager_ref is not None:
try:
await xo.destroy_actor(collective_manager_ref)
except Exception as e:
logger.debug(
"Destroy collective_manager_ref failed, model uid: %s, error: %s",
model_uid,
e,
)
finally:
logger.debug(
f"Destroy collective_manager_ref done. model uid: {model_uid}"
)
block_tracker_ref = self._block_tracker_mapping.pop(model_uid, None)
if block_tracker_ref is not None:
try:
await xo.destroy_actor(block_tracker_ref)
except Exception as e:
logger.debug(
"Destroy block_tracker_ref failed, model uid: %s, error: %s",
model_uid,
e,
)
finally:
logger.debug(f"Destroy block_tracker_ref done. model uid: {model_uid}")

@log_async(logger=logger)
async def get_model(self, model_uid: str) -> xo.ActorRefType["ModelActor"]:
replica_info = self._model_uid_to_replica_info.get(model_uid, None)
Expand Down Expand Up @@ -1448,3 +1530,12 @@ def record_metrics(name, op, kwargs):

async def get_progress(self, request_id: str) -> float:
return await self._progress_tracker.get_progress(request_id)

async def call_collective_manager(
self, model_uid: str, func_name: str, *args, **kwargs
):
"""
Used by worker.
"""
collective_manager_ref = self._collective_manager_mapping[model_uid]
await getattr(collective_manager_ref, func_name)(*args, **kwargs)
Loading
Loading