Skip to content

Commit

Permalink
Add lora support
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
  • Loading branch information
Varun Sundar Rabindranath committed Jan 4, 2025
1 parent e5d7ed0 commit 4fc158c
Show file tree
Hide file tree
Showing 18 changed files with 451 additions and 62 deletions.
17 changes: 17 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,20 @@ def get_model_patched(**kwargs):
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
model_runner.model)


@pytest.fixture(params=[False, True])
def run_with_both_engines_lora(request):
# Automatically runs tests twice, once with V1 and once without
use_v1 = request.param
# Tests decorated with `@skip_v1` are only run without v1
skip_v1 = request.node.get_closest_marker("skip_v1")

if use_v1:
if skip_v1:
pytest.skip("Skipping test on vllm V1")
with patch('vllm.envs.VLLM_USE_V1', True):
yield
else:
with patch('vllm.envs.VLLM_USE_V1', False):
yield
8 changes: 8 additions & 0 deletions tests/lora/test_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


def test_baichuan_lora(baichuan_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
Expand Down
10 changes: 10 additions & 0 deletions tests/lora/test_chatglm3_tp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List

import pytest

import vllm
from tests.utils import fork_new_process_for_each_test
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -45,6 +47,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@fork_new_process_for_each_test
def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
Expand Down
8 changes: 8 additions & 0 deletions tests/lora/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.xfail(current_platform.is_rocm(),
reason="There can be output mismatch on ROCm")
def test_gemma_lora(gemma_lora_files):
Expand Down
9 changes: 9 additions & 0 deletions tests/lora/test_llama_tp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List

import pytest
import ray

import vllm
Expand Down Expand Up @@ -71,6 +72,14 @@ def generate_and_test(llm, sql_lora_files):
print("removing lora")


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@fork_new_process_for_each_test
def test_llama_lora(sql_lora_files):

Expand Down
8 changes: 8 additions & 0 deletions tests/lora/test_lora_bias_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.parametrize("lora_bias", [True])
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):
Expand Down
8 changes: 8 additions & 0 deletions tests/lora/test_minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.xfail(
current_platform.is_rocm(),
reason="MiniCPM-V dependency xformers incompatible with ROCm")
Expand Down
10 changes: 10 additions & 0 deletions tests/lora/test_phi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List

import pytest

import vllm
from vllm.lora.request import LoRARequest

Expand Down Expand Up @@ -46,6 +48,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


def test_phi2_lora(phi2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
Expand Down
8 changes: 8 additions & 0 deletions tests/lora/test_quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def format_prompt_tuples(prompt):
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", [1])
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
Expand Down
8 changes: 5 additions & 3 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
tensor_model_parallel_all_reduce)
from vllm.distributed.utils import divide
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -1028,7 +1027,10 @@ def _get_logits(
logits = lm_head.linear_method.apply(lm_head, hidden_states)
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)

# Gather logits for TP
logits = self.base_layer._gather_logits(logits)

if logits is None:
return None

Expand Down
29 changes: 18 additions & 11 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __init__(self,
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.

self.use_gather = not current_platform.is_tpu(
) and not envs.VLLM_USE_V1

Expand Down Expand Up @@ -78,16 +77,8 @@ def forward(

return logits

def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
"""gather/all-gather the logits tensor across model parallel group."""
if self.use_gather:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
Expand All @@ -98,6 +89,22 @@ def _get_logits(
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)
return logits

def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)

# Gather logits for TP
logits = self._gather_logits(logits)

# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[..., :self.org_vocab_size]
Expand Down
90 changes: 68 additions & 22 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,22 @@ def get_all_free_blocks(self) -> List[KVCacheBlock]:
return ret


def generate_block_hash_extra_keys(
request: Request, start_token_idx: int, end_token_idx: int,
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
For multi-modal inputs, the extra keys are (mm_hash, start_offset) that
indicate a mm input contained in the block and its starting offset in
the block tokens.
def need_extra_keys(request: Request) -> bool:
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match.")

return bool(mm_positions) or (request.lora_request is not None)


def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
end_token_idx: int,
start_mm_idx: int) -> Tuple[List[Any], int]:
"""Generate extra keys related to MultiModal request for block hash
computation. For multi-modal inputs, the extra keys are
(mm_hash, start_offset) that indicate a mm input contained in the
block and its starting offset in the block tokens.
Args:
request: The request object.
Expand All @@ -182,10 +190,11 @@ def generate_block_hash_extra_keys(
Returns:
A tuple of extra keys and the next multi-modal index.
"""
extra_keys: List[Any] = []

mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if not mm_positions:
return None, start_mm_idx
return extra_keys, start_mm_idx

if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
Expand All @@ -198,14 +207,13 @@ def generate_block_hash_extra_keys(
# range. This usually happens in the late prefill phase and decoding phase.
if mm_positions[-1]["offset"] + mm_positions[-1][
"length"] < start_token_idx:
return None, start_mm_idx
return extra_keys, start_mm_idx

# Support start_mm_idx == -1 to indicate the last mm input.
if start_mm_idx < 0:
assert -start_mm_idx <= len(mm_positions)
start_mm_idx = len(mm_positions) + start_mm_idx

extra_keys = []
curr_mm_idx = start_mm_idx
while mm_positions and curr_mm_idx < len(mm_positions):
assert mm_hashes[curr_mm_idx] is not None
Expand All @@ -231,7 +239,50 @@ def generate_block_hash_extra_keys(
else:
# This block has not reached the current mm input.
break
return tuple(extra_keys), curr_mm_idx
return extra_keys, curr_mm_idx


def _gen_lora_extra_hash_keys(request: Request) -> List[int]:
"""Generate extra keys related to LoRA for block hash computation.
Args:
request: The request object.
Returns:
Return LoRA id of the request if it is a LoRA request. Return empty
list otherwise.
"""
if not request.lora_request:
return []
return [request.lora_request.lora_int_id]


def generate_block_hash_extra_keys(
request: Request, start_token_idx: int, end_token_idx: int,
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block.
Returns:
A tuple of extra keys and the next multi-modal index.
"""
mm_extra_keys: List[Any]
mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys(
request, start_token_idx, end_token_idx, start_mm_idx)
lora_extra_keys: List[int] = _gen_lora_extra_hash_keys(request)

extra_keys: List[Any] = lora_extra_keys + mm_extra_keys

if not extra_keys:
return None, new_start_mm_idx

return tuple(extra_keys), new_start_mm_idx


def hash_block_tokens(
Expand Down Expand Up @@ -274,14 +325,9 @@ def hash_request_tokens(block_size: int,
The list of computed hash values.
"""
token_ids = request.all_token_ids
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match.")

# TODO: Extend this to support other features such as LoRA.
need_extra_keys = bool(mm_positions)
extra_keys = None
req_need_extra_keys = need_extra_keys(request)
req_extra_keys = None
curr_mm_idx = 0

ret = []
Expand All @@ -294,12 +340,12 @@ def hash_request_tokens(block_size: int,
break

# Add extra keys if the block is a multi-modal block.
if need_extra_keys:
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
if req_need_extra_keys:
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start, end, curr_mm_idx)

block_hash = hash_block_tokens(parent_block_hash_value,
block_token_ids, extra_keys)
block_token_ids, req_extra_keys)
ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value
return ret
Loading

0 comments on commit 4fc158c

Please sign in to comment.