-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Param] Recheck and update repetition penalty parameter #202
[Param] Recheck and update repetition penalty parameter #202
Conversation
vvchernov
commented
Feb 12, 2024
•
edited
Loading
edited
- Recheck that HF TGI API and vLLM approach for repetition penalty is the same
- Correct calculation of repetition penalty, add prompt tokens for it
- Clean code: rename sampling_metadata -> sampling_state anywhere
cc @sunggg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the quick action, @vvchernov! A couple questions.
@@ -372,20 +382,39 @@ def from_sampling_params( | |||
) | |||
|
|||
|
|||
def adjust_logits(logits, sampling_metadata, vocab_size): | |||
def get_bin_counts_and_mask( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment with above. Would it make sense to move this to SamplingState
prep and perform this on CPU? If the perf impact is not bad, it seems better to prepare there as we are looking for an option to make the prep process more async.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Valery talked about this in the issue, repasting some of it here:
In my point of view the good place is SamplingState, but it calculates each time for new request. It looks like it is not good thing due to in general it can be done once, but redo this requires much time. May be SamplingParams in Request should be replaced by SamplingState with some sharing for parameter n > 1.
What do you think about it?
Just now I plan to save it in SamplingParams for quick fix, but we need a better place
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, thanks @binarybana and this is great point. @vvchernov, can we follow-up about this? We can separate what needs/does not need to be updated each iteration like you raised.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, thanks @binarybana to share my ideas! I repeat it again in other words above.
can we follow-up about this? We can separate what needs/does not need to be updated each iteration like you raised.
First of all it looks like a separate task, if we need to support correct repetition penalty in high priority it is better to do it first. I see the task is redesign of API, it should be done carefully and we need more discussions in details. I can prepare separate draft of PR with some dirty code which implements my idea and we will discuss there how it should be. And I'm still aware about logits_processor. Now it looks like it was done for json mode, but in general view it does the same as sampler. It is better also rethink this moment in our redesign
@vvchernov , thanks for the quick action on this one. Can you also see what the performance impact is with and without this enabled? Ideally for sequences of ~2k input, 50 tokens under fairly loaded (5-15 VUs) output to match customers. |
|
||
expected = torch.clone(logits) | ||
for batch_size in range(batch_size): | ||
logit_bias = list_logit_bias[batch_size] | ||
for idx, val in logit_bias.items(): | ||
expected[batch_size][idx - 1] += val | ||
new_logits = adjust_logits(logits, sampling_metadata, vocab_size) | ||
new_logits = adjust_logits(logits, sampling_state, vocab_size) | ||
assert torch.allclose(expected, new_logits) | ||
|
||
|
||
def _test_penalties_checker(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add the testcase?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My colleague extends tests for sampling parameters, particularly for repetition_penalty
. See #200
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is okay to be very basic, so please add one. We've not been doing this so far but I'd like to recommend every PR to have unittests that validates the very basic functionality at least. With SLM migration, we will install the CI as well.
4c554a0
to
f439b97
Compare
sampling_params, | ||
list_past_output_tokens=past_output_tokens, | ||
dtype=dtype, | ||
dev=dev, | ||
vocab_size=vocab_size, | ||
) | ||
torch.cuda.current_stream().wait_stream(_copy_stream) | ||
return sampling_metadata | ||
return sampling_state | ||
|
||
|
||
def _test_temperature(temp=0, batch_size=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sunggg why do we have _ symbol before test? I added such symbol to functions which cannot be run using pytests, but functions from test_sampler.py does not require any compiled model and can be called under pytest
If there is no special reason, I propose to remove underscore symbol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can ask Iliya to do it in his PR (#200) due to he extends these tests now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No specific reason, I just followed the other testcase. I'm okay to remove this for pytest.
Based on the offline discussion, we decided to merge this for the sense of urgency. @vvchernov, let's add the test case in #200. |