From e53ba4f1c7a89fcaa1ed041e090f5ec3d3f8dcc9 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 8 Feb 2024 17:21:23 +0400 Subject: [PATCH 1/9] sampling_metadata -> sampling_state --- serve/mlc_serve/model/model_common.py | 12 ++-- serve/mlc_serve/model/sampler.py | 33 ++++++----- serve/mlc_serve/model/tvm_model.py | 8 +-- serve/tests/unittest/test_sampler.py | 84 +++++++++++++-------------- 4 files changed, 70 insertions(+), 67 deletions(-) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 9ccea2ba1d..54414f45da 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -83,7 +83,7 @@ def sample_from_logits( logits: Union[tvm.nd.NDArray, torch.Tensor], sequence_ids: List[SequenceId], requests: Sequence[RequestType], - sampling_metadata: SamplingState, + sampling_state: SamplingState, vocab_size: int, copy_stream: torch.cuda.Stream, torch_dtype: torch.dtype, @@ -110,13 +110,13 @@ def sample_from_logits( sequence_id, cs_input_ids, logits[i] ) - logits = adjust_logits(logits, sampling_metadata, vocab_size) + logits = adjust_logits(logits, sampling_state, vocab_size) outputs: List[TextGenerationResult] = [] try: sampling_output: Optional[SamplingOutput] = sample( logits, - sampling_metadata, + sampling_state, ) for i, (new_token, logprob_info) in enumerate( @@ -142,13 +142,13 @@ def sample_from_logits( for i in range(batch_size): sequence_id = sequence_ids[i] logits_per_token = logits[i] - sampling_param = sampling_metadata.sampling_params[i] + sampling_param = sampling_state.sampling_params[i] past_decode_tokens_per_request = past_decode_tokens[i] # NOTE: Rerun the preparation for simplicity. # Assume this code path is taken rarely and the recomputation overhead is # marginal. with torch.cuda.stream(copy_stream): - new_sampling_metadata = SamplingState.from_sampling_params( + new_sampling_state = SamplingState.from_sampling_params( [sampling_param], [past_decode_tokens_per_request], torch_dtype, @@ -158,7 +158,7 @@ def sample_from_logits( torch.cuda.current_stream().wait_stream(copy_stream) maybe_sampling_output: Optional[SamplingOutput] = sample( torch.unsqueeze(logits_per_token, 0), - new_sampling_metadata, + new_sampling_state, check_safety=True, ) diff --git a/serve/mlc_serve/model/sampler.py b/serve/mlc_serve/model/sampler.py index 0562121b8f..36516e48b7 100644 --- a/serve/mlc_serve/model/sampler.py +++ b/serve/mlc_serve/model/sampler.py @@ -372,7 +372,10 @@ def from_sampling_params( ) -def adjust_logits(logits, sampling_metadata, vocab_size): +def adjust_logits( + logits: torch.Tensor, + sampling_state: SamplingState, + vocab_size: int): batch_size = logits.shape[0] ( apply_top_p_top_k, @@ -380,10 +383,10 @@ def adjust_logits(logits, sampling_metadata, vocab_size): apply_bias, sampling_tensors, ) = ( - sampling_metadata.apply_top_p_top_k, - sampling_metadata.apply_penalty, - sampling_metadata.apply_bias, - sampling_metadata.sampling_tensors, + sampling_state.apply_top_p_top_k, + sampling_state.apply_penalty, + sampling_state.apply_bias, + sampling_state.sampling_tensors, ) ( temp_t, @@ -447,7 +450,7 @@ class SamplingOutput: def sample( logits: torch.Tensor, - sampling_metadata: SamplingState, + sampling_state: SamplingState, check_safety: bool = False, ) -> SamplingOutput: def _is_safe_to_sample(prob_like): @@ -457,7 +460,7 @@ def _is_safe_to_sample(prob_like): ) res_greedy, res_random = None, None - sampling_tensors = sampling_metadata.sampling_tensors + sampling_tensors = sampling_state.sampling_tensors batch_size = logits.shape[0] mask_greedy_t, mask_random_t = ( @@ -466,13 +469,13 @@ def _is_safe_to_sample(prob_like): ) next_tokens = np.empty((batch_size,), dtype=np.int64) - if sampling_metadata.has_greedy: + if sampling_state.has_greedy: res_greedy = torch.argmax(logits[mask_greedy_t], -1) np_mask_greedy = mask_greedy_t.cpu().numpy() next_tokens[np_mask_greedy] = res_greedy.cpu().numpy() probs_random = None - if sampling_metadata.has_random: + if sampling_state.has_random: probs_random = torch.softmax(logits[mask_random_t], dim=-1) if check_safety and not _is_safe_to_sample(probs_random): return None @@ -481,9 +484,9 @@ def _is_safe_to_sample(prob_like): next_tokens[np_mask_random] = res_random.cpu().numpy() logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * batch_size - if sampling_metadata.has_logprob: + if sampling_state.has_logprob: # If everything is random sampling, save one extra softmax - if not sampling_metadata.has_greedy: + if not sampling_state.has_greedy: assert probs_random is not None logprobs = torch.log(probs_random) else: @@ -494,13 +497,13 @@ def _is_safe_to_sample(prob_like): all_top_logprobs, all_top_tokens = torch.topk( extended_logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True ) - mask = sampling_metadata.sampling_tensors.mask_top_logprob + mask = sampling_state.sampling_tensors.mask_top_logprob top_tokens = all_top_tokens[mask] top_logprobs = all_top_logprobs[mask] - for idx, batch_idx in enumerate(sampling_metadata.logprob_batch_indices): + for idx, batch_idx in enumerate(sampling_state.logprob_batch_indices): next_token = next_tokens[batch_idx] - assert sampling_metadata.sampling_params[batch_idx].logprobs - top_k = sampling_metadata.sampling_params[batch_idx].top_logprobs + assert sampling_state.sampling_params[batch_idx].logprobs + top_k = sampling_state.sampling_params[batch_idx].top_logprobs logprob_infos[batch_idx] = RawLogprobsInfo( current_token_id=next_token, current_logprob=logprobs[batch_idx][next_token], diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 98a52ecbf8..f20370275b 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -259,7 +259,7 @@ def generate_multi_query( # Prepare sampling tensors in another stream to overlap # CPU<->GPU data transfer with GPU computation in forward pass. with torch.cuda.stream(self._copy_stream): - sampling_metadata = SamplingState.from_sampling_params( + sampling_state = SamplingState.from_sampling_params( sampling_params, past_decode_tokens, self.torch_dtype, @@ -318,7 +318,7 @@ def generate_multi_query( last_query_logits, sequence_ids, requests, - sampling_metadata, + sampling_state, self.vocab_size, self._copy_stream, self.torch_dtype, @@ -381,7 +381,7 @@ def generate( # Prepare sampling tensors in another stream to overlap # CPU<->GPU data transfer with GPU computation in forward pass. with torch.cuda.stream(self._copy_stream): - sampling_metadata = SamplingState.from_sampling_params( + sampling_state = SamplingState.from_sampling_params( sampling_params, past_decode_tokens, self.torch_dtype, @@ -502,7 +502,7 @@ def generate( logits, sequence_ids, requests, - sampling_metadata, + sampling_state, self.vocab_size, self._copy_stream, self.torch_dtype, diff --git a/serve/tests/unittest/test_sampler.py b/serve/tests/unittest/test_sampler.py index b526a2fa39..5017220617 100644 --- a/serve/tests/unittest/test_sampler.py +++ b/serve/tests/unittest/test_sampler.py @@ -8,13 +8,13 @@ vocab_size = 32000 -def get_sampling_metadata(sampling_params, past_output_tokens=None): +def get_sampling_state(sampling_params, past_output_tokens=None): batch_size = len(sampling_params) if past_output_tokens is None: past_output_tokens = [[] for _ in range(batch_size)] _copy_stream: torch.cuda.Stream = torch.cuda.Stream() with torch.cuda.stream(_copy_stream): - sampling_metadata = SamplingState.from_sampling_params( + sampling_state = SamplingState.from_sampling_params( sampling_params, list_past_output_tokens=past_output_tokens, dtype=dtype, @@ -22,7 +22,7 @@ def get_sampling_metadata(sampling_params, past_output_tokens=None): vocab_size=vocab_size, ) torch.cuda.current_stream().wait_stream(_copy_stream) - return sampling_metadata + return sampling_state def _test_temperature(temp=0, batch_size=1): @@ -32,10 +32,10 @@ def _test_temperature(temp=0, batch_size=1): temperature=temp, ) - sampling_metadata = get_sampling_metadata([sampling_param]) + sampling_state = get_sampling_state([sampling_param]) expected = logits / temp if abs(temp) > SAMPLING_EPS else logits - new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + new_logits = adjust_logits(logits, sampling_state, vocab_size) assert torch.allclose(expected, new_logits) @@ -44,36 +44,36 @@ def _test_logit_bias_checker(): with pytest.raises(ValueError): logit_bias = {1: 2, 3: 105, 2: 2} sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_metadata([sampling_param]) + get_sampling_state([sampling_param]) with pytest.raises(ValueError): logit_bias = {1: 99, 3: -101, 2: 2} sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_metadata([sampling_param]) + get_sampling_state([sampling_param]) logit_bias = {1: 100, 3: -100, 2: 2} sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_metadata([sampling_param]) + get_sampling_state([sampling_param]) # TODO(@team): it seems like the valid range is [1,vocab_size]. Double check. logit_bias = {1: 10, 3: -10, vocab_size: 2} sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_metadata([sampling_param]) + get_sampling_state([sampling_param]) with pytest.raises(ValueError): logit_bias = {0: 10, 3: -10} sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_metadata([sampling_param]) + get_sampling_state([sampling_param]) with pytest.raises(ValueError): logit_bias = {1: 10, 3: -10, vocab_size + 100: 2} sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_metadata([sampling_param]) + get_sampling_state([sampling_param]) with pytest.raises(ValueError): logit_bias = {1: 10, -1: -10} sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_metadata([sampling_param]) + get_sampling_state([sampling_param]) def _test_logit_bias(): @@ -83,12 +83,12 @@ def _test_logit_bias(): logits = torch.rand(shape, dtype=dtype, device=dev) logit_bias = {1: -1, 3: 1, 2: 2} sampling_param = SamplingParams(logit_bias=logit_bias) - sampling_metadata = get_sampling_metadata([sampling_param]) + sampling_state = get_sampling_state([sampling_param]) expected = torch.clone(logits) for idx, val in logit_bias.items(): expected[0][idx - 1] += val - new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + new_logits = adjust_logits(logits, sampling_state, vocab_size) assert torch.allclose(expected, new_logits) # test multi-batch @@ -99,39 +99,39 @@ def _test_logit_bias(): sampling_params = [ SamplingParams(logit_bias=logit_bias) for logit_bias in list_logit_bias ] - sampling_metadata = get_sampling_metadata(sampling_params) + sampling_state = get_sampling_state(sampling_params) expected = torch.clone(logits) for batch_size in range(batch_size): logit_bias = list_logit_bias[batch_size] for idx, val in logit_bias.items(): expected[batch_size][idx - 1] += val - new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + new_logits = adjust_logits(logits, sampling_state, vocab_size) assert torch.allclose(expected, new_logits) def _test_penalties_checker(): - get_sampling_metadata([SamplingParams(presence_penalty=-1.0)]) - get_sampling_metadata([SamplingParams(frequency_penalty=-1.0)]) - get_sampling_metadata([SamplingParams(repetition_penalty=0.7)]) + get_sampling_state([SamplingParams(presence_penalty=-1.0)]) + get_sampling_state([SamplingParams(frequency_penalty=-1.0)]) + get_sampling_state([SamplingParams(repetition_penalty=0.7)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(presence_penalty=-2.1)]) + get_sampling_state([SamplingParams(presence_penalty=-2.1)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(frequency_penalty=-2.1)]) + get_sampling_state([SamplingParams(frequency_penalty=-2.1)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(repetition_penalty=-2.1)]) + get_sampling_state([SamplingParams(repetition_penalty=-2.1)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(presence_penalty=2.1)]) + get_sampling_state([SamplingParams(presence_penalty=2.1)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(frequency_penalty=2.1)]) + get_sampling_state([SamplingParams(frequency_penalty=2.1)]) with pytest.raises(ValueError): - get_sampling_metadata( + get_sampling_state( [ SamplingParams(frequency_penalty=1.1), SamplingParams(repetition_penalty=2.1), @@ -187,10 +187,10 @@ def get_expected_result( frequency_penalty=frequency_penalties[0], ) ] - sampling_metadata = get_sampling_metadata( + sampling_state = get_sampling_state( sampling_param, past_output_tokens=past_output_tokens ) - new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + new_logits = adjust_logits(logits, sampling_state, vocab_size) assert torch.allclose(expected, new_logits) batch_size = 3 @@ -212,34 +212,34 @@ def get_expected_result( ) for i in range(batch_size) ] - sampling_metadata = get_sampling_metadata( + sampling_state = get_sampling_state( sampling_params, past_output_tokens=past_output_tokens ) - new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + new_logits = adjust_logits(logits, sampling_state, vocab_size) assert torch.allclose(expected, new_logits) def _test_top_p_top_k_checker(): - get_sampling_metadata([SamplingParams(top_p=0.8)]) - get_sampling_metadata([SamplingParams(top_k=3)]) + get_sampling_state([SamplingParams(top_p=0.8)]) + get_sampling_state([SamplingParams(top_k=3)]) - get_sampling_metadata([SamplingParams(top_k=-1)]) - get_sampling_metadata([SamplingParams(top_k=1)]) + get_sampling_state([SamplingParams(top_k=-1)]) + get_sampling_state([SamplingParams(top_k=1)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(top_p=0.0)]) + get_sampling_state([SamplingParams(top_p=0.0)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(top_p=-0.8)]) + get_sampling_state([SamplingParams(top_p=-0.8)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(top_k=0)]) + get_sampling_state([SamplingParams(top_k=0)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(top_k=0.8)]) + get_sampling_state([SamplingParams(top_k=0.8)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(top_k=-2)]) + get_sampling_state([SamplingParams(top_k=-2)]) def _test_top_p_top_k(): @@ -293,8 +293,8 @@ def get_expected_result(logits, top_pks, filter_value=-float("Inf")): sampling_params = [ SamplingParams(top_p=top_p, top_k=top_k) for _ in range(batch_size) ] - sampling_metadata = get_sampling_metadata(sampling_params) - new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + sampling_state = get_sampling_state(sampling_params) + new_logits = adjust_logits(logits, sampling_state, vocab_size) expected = logits.clone() expected = get_expected_result(expected, top_pks=[(top_p, top_k)]) assert torch.allclose(expected, new_logits) @@ -306,9 +306,9 @@ def get_expected_result(logits, top_pks, filter_value=-float("Inf")): sampling_params = [ SamplingParams(top_p=top_p, top_k=top_k) for top_p, top_k in top_pks ] - sampling_metadata = get_sampling_metadata(sampling_params) + sampling_state = get_sampling_state(sampling_params) - new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + new_logits = adjust_logits(logits, sampling_state, vocab_size) expected = logits.clone() expected = get_expected_result(expected, top_pks) assert torch.allclose(expected, new_logits) From abe3d18be4da205b1597cc5b070846f5913b4415 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 9 Feb 2024 18:30:29 +0400 Subject: [PATCH 2/9] update repetition penalties --- serve/mlc_serve/model/sampler.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/serve/mlc_serve/model/sampler.py b/serve/mlc_serve/model/sampler.py index 36516e48b7..6f7b929d76 100644 --- a/serve/mlc_serve/model/sampler.py +++ b/serve/mlc_serve/model/sampler.py @@ -415,9 +415,11 @@ def adjust_logits( # in the right order. if apply_penalty: repetition_penalties_t = repetition_penalties_t[:, None].repeat(1, vocab_size) - logits = torch.where( - logits > 0, logits / repetition_penalties_t, logits * repetition_penalties_t - ) + # RepetitionPenaltyLogitsProcessor approach from HF TGI API is used + # https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L332C1-L339C22 + # where score is logits + # https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L76C1-L78C92 + logits = logits / repetition_penalties_t bin_counts = torch.zeros( (batch_size, vocab_size + 1), dtype=torch.long, device=logits.device ) From 3197fa7d56831e81b916eb88c0acd963c6865350 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 12 Feb 2024 16:55:06 +0400 Subject: [PATCH 3/9] correct repetition penalties calculation use vLLM and HF approaches --- serve/mlc_serve/model/sampler.py | 48 ++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/serve/mlc_serve/model/sampler.py b/serve/mlc_serve/model/sampler.py index 6f7b929d76..0013fc767d 100644 --- a/serve/mlc_serve/model/sampler.py +++ b/serve/mlc_serve/model/sampler.py @@ -372,6 +372,21 @@ def from_sampling_params( ) +def get_bin_counts_and_mask( + tokens: torch.Tensor, + vocab_size: int, + num_seqs: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + bin_counts = torch.zeros((num_seqs, vocab_size + 1), + dtype=torch.long, + device=tokens.device) + bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) + bin_counts = bin_counts[:, :vocab_size] + mask = bin_counts > 0 + + return bin_counts, mask + + def adjust_logits( logits: torch.Tensor, sampling_state: SamplingState, @@ -414,22 +429,31 @@ def adjust_logits( # (e.g., repetition penalty, frequency/presence penalty, logit bias, temperature...) # in the right order. if apply_penalty: - repetition_penalties_t = repetition_penalties_t[:, None].repeat(1, vocab_size) - # RepetitionPenaltyLogitsProcessor approach from HF TGI API is used + bin_counts, output_mask = get_bin_counts_and_mask( + past_output_tokens_t, + vocab_size, + batch_size, + ) + + _, prompt_mask = get_bin_counts_and_mask( + prompt_tokens_t, + vocab_size, + batch_size, + ) + + # Calculate repetition penalty use vLLM approach + # https://github.com/vllm-project/vllm/blob/0580aab02ffe60fee50bddc80b787828eb233c44/vllm/model_executor/layers/sampler.py#L177 + # and RepetitionPenaltyLogitsProcessor approach from HF TGI API # https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L332C1-L339C22 # where score is logits # https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L76C1-L78C92 - logits = logits / repetition_penalties_t - bin_counts = torch.zeros( - (batch_size, vocab_size + 1), dtype=torch.long, device=logits.device - ) - bin_counts.scatter_add_( - 1, past_output_tokens_t, torch.ones_like(past_output_tokens_t) - ) - bin_counts = bin_counts[:, :vocab_size] - mask = bin_counts > 0 + repetition_penalties_t = repetition_penalties_t[:, None].repeat(1, vocab_size) + repetition_penalties_t[~(prompt_mask | output_mask)] = 1.0 + logits /= repetition_penalties_t + + # Calculate frequency and presence penalties logits -= frequency_penalties_t.unsqueeze_(dim=1) * bin_counts - logits -= presence_penalties_t.unsqueeze_(dim=1) * mask + logits -= presence_penalties_t.unsqueeze_(dim=1) * output_mask # Adjust temperature logits.div_(temp_t.unsqueeze(dim=1)) From 65ae93ae165a94833b19310fb0956a0ba04c5b07 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 12 Feb 2024 21:31:06 +0400 Subject: [PATCH 4/9] construct mask_prompt and transfer to repetition penalty --- serve/mlc_serve/engine/engine_common.py | 16 ++++++++++++++++ serve/mlc_serve/engine/sampling_params.py | 2 ++ serve/mlc_serve/model/sampler.py | 18 ++++++++++++------ 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index d7e15cd1e9..f48fc16081 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -2,6 +2,7 @@ Common utilites for engine classes. """ +import torch import time from typing import Tuple, Deque, Dict, Optional, Union, Callable, List from collections import deque @@ -240,6 +241,18 @@ def prepare_output( return delta, out_logprob_info +def set_mask_prompt_to(state: RequestState): + # Prompt tokens + tokens=torch.tensor(state.prompt_token_ids, dtype=torch.long) + vocab_size = state.sampling_params.vocab_size + bin_counts = torch.zeros((vocab_size + 1,), + dtype=torch.long, + device=tokens.device) + bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) + bin_counts = bin_counts[:vocab_size] + state.sampling_params.mask_prompt = bin_counts > 0 + + def get_requests_to_process( current_states: list[RequestState], cache_manager: KVCacheManager, @@ -264,6 +277,9 @@ def get_requests_to_process( if is_prompt_batch: for state in current_states: if is_evicted_parallel_sampling_request(state): + # TODO(vvchernov): we still need mask if apply_penallty = True + # if state.sampling_params.repetition_penalty != 1.0: + set_mask_prompt_to(state) requests.append( PrefillRequest( request_id=state.request_id, diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index 43c7040e6f..1d2488a89f 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -7,6 +7,7 @@ from enum import IntEnum from functools import cached_property from typing import Dict, Optional, Any +import torch _SAMPLING_EPS = 1e-5 LOGPROB_TOP_K_MAX = 5 @@ -75,6 +76,7 @@ class SamplingParams: vocab_size = 32000 json_schema: Optional[Dict[str, Any]] = None logits_processor: Optional[Any] = None + mask_prompt: Optional[torch.Tensor] = None def __post_init__(self): if self.logit_bias: diff --git a/serve/mlc_serve/model/sampler.py b/serve/mlc_serve/model/sampler.py index 0013fc767d..18f84fb888 100644 --- a/serve/mlc_serve/model/sampler.py +++ b/serve/mlc_serve/model/sampler.py @@ -53,6 +53,9 @@ class SamplingTensors: mask_top_logprob: torch.Tensor Mask for requests with top_logprob. shape: (LOGPROB_TOP_K_MAX) + 1, batch_size,) + mask_prompt: torch.Tensor + Mask for request with repetition penalty (prompt part) + shape: (batch_size, vocab_size) temperatures: torch.Tensor Tensor for temperature values shape: (batch_size, ) @@ -85,6 +88,7 @@ class SamplingTensors: mask_random: torch.Tensor mask_greedy: torch.Tensor mask_top_logprob: torch.Tensor + mask_prompt: torch.Tensor temperatures: torch.Tensor top_ps: torch.Tensor top_ks: torch.Tensor @@ -102,6 +106,7 @@ def from_lists( dev, list_mask_random: List[bool], list_mask_top_logprob: List[List[bool]], + list_mask_prompt: List[torch.Tensor], list_temperatures: List[float], list_top_ps: List[float], list_top_ks: List[int], @@ -124,6 +129,7 @@ def from_lists( ) # `mask_top_logprob` will be on cpu mask_top_logprob = torch.from_numpy(list_mask_top_logprob) + mask_prompt = torch.stack(list_mask_prompt) temp = torch.tensor( list_temperatures, dtype=dtype, @@ -185,6 +191,7 @@ def from_lists( mask_random, mask_greedy, mask_top_logprob, + mask_prompt, temp.to(device=dev, non_blocking=True), top_ps.to(device=dev, non_blocking=True), top_ks.to(device=dev, non_blocking=True), @@ -250,6 +257,7 @@ def from_sampling_params( vocab_size: int, ): list_mask_random = [] + list_mask_prompt = [] list_temperatures = [] list_top_ps = [] list_top_ks = [] @@ -307,6 +315,7 @@ def from_sampling_params( list_frequency_penalties.append(param.frequency_penalty) list_presence_penalties.append(param.presence_penalty) list_repetition_penalties.append(param.repetition_penalty) + list_mask_prompt.append(param.mask_prompt) if param.logit_bias_index: assert param.logit_bias_value @@ -348,6 +357,7 @@ def from_sampling_params( dev, list_mask_random, list_mask_top_logprob, + list_mask_prompt, list_temperatures, list_top_ps, list_top_ks, @@ -404,6 +414,7 @@ def adjust_logits( sampling_state.sampling_tensors, ) ( + prompt_mask, temp_t, top_ps_t, top_ks_t, @@ -414,6 +425,7 @@ def adjust_logits( logit_bias_indices_t, logit_bias_values_t, ) = ( + sampling_tensors.mask_prompt, sampling_tensors.temperatures, sampling_tensors.top_ps, sampling_tensors.top_ks, @@ -435,12 +447,6 @@ def adjust_logits( batch_size, ) - _, prompt_mask = get_bin_counts_and_mask( - prompt_tokens_t, - vocab_size, - batch_size, - ) - # Calculate repetition penalty use vLLM approach # https://github.com/vllm-project/vllm/blob/0580aab02ffe60fee50bddc80b787828eb233c44/vllm/model_executor/layers/sampler.py#L177 # and RepetitionPenaltyLogitsProcessor approach from HF TGI API From f439b97f99e63c71929ba8e5144ca21e8b1b5458 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 13 Feb 2024 11:24:44 +0400 Subject: [PATCH 5/9] fix comments --- serve/mlc_serve/engine/engine_common.py | 2 +- serve/mlc_serve/model/sampler.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index f48fc16081..2a4d46ca15 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -277,7 +277,7 @@ def get_requests_to_process( if is_prompt_batch: for state in current_states: if is_evicted_parallel_sampling_request(state): - # TODO(vvchernov): we still need mask if apply_penallty = True + # TODO(vvchernov): we still need mask if apply_penalty = True # if state.sampling_params.repetition_penalty != 1.0: set_mask_prompt_to(state) requests.append( diff --git a/serve/mlc_serve/model/sampler.py b/serve/mlc_serve/model/sampler.py index 18f84fb888..fbcecbba37 100644 --- a/serve/mlc_serve/model/sampler.py +++ b/serve/mlc_serve/model/sampler.py @@ -447,6 +447,8 @@ def adjust_logits( batch_size, ) + # It was checked that vLLM and HF approaches for repetition penalty are the same + # For calculation of it their combination is used (see references below) # Calculate repetition penalty use vLLM approach # https://github.com/vllm-project/vllm/blob/0580aab02ffe60fee50bddc80b787828eb233c44/vllm/model_executor/layers/sampler.py#L177 # and RepetitionPenaltyLogitsProcessor approach from HF TGI API From fe8b5de46d9521b3f4764ab7c52f7b969fa0a4b9 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 14 Feb 2024 18:53:07 +0400 Subject: [PATCH 6/9] fix penalty calculation --- serve/mlc_serve/model/sampler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/serve/mlc_serve/model/sampler.py b/serve/mlc_serve/model/sampler.py index fbcecbba37..d5d443eec4 100644 --- a/serve/mlc_serve/model/sampler.py +++ b/serve/mlc_serve/model/sampler.py @@ -457,7 +457,9 @@ def adjust_logits( # https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L76C1-L78C92 repetition_penalties_t = repetition_penalties_t[:, None].repeat(1, vocab_size) repetition_penalties_t[~(prompt_mask | output_mask)] = 1.0 - logits /= repetition_penalties_t + logits = torch.where( + logits > 0, logits / repetition_penalties_t, logits * repetition_penalties_t + ) # Calculate frequency and presence penalties logits -= frequency_penalties_t.unsqueeze_(dim=1) * bin_counts From fa161b870009c75681bd1b2109495266cd05b5be Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 14 Feb 2024 19:31:39 +0400 Subject: [PATCH 7/9] add mask_prompt in correct place --- serve/mlc_serve/engine/engine_common.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 2a4d46ca15..6bd12cf7af 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -279,7 +279,7 @@ def get_requests_to_process( if is_evicted_parallel_sampling_request(state): # TODO(vvchernov): we still need mask if apply_penalty = True # if state.sampling_params.repetition_penalty != 1.0: - set_mask_prompt_to(state) + # set_mask_prompt_to(state) requests.append( PrefillRequest( request_id=state.request_id, @@ -327,6 +327,9 @@ def get_requests_to_process( else: token_ids = state.prompt_token_ids + # TODO(vvchernov): we still need mask if apply_penalty = True + # if state.sampling_params.repetition_penalty != 1.0: + set_mask_prompt_to(state) requests.append( PrefillRequest( request_id=state.request_id, From 77180aa1e35646515d2a4814c1e8c6eb943993e2 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 14 Feb 2024 19:35:06 +0400 Subject: [PATCH 8/9] fix dim for scatter --- serve/mlc_serve/engine/engine_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 6bd12cf7af..1ddf6fa7fa 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -248,7 +248,7 @@ def set_mask_prompt_to(state: RequestState): bin_counts = torch.zeros((vocab_size + 1,), dtype=torch.long, device=tokens.device) - bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) + bin_counts.scatter_add_(0, tokens, torch.ones_like(tokens)) bin_counts = bin_counts[:vocab_size] state.sampling_params.mask_prompt = bin_counts > 0 From 52155cad48954adac82187f2bfd6ff213933eb2a Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 14 Feb 2024 20:02:57 +0400 Subject: [PATCH 9/9] fix device --- serve/mlc_serve/model/sampler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/serve/mlc_serve/model/sampler.py b/serve/mlc_serve/model/sampler.py index d5d443eec4..2f1b9d3c72 100644 --- a/serve/mlc_serve/model/sampler.py +++ b/serve/mlc_serve/model/sampler.py @@ -456,6 +456,7 @@ def adjust_logits( # where score is logits # https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L76C1-L78C92 repetition_penalties_t = repetition_penalties_t[:, None].repeat(1, vocab_size) + prompt_mask = prompt_mask.to(repetition_penalties_t.device) repetition_penalties_t[~(prompt_mask | output_mask)] = 1.0 logits = torch.where( logits > 0, logits / repetition_penalties_t, logits * repetition_penalties_t