-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1] LoRA Support #10957
base: main
Are you sure you want to change the base?
[V1] LoRA Support #10957
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
tokenizer_name=tokenizer_name, | ||
tokenizer_mode=tokenizer_mode, | ||
trust_remote_code=trust_remote_code, | ||
revision=revision) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vllm/v1/worker/gpu_model_runner.py
Outdated
@@ -602,269 +633,3 @@ def _get_padded_batch_size(self, batch_size: int) -> Optional[int]: | |||
if batch_size <= size: | |||
return size | |||
return None | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor : Moved CachedRequestState and InputBatch to input_batch.py. It looked like a good refactor to reduce file-size. In this PR it lets both gpu_model_runner.py
and lora_model_runner_mixin.py
import these datastructures from InputBatch.
vllm/v1/worker/input_batch.py
Outdated
max_num_logprobs=self.max_num_logprobs, | ||
) | ||
|
||
def make_lora_inputs(self, num_scheduled_tokens: np.array) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added for LoRA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this! Left a few early comments. Will look into more details later.
vllm/v1/core/scheduler.py
Outdated
if self.lora_config: | ||
requested_loras = \ | ||
set(req.lora_request.lora_int_id \ | ||
for req in scheduled_running_reqs \ | ||
if req.lora_request and \ | ||
req.lora_request.lora_int_id > 0) | ||
assert len(requested_loras) <= self.lora_config.max_loras |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we cache this state and incrementally update it whenever new request joins or finishes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I explored this a bit. Tracking the additions and deletions to the running queue in the current code is hard. The updates happen in more than one place (with new requests, finish requests and requests moving between running to preempted state and back). one way is to replace the append/remove/pop with
self.running.<operation>()
if lora_config:
update_active_loras()
A better way is to subclass List and after any Create, Update, Delete operation we can update the active LoRAs. This is a considerable change. I believe we can do this after some profiling to see how bad this code is.
For the moment, I think this localized update is nicer as it doesn't introduce a bunch of if self.lora_config
s .
Is there a better way I am missing ?
vllm/v1/worker/input_batch.py
Outdated
req_lora_mapping = self.request_lora_mapping[:self.num_reqs] | ||
prompt_lora_mapping = tuple(req_lora_mapping) | ||
token_lora_mapping = tuple( | ||
req_lora_mapping.repeat(num_scheduled_tokens)) | ||
|
||
active_lora_ids: set[int] = set(np.unique(req_lora_mapping)) | ||
active_lora_requests: set[LoRARequest] = \ | ||
set({lr for lr in self.lora_requests \ | ||
if lr.lora_int_id in active_lora_ids}) | ||
# Update lora requests | ||
self.lora_requests = active_lora_requests | ||
|
||
return prompt_lora_mapping, token_lora_mapping, self.lora_requests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this work with tunica kernels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use the punica SGMV kernel always (as set in
lora_mapping = LoRAMapping(token_lora_mapping, |
The SGMV kernel codepath merges the sequences that have the same lora-id together in
Line 28 in 7406274
def compute_meta( |
I'll profile with both SGMV and BGMV kernels and choose the best. For now, SGMV looked like a good default/placeholder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding V0 LoRA,SGMV implements group gemm, which provides better performance for prefill stage . BGMV implements group gemv, which is better optimized for decoding stage . If only one can be chosen, SGMV is likely more suitable.
d21df49
to
797dab2
Compare
This pull request has merge conflicts that must be resolved before it can be |
d4d70cc
to
550da53
Compare
51ef92a
to
3200ed4
Compare
This pull request has merge conflicts that must be resolved before it can be |
3200ed4
to
48e9185
Compare
logits = lm_head.linear_method.apply(lm_head, | ||
hidden_states, | ||
bias=embedding_bias) | ||
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor : introduce _gather_logits()
that LogitsProcessorWithLoRA
also uses.
return [request.lora_request.lora_int_id] | ||
|
||
|
||
def generate_block_hash_extra_keys( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor for using prefix caching with LoRA.
del hidden_states, logits | ||
self.encoder_cache.clear() | ||
# For profile, have maximum num_reqs and that collectively have | ||
# maximum num_tokens. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setup num_scheduled_tokens
for initializing LoRA for profile_run. @ywang96 will this change interfere with the multi modal setup above ? Can you point me to a test / command that I should confirm that it works ? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bump.
I'd like some review on this part please.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
v1/core LGTM
if self.lora_config: | ||
requested_loras = set( | ||
req.lora_request.lora_int_id for req in scheduled_running_reqs | ||
if req.lora_request and req.lora_request.lora_int_id > 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ooc, why LoRA ID 0 is reserved?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a V0 requirement. LoRA ID 0 is reserved for requests without LoRA.
reference :
Line 469 in 23c1b10
def lora_int_id(self) -> int: |
it is used as,
vllm/tests/lora/test_llama_tp.py
Line 58 in 23c1b10
def generate_and_test(llm, sql_lora_files): |
The requirement is plumbed down to the kernel.
vllm/vllm/lora/punica_wrapper/utils.py
Line 93 in 23c1b10
prompt_mapping: List[int] = [ |
vllm/vllm/lora/ops/sgmv_shrink.py
Line 55 in 23c1b10
if lora_index == -1: |
I believe this is an implementation detail and can be moved down to be fully handled at the Kernel level. However, I am not sure if reserving LoRA ID 0 is a norm with LoRA users.
@jeejeelee any comments ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah now I remembered. This does introduce some confusions to me in the past so it'd be better to change it, but not necessary in this PR.
This pull request has merge conflicts that must be resolved before it can be |
d04d56d
to
4fc158c
Compare
tests/lora/test_minicpmv.py
Outdated
# test in a package | ||
pass | ||
|
||
|
||
@pytest.mark.xfail( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minicpmv does not support v1 yet, see:https://docs.vllm.ai/en/latest/models/supported_models.html#id3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Thanks for the call out 👍 . I was hoping to catch these errors when the PR goes /ready
.
vllm/v1/core/kv_cache_utils.py
Outdated
def need_extra_keys(request: Request) -> bool: | ||
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes | ||
if mm_positions and len(mm_positions) != len(mm_hashes): | ||
raise ValueError( | ||
"The number of multi-modal positions and hashes must match.") | ||
|
||
return bool(mm_positions) or (request.lora_request is not None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just help add the docstring and comments.
def need_extra_keys(request: Request) -> bool:
"""Check whether the blocks allocated to this request need extra hash keys.
Args:
request: The request.
Returns:
Whether the blocks allocated to this request need extra hash keys.
""""
# LoRA requests need to include LoRA ID.
if request.lora_request is not None:
return True
# Requests with MM inputs need to include MM hash.
return bool(request.mm_positions)
Also I feel the following validation logic should not be in this function.
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match.")
Maybe putting it to _gen_mm_extra_hash_keys
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I feel the following validation logic should not be in this function.
Maybe putting it to _gen_mm_extra_hash_keys?
Yes. I realize that check is already in _gen_mm_extra_hash_keys
. The one in need_extra_keys
was redundant. I have removed the check from need_extra_keys
👍
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
b57ca04
to
5fc59ef
Compare
Changes:
Benchmarks:
Machine : 1xA100
V1
Throughput: 2.42 requests/s, 1225.95 total tokens/s, 628.29 output tokens/s
V0
Throughput: 5.95 requests/s, 3021.90 total tokens/s, 1548.71 output tokens/s
The performance gap between V0 and V1 is due to CUDA Graphs. Refer to benchmarks in reference PR #11613 .