Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] LoRA Support #10957

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Dec 6, 2024

Changes:

  • Run LoRA requests through V1
    • All LoRA functionality is put in a LoRAGPUModelRunnerMixin class that the GPUModelRunner inherits.
    • Changes to GPUModelRunner for loading lora models and setting active loras before every run.
  • Prefix caching
    • Add lora_id as a key to prefix caching hash.
  • Scheduler:
    • Add code to track Current and Newly added LoRA requests.
  • Detokenizer:
    • Use LoRA tokenizers for LoRA requests.

Benchmarks:
Machine : 1xA100
V1

VLLM_USE_V1="1" python3 benchmarks/benchmark_throughput.py --model  meta-llama/Llama-2-7b-hf --backend vllm   --dataset ./ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 200 --max-loras 4 --max-lora-rank 8  --enable-lora --lora-path "yard1/llama-2-7b-sql-lora-test"

Throughput: 2.42 requests/s, 1225.95 total tokens/s, 628.29 output tokens/s

V0

python3 benchmarks/benchmark_throughput.py --model  meta-llama/Llama-2-7b-hf --backend vllm   --dataset ./ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 200 --max-loras 4 --max-lora-rank 8  --enable-lora --lora-path "yard1/llama-2-7b-sql-lora-test"

Throughput: 5.95 requests/s, 3021.90 total tokens/s, 1548.71 output tokens/s

The performance gap between V0 and V1 is due to CUDA Graphs. Refer to benchmarks in reference PR #11613 .

Copy link

github-actions bot commented Dec 6, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

tokenizer_name=tokenizer_name,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
revision=revision)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ywang96 @njhill small refactor to allow for per-request tokenizers.

@@ -602,269 +633,3 @@ def _get_padded_batch_size(self, batch_size: int) -> Optional[int]:
if batch_size <= size:
return size
return None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor : Moved CachedRequestState and InputBatch to input_batch.py. It looked like a good refactor to reduce file-size. In this PR it lets both gpu_model_runner.py and lora_model_runner_mixin.py import these datastructures from InputBatch.

max_num_logprobs=self.max_num_logprobs,
)

def make_lora_inputs(self, num_scheduled_tokens: np.array) \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added for LoRA

@varun-sundar-rabindranath varun-sundar-rabindranath changed the title V1 LoRA Support [V1] LoRA Support Dec 6, 2024
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this! Left a few early comments. Will look into more details later.

vllm/v1/engine/processor.py Outdated Show resolved Hide resolved
vllm/v1/core/scheduler.py Outdated Show resolved Hide resolved
vllm/v1/core/scheduler.py Outdated Show resolved Hide resolved
Comment on lines 175 to 179
if self.lora_config:
requested_loras = \
set(req.lora_request.lora_int_id \
for req in scheduled_running_reqs \
if req.lora_request and \
req.lora_request.lora_int_id > 0)
assert len(requested_loras) <= self.lora_config.max_loras
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we cache this state and incrementally update it whenever new request joins or finishes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I explored this a bit. Tracking the additions and deletions to the running queue in the current code is hard. The updates happen in more than one place (with new requests, finish requests and requests moving between running to preempted state and back). one way is to replace the append/remove/pop with

self.running.<operation>()
if lora_config: 
    update_active_loras()

A better way is to subclass List and after any Create, Update, Delete operation we can update the active LoRAs. This is a considerable change. I believe we can do this after some profiling to see how bad this code is.
For the moment, I think this localized update is nicer as it doesn't introduce a bunch of if self.lora_configs .

Is there a better way I am missing ?

vllm/v1/worker/input_batch.py Outdated Show resolved Hide resolved
vllm/v1/worker/input_batch.py Outdated Show resolved Hide resolved
vllm/v1/worker/input_batch.py Outdated Show resolved Hide resolved
vllm/v1/worker/input_batch.py Outdated Show resolved Hide resolved
Comment on lines 290 to 302
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
prompt_lora_mapping = tuple(req_lora_mapping)
token_lora_mapping = tuple(
req_lora_mapping.repeat(num_scheduled_tokens))

active_lora_ids: set[int] = set(np.unique(req_lora_mapping))
active_lora_requests: set[LoRARequest] = \
set({lr for lr in self.lora_requests \
if lr.lora_int_id in active_lora_ids})
# Update lora requests
self.lora_requests = active_lora_requests

return prompt_lora_mapping, token_lora_mapping, self.lora_requests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this work with tunica kernels?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use the punica SGMV kernel always (as set in

lora_mapping = LoRAMapping(token_lora_mapping,
). Internally the kernels launch a thread-block set for each request separately. So, as long as the prompt_lora_mapping is correct, the kernels work correctly.

The SGMV kernel codepath merges the sequences that have the same lora-id together in

def compute_meta(
. I chose the SGMV kernel so this merging happens wherever possible.

I'll profile with both SGMV and BGMV kernels and choose the best. For now, SGMV looked like a good default/placeholder.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding V0 LoRA,SGMV implements group gemm, which provides better performance for prefill stage . BGMV implements group gemv, which is better optimized for decoding stage . If only one can be chosen, SGMV is likely more suitable.

@varun-sundar-rabindranath varun-sundar-rabindranath force-pushed the varun/v1-lora-support-attempt-2 branch from d21df49 to 797dab2 Compare December 17, 2024 03:47
Copy link

mergify bot commented Dec 17, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @varun-sundar-rabindranath.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link

mergify bot commented Dec 31, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @varun-sundar-rabindranath.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 31, 2024
@varun-sundar-rabindranath varun-sundar-rabindranath force-pushed the varun/v1-lora-support-attempt-2 branch from 3200ed4 to 48e9185 Compare December 31, 2024 01:53
@mergify mergify bot removed the needs-rebase label Dec 31, 2024
@varun-sundar-rabindranath varun-sundar-rabindranath marked this pull request as ready for review December 31, 2024 01:54
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor : introduce _gather_logits() that LogitsProcessorWithLoRA also uses.

return [request.lora_request.lora_int_id]


def generate_block_hash_extra_keys(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor for using prefix caching with LoRA.

del hidden_states, logits
self.encoder_cache.clear()
# For profile, have maximum num_reqs and that collectively have
# maximum num_tokens.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setup num_scheduled_tokens for initializing LoRA for profile_run. @ywang96 will this change interfere with the multi modal setup above ? Can you point me to a test / command that I should confirm that it works ? Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bump.

I'd like some review on this part please.

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v1/core LGTM

vllm/v1/core/kv_cache_utils.py Outdated Show resolved Hide resolved
vllm/v1/core/kv_cache_utils.py Outdated Show resolved Hide resolved
vllm/v1/core/kv_cache_utils.py Outdated Show resolved Hide resolved
vllm/v1/core/kv_cache_utils.py Outdated Show resolved Hide resolved
if self.lora_config:
requested_loras = set(
req.lora_request.lora_int_id for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooc, why LoRA ID 0 is reserved?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a V0 requirement. LoRA ID 0 is reserved for requests without LoRA.
reference :

def lora_int_id(self) -> int:

it is used as,
def generate_and_test(llm, sql_lora_files):

The requirement is plumbed down to the kernel.

prompt_mapping: List[int] = [
. 0 LoRA ID is translated to -1 so the LoRA kernels ignore them. e.g.
if lora_index == -1:

I believe this is an implementation detail and can be moved down to be fully handled at the Kernel level. However, I am not sure if reserving LoRA ID 0 is a norm with LoRA users.

@jeejeelee any comments ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah now I remembered. This does introduce some confusions to me in the past so it'd be better to change it, but not necessary in this PR.

Copy link

mergify bot commented Jan 4, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @varun-sundar-rabindranath.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 4, 2025
@varun-sundar-rabindranath varun-sundar-rabindranath force-pushed the varun/v1-lora-support-attempt-2 branch from d04d56d to 4fc158c Compare January 4, 2025 05:42
@mergify mergify bot removed the needs-rebase label Jan 4, 2025
# test in a package
pass


@pytest.mark.xfail(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks for the call out 👍 . I was hoping to catch these errors when the PR goes /ready.

Comment on lines 167 to 173
def need_extra_keys(request: Request) -> bool:
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match.")

return bool(mm_positions) or (request.lora_request is not None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just help add the docstring and comments.

def need_extra_keys(request: Request) -> bool:
    """Check whether the blocks allocated to this request need extra hash keys.
    
    Args:
        request: The request.
        
    Returns:
        Whether the blocks allocated to this request need extra hash keys.
    """"
    # LoRA requests need to include LoRA ID.
    if request.lora_request is not None:
        return True
    
    # Requests with MM inputs need to include MM hash.
    return bool(request.mm_positions)

Also I feel the following validation logic should not be in this function.

    if mm_positions and len(mm_positions) != len(mm_hashes):
        raise ValueError(
            "The number of multi-modal positions and hashes must match.")

Maybe putting it to _gen_mm_extra_hash_keys?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I feel the following validation logic should not be in this function.
Maybe putting it to _gen_mm_extra_hash_keys?

Yes. I realize that check is already in _gen_mm_extra_hash_keys. The one in need_extra_keys was redundant. I have removed the check from need_extra_keys 👍

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
@varun-sundar-rabindranath varun-sundar-rabindranath force-pushed the varun/v1-lora-support-attempt-2 branch from b57ca04 to 5fc59ef Compare January 10, 2025 09:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants