Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Param] Recheck and update repetition penalty parameter #202

Merged
merged 9 commits into from
Feb 14, 2024

Conversation

vvchernov
Copy link

@vvchernov vvchernov commented Feb 12, 2024

  • Recheck that HF TGI API and vLLM approach for repetition penalty is the same
  • Correct calculation of repetition penalty, add prompt tokens for it
  • Clean code: rename sampling_metadata -> sampling_state anywhere

@vvchernov vvchernov marked this pull request as draft February 12, 2024 07:43
@vvchernov vvchernov marked this pull request as ready for review February 12, 2024 17:31
@vvchernov
Copy link
Author

cc @sunggg

Copy link
Member

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the quick action, @vvchernov! A couple questions.

serve/mlc_serve/engine/engine_common.py Outdated Show resolved Hide resolved
serve/mlc_serve/engine/engine_common.py Show resolved Hide resolved
@@ -372,20 +382,39 @@ def from_sampling_params(
)


def adjust_logits(logits, sampling_metadata, vocab_size):
def get_bin_counts_and_mask(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment with above. Would it make sense to move this to SamplingState prep and perform this on CPU? If the perf impact is not bad, it seems better to prepare there as we are looking for an option to make the prep process more async.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valery talked about this in the issue, repasting some of it here:

In my point of view the good place is SamplingState, but it calculates each time for new request. It looks like it is not good thing due to in general it can be done once, but redo this requires much time. May be SamplingParams in Request should be replaced by SamplingState with some sharing for parameter n > 1.
What do you think about it?
Just now I plan to save it in SamplingParams for quick fix, but we need a better place

Copy link
Member

@sunggg sunggg Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks @binarybana and this is great point. @vvchernov, can we follow-up about this? We can separate what needs/does not need to be updated each iteration like you raised.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, thanks @binarybana to share my ideas! I repeat it again in other words above.

can we follow-up about this? We can separate what needs/does not need to be updated each iteration like you raised.

First of all it looks like a separate task, if we need to support correct repetition penalty in high priority it is better to do it first. I see the task is redesign of API, it should be done carefully and we need more discussions in details. I can prepare separate draft of PR with some dirty code which implements my idea and we will discuss there how it should be. And I'm still aware about logits_processor. Now it looks like it was done for json mode, but in general view it does the same as sampler. It is better also rethink this moment in our redesign

serve/mlc_serve/model/sampler.py Show resolved Hide resolved
@binarybana
Copy link
Member

@vvchernov , thanks for the quick action on this one. Can you also see what the performance impact is with and without this enabled? Ideally for sequences of ~2k input, 50 tokens under fairly loaded (5-15 VUs) output to match customers.


expected = torch.clone(logits)
for batch_size in range(batch_size):
logit_bias = list_logit_bias[batch_size]
for idx, val in logit_bias.items():
expected[batch_size][idx - 1] += val
new_logits = adjust_logits(logits, sampling_metadata, vocab_size)
new_logits = adjust_logits(logits, sampling_state, vocab_size)
assert torch.allclose(expected, new_logits)


def _test_penalties_checker():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add the testcase?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My colleague extends tests for sampling parameters, particularly for repetition_penalty. See #200

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is okay to be very basic, so please add one. We've not been doing this so far but I'd like to recommend every PR to have unittests that validates the very basic functionality at least. With SLM migration, we will install the CI as well.

@vvchernov vvchernov force-pushed the vc/repetition_penalty branch from 4c554a0 to f439b97 Compare February 14, 2024 09:21
sampling_params,
list_past_output_tokens=past_output_tokens,
dtype=dtype,
dev=dev,
vocab_size=vocab_size,
)
torch.cuda.current_stream().wait_stream(_copy_stream)
return sampling_metadata
return sampling_state


def _test_temperature(temp=0, batch_size=1):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sunggg why do we have _ symbol before test? I added such symbol to functions which cannot be run using pytests, but functions from test_sampler.py does not require any compiled model and can be called under pytest

If there is no special reason, I propose to remove underscore symbol

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can ask Iliya to do it in his PR (#200) due to he extends these tests now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No specific reason, I just followed the other testcase. I'm okay to remove this for pytest.

@binarybana binarybana merged commit 2ebbacf into octoml:batch-serving Feb 14, 2024
1 check passed
@sunggg
Copy link
Member

sunggg commented Feb 14, 2024

Based on the offline discussion, we decided to merge this for the sense of urgency. @vvchernov, let's add the test case in #200.

@vvchernov vvchernov deleted the vc/repetition_penalty branch February 16, 2024 09:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants