diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index d7e15cd1e9..1ddf6fa7fa 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -2,6 +2,7 @@ Common utilites for engine classes. """ +import torch import time from typing import Tuple, Deque, Dict, Optional, Union, Callable, List from collections import deque @@ -240,6 +241,18 @@ def prepare_output( return delta, out_logprob_info +def set_mask_prompt_to(state: RequestState): + # Prompt tokens + tokens=torch.tensor(state.prompt_token_ids, dtype=torch.long) + vocab_size = state.sampling_params.vocab_size + bin_counts = torch.zeros((vocab_size + 1,), + dtype=torch.long, + device=tokens.device) + bin_counts.scatter_add_(0, tokens, torch.ones_like(tokens)) + bin_counts = bin_counts[:vocab_size] + state.sampling_params.mask_prompt = bin_counts > 0 + + def get_requests_to_process( current_states: list[RequestState], cache_manager: KVCacheManager, @@ -264,6 +277,9 @@ def get_requests_to_process( if is_prompt_batch: for state in current_states: if is_evicted_parallel_sampling_request(state): + # TODO(vvchernov): we still need mask if apply_penalty = True + # if state.sampling_params.repetition_penalty != 1.0: + # set_mask_prompt_to(state) requests.append( PrefillRequest( request_id=state.request_id, @@ -311,6 +327,9 @@ def get_requests_to_process( else: token_ids = state.prompt_token_ids + # TODO(vvchernov): we still need mask if apply_penalty = True + # if state.sampling_params.repetition_penalty != 1.0: + set_mask_prompt_to(state) requests.append( PrefillRequest( request_id=state.request_id, diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index 43c7040e6f..1d2488a89f 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -7,6 +7,7 @@ from enum import IntEnum from functools import cached_property from typing import Dict, Optional, Any +import torch _SAMPLING_EPS = 1e-5 LOGPROB_TOP_K_MAX = 5 @@ -75,6 +76,7 @@ class SamplingParams: vocab_size = 32000 json_schema: Optional[Dict[str, Any]] = None logits_processor: Optional[Any] = None + mask_prompt: Optional[torch.Tensor] = None def __post_init__(self): if self.logit_bias: diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 9ccea2ba1d..54414f45da 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -83,7 +83,7 @@ def sample_from_logits( logits: Union[tvm.nd.NDArray, torch.Tensor], sequence_ids: List[SequenceId], requests: Sequence[RequestType], - sampling_metadata: SamplingState, + sampling_state: SamplingState, vocab_size: int, copy_stream: torch.cuda.Stream, torch_dtype: torch.dtype, @@ -110,13 +110,13 @@ def sample_from_logits( sequence_id, cs_input_ids, logits[i] ) - logits = adjust_logits(logits, sampling_metadata, vocab_size) + logits = adjust_logits(logits, sampling_state, vocab_size) outputs: List[TextGenerationResult] = [] try: sampling_output: Optional[SamplingOutput] = sample( logits, - sampling_metadata, + sampling_state, ) for i, (new_token, logprob_info) in enumerate( @@ -142,13 +142,13 @@ def sample_from_logits( for i in range(batch_size): sequence_id = sequence_ids[i] logits_per_token = logits[i] - sampling_param = sampling_metadata.sampling_params[i] + sampling_param = sampling_state.sampling_params[i] past_decode_tokens_per_request = past_decode_tokens[i] # NOTE: Rerun the preparation for simplicity. # Assume this code path is taken rarely and the recomputation overhead is # marginal. with torch.cuda.stream(copy_stream): - new_sampling_metadata = SamplingState.from_sampling_params( + new_sampling_state = SamplingState.from_sampling_params( [sampling_param], [past_decode_tokens_per_request], torch_dtype, @@ -158,7 +158,7 @@ def sample_from_logits( torch.cuda.current_stream().wait_stream(copy_stream) maybe_sampling_output: Optional[SamplingOutput] = sample( torch.unsqueeze(logits_per_token, 0), - new_sampling_metadata, + new_sampling_state, check_safety=True, ) diff --git a/serve/mlc_serve/model/sampler.py b/serve/mlc_serve/model/sampler.py index 0562121b8f..2f1b9d3c72 100644 --- a/serve/mlc_serve/model/sampler.py +++ b/serve/mlc_serve/model/sampler.py @@ -53,6 +53,9 @@ class SamplingTensors: mask_top_logprob: torch.Tensor Mask for requests with top_logprob. shape: (LOGPROB_TOP_K_MAX) + 1, batch_size,) + mask_prompt: torch.Tensor + Mask for request with repetition penalty (prompt part) + shape: (batch_size, vocab_size) temperatures: torch.Tensor Tensor for temperature values shape: (batch_size, ) @@ -85,6 +88,7 @@ class SamplingTensors: mask_random: torch.Tensor mask_greedy: torch.Tensor mask_top_logprob: torch.Tensor + mask_prompt: torch.Tensor temperatures: torch.Tensor top_ps: torch.Tensor top_ks: torch.Tensor @@ -102,6 +106,7 @@ def from_lists( dev, list_mask_random: List[bool], list_mask_top_logprob: List[List[bool]], + list_mask_prompt: List[torch.Tensor], list_temperatures: List[float], list_top_ps: List[float], list_top_ks: List[int], @@ -124,6 +129,7 @@ def from_lists( ) # `mask_top_logprob` will be on cpu mask_top_logprob = torch.from_numpy(list_mask_top_logprob) + mask_prompt = torch.stack(list_mask_prompt) temp = torch.tensor( list_temperatures, dtype=dtype, @@ -185,6 +191,7 @@ def from_lists( mask_random, mask_greedy, mask_top_logprob, + mask_prompt, temp.to(device=dev, non_blocking=True), top_ps.to(device=dev, non_blocking=True), top_ks.to(device=dev, non_blocking=True), @@ -250,6 +257,7 @@ def from_sampling_params( vocab_size: int, ): list_mask_random = [] + list_mask_prompt = [] list_temperatures = [] list_top_ps = [] list_top_ks = [] @@ -307,6 +315,7 @@ def from_sampling_params( list_frequency_penalties.append(param.frequency_penalty) list_presence_penalties.append(param.presence_penalty) list_repetition_penalties.append(param.repetition_penalty) + list_mask_prompt.append(param.mask_prompt) if param.logit_bias_index: assert param.logit_bias_value @@ -348,6 +357,7 @@ def from_sampling_params( dev, list_mask_random, list_mask_top_logprob, + list_mask_prompt, list_temperatures, list_top_ps, list_top_ks, @@ -372,7 +382,25 @@ def from_sampling_params( ) -def adjust_logits(logits, sampling_metadata, vocab_size): +def get_bin_counts_and_mask( + tokens: torch.Tensor, + vocab_size: int, + num_seqs: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + bin_counts = torch.zeros((num_seqs, vocab_size + 1), + dtype=torch.long, + device=tokens.device) + bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) + bin_counts = bin_counts[:, :vocab_size] + mask = bin_counts > 0 + + return bin_counts, mask + + +def adjust_logits( + logits: torch.Tensor, + sampling_state: SamplingState, + vocab_size: int): batch_size = logits.shape[0] ( apply_top_p_top_k, @@ -380,12 +408,13 @@ def adjust_logits(logits, sampling_metadata, vocab_size): apply_bias, sampling_tensors, ) = ( - sampling_metadata.apply_top_p_top_k, - sampling_metadata.apply_penalty, - sampling_metadata.apply_bias, - sampling_metadata.sampling_tensors, + sampling_state.apply_top_p_top_k, + sampling_state.apply_penalty, + sampling_state.apply_bias, + sampling_state.sampling_tensors, ) ( + prompt_mask, temp_t, top_ps_t, top_ks_t, @@ -396,6 +425,7 @@ def adjust_logits(logits, sampling_metadata, vocab_size): logit_bias_indices_t, logit_bias_values_t, ) = ( + sampling_tensors.mask_prompt, sampling_tensors.temperatures, sampling_tensors.top_ps, sampling_tensors.top_ks, @@ -411,20 +441,30 @@ def adjust_logits(logits, sampling_metadata, vocab_size): # (e.g., repetition penalty, frequency/presence penalty, logit bias, temperature...) # in the right order. if apply_penalty: + bin_counts, output_mask = get_bin_counts_and_mask( + past_output_tokens_t, + vocab_size, + batch_size, + ) + + # It was checked that vLLM and HF approaches for repetition penalty are the same + # For calculation of it their combination is used (see references below) + # Calculate repetition penalty use vLLM approach + # https://github.com/vllm-project/vllm/blob/0580aab02ffe60fee50bddc80b787828eb233c44/vllm/model_executor/layers/sampler.py#L177 + # and RepetitionPenaltyLogitsProcessor approach from HF TGI API + # https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L332C1-L339C22 + # where score is logits + # https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L76C1-L78C92 repetition_penalties_t = repetition_penalties_t[:, None].repeat(1, vocab_size) + prompt_mask = prompt_mask.to(repetition_penalties_t.device) + repetition_penalties_t[~(prompt_mask | output_mask)] = 1.0 logits = torch.where( logits > 0, logits / repetition_penalties_t, logits * repetition_penalties_t ) - bin_counts = torch.zeros( - (batch_size, vocab_size + 1), dtype=torch.long, device=logits.device - ) - bin_counts.scatter_add_( - 1, past_output_tokens_t, torch.ones_like(past_output_tokens_t) - ) - bin_counts = bin_counts[:, :vocab_size] - mask = bin_counts > 0 + + # Calculate frequency and presence penalties logits -= frequency_penalties_t.unsqueeze_(dim=1) * bin_counts - logits -= presence_penalties_t.unsqueeze_(dim=1) * mask + logits -= presence_penalties_t.unsqueeze_(dim=1) * output_mask # Adjust temperature logits.div_(temp_t.unsqueeze(dim=1)) @@ -447,7 +487,7 @@ class SamplingOutput: def sample( logits: torch.Tensor, - sampling_metadata: SamplingState, + sampling_state: SamplingState, check_safety: bool = False, ) -> SamplingOutput: def _is_safe_to_sample(prob_like): @@ -457,7 +497,7 @@ def _is_safe_to_sample(prob_like): ) res_greedy, res_random = None, None - sampling_tensors = sampling_metadata.sampling_tensors + sampling_tensors = sampling_state.sampling_tensors batch_size = logits.shape[0] mask_greedy_t, mask_random_t = ( @@ -466,13 +506,13 @@ def _is_safe_to_sample(prob_like): ) next_tokens = np.empty((batch_size,), dtype=np.int64) - if sampling_metadata.has_greedy: + if sampling_state.has_greedy: res_greedy = torch.argmax(logits[mask_greedy_t], -1) np_mask_greedy = mask_greedy_t.cpu().numpy() next_tokens[np_mask_greedy] = res_greedy.cpu().numpy() probs_random = None - if sampling_metadata.has_random: + if sampling_state.has_random: probs_random = torch.softmax(logits[mask_random_t], dim=-1) if check_safety and not _is_safe_to_sample(probs_random): return None @@ -481,9 +521,9 @@ def _is_safe_to_sample(prob_like): next_tokens[np_mask_random] = res_random.cpu().numpy() logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * batch_size - if sampling_metadata.has_logprob: + if sampling_state.has_logprob: # If everything is random sampling, save one extra softmax - if not sampling_metadata.has_greedy: + if not sampling_state.has_greedy: assert probs_random is not None logprobs = torch.log(probs_random) else: @@ -494,13 +534,13 @@ def _is_safe_to_sample(prob_like): all_top_logprobs, all_top_tokens = torch.topk( extended_logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True ) - mask = sampling_metadata.sampling_tensors.mask_top_logprob + mask = sampling_state.sampling_tensors.mask_top_logprob top_tokens = all_top_tokens[mask] top_logprobs = all_top_logprobs[mask] - for idx, batch_idx in enumerate(sampling_metadata.logprob_batch_indices): + for idx, batch_idx in enumerate(sampling_state.logprob_batch_indices): next_token = next_tokens[batch_idx] - assert sampling_metadata.sampling_params[batch_idx].logprobs - top_k = sampling_metadata.sampling_params[batch_idx].top_logprobs + assert sampling_state.sampling_params[batch_idx].logprobs + top_k = sampling_state.sampling_params[batch_idx].top_logprobs logprob_infos[batch_idx] = RawLogprobsInfo( current_token_id=next_token, current_logprob=logprobs[batch_idx][next_token], diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 98a52ecbf8..f20370275b 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -259,7 +259,7 @@ def generate_multi_query( # Prepare sampling tensors in another stream to overlap # CPU<->GPU data transfer with GPU computation in forward pass. with torch.cuda.stream(self._copy_stream): - sampling_metadata = SamplingState.from_sampling_params( + sampling_state = SamplingState.from_sampling_params( sampling_params, past_decode_tokens, self.torch_dtype, @@ -318,7 +318,7 @@ def generate_multi_query( last_query_logits, sequence_ids, requests, - sampling_metadata, + sampling_state, self.vocab_size, self._copy_stream, self.torch_dtype, @@ -381,7 +381,7 @@ def generate( # Prepare sampling tensors in another stream to overlap # CPU<->GPU data transfer with GPU computation in forward pass. with torch.cuda.stream(self._copy_stream): - sampling_metadata = SamplingState.from_sampling_params( + sampling_state = SamplingState.from_sampling_params( sampling_params, past_decode_tokens, self.torch_dtype, @@ -502,7 +502,7 @@ def generate( logits, sequence_ids, requests, - sampling_metadata, + sampling_state, self.vocab_size, self._copy_stream, self.torch_dtype, diff --git a/serve/tests/unittest/test_sampler.py b/serve/tests/unittest/test_sampler.py index b526a2fa39..5017220617 100644 --- a/serve/tests/unittest/test_sampler.py +++ b/serve/tests/unittest/test_sampler.py @@ -8,13 +8,13 @@ vocab_size = 32000 -def get_sampling_metadata(sampling_params, past_output_tokens=None): +def get_sampling_state(sampling_params, past_output_tokens=None): batch_size = len(sampling_params) if past_output_tokens is None: past_output_tokens = [[] for _ in range(batch_size)] _copy_stream: torch.cuda.Stream = torch.cuda.Stream() with torch.cuda.stream(_copy_stream): - sampling_metadata = SamplingState.from_sampling_params( + sampling_state = SamplingState.from_sampling_params( sampling_params, list_past_output_tokens=past_output_tokens, dtype=dtype, @@ -22,7 +22,7 @@ def get_sampling_metadata(sampling_params, past_output_tokens=None): vocab_size=vocab_size, ) torch.cuda.current_stream().wait_stream(_copy_stream) - return sampling_metadata + return sampling_state def _test_temperature(temp=0, batch_size=1): @@ -32,10 +32,10 @@ def _test_temperature(temp=0, batch_size=1): temperature=temp, ) - sampling_metadata = get_sampling_metadata([sampling_param]) + sampling_state = get_sampling_state([sampling_param]) expected = logits / temp if abs(temp) > SAMPLING_EPS else logits - new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + new_logits = adjust_logits(logits, sampling_state, vocab_size) assert torch.allclose(expected, new_logits) @@ -44,36 +44,36 @@ def _test_logit_bias_checker(): with pytest.raises(ValueError): logit_bias = {1: 2, 3: 105, 2: 2} sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_metadata([sampling_param]) + get_sampling_state([sampling_param]) with pytest.raises(ValueError): logit_bias = {1: 99, 3: -101, 2: 2} sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_metadata([sampling_param]) + get_sampling_state([sampling_param]) logit_bias = {1: 100, 3: -100, 2: 2} sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_metadata([sampling_param]) + get_sampling_state([sampling_param]) # TODO(@team): it seems like the valid range is [1,vocab_size]. Double check. logit_bias = {1: 10, 3: -10, vocab_size: 2} sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_metadata([sampling_param]) + get_sampling_state([sampling_param]) with pytest.raises(ValueError): logit_bias = {0: 10, 3: -10} sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_metadata([sampling_param]) + get_sampling_state([sampling_param]) with pytest.raises(ValueError): logit_bias = {1: 10, 3: -10, vocab_size + 100: 2} sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_metadata([sampling_param]) + get_sampling_state([sampling_param]) with pytest.raises(ValueError): logit_bias = {1: 10, -1: -10} sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_metadata([sampling_param]) + get_sampling_state([sampling_param]) def _test_logit_bias(): @@ -83,12 +83,12 @@ def _test_logit_bias(): logits = torch.rand(shape, dtype=dtype, device=dev) logit_bias = {1: -1, 3: 1, 2: 2} sampling_param = SamplingParams(logit_bias=logit_bias) - sampling_metadata = get_sampling_metadata([sampling_param]) + sampling_state = get_sampling_state([sampling_param]) expected = torch.clone(logits) for idx, val in logit_bias.items(): expected[0][idx - 1] += val - new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + new_logits = adjust_logits(logits, sampling_state, vocab_size) assert torch.allclose(expected, new_logits) # test multi-batch @@ -99,39 +99,39 @@ def _test_logit_bias(): sampling_params = [ SamplingParams(logit_bias=logit_bias) for logit_bias in list_logit_bias ] - sampling_metadata = get_sampling_metadata(sampling_params) + sampling_state = get_sampling_state(sampling_params) expected = torch.clone(logits) for batch_size in range(batch_size): logit_bias = list_logit_bias[batch_size] for idx, val in logit_bias.items(): expected[batch_size][idx - 1] += val - new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + new_logits = adjust_logits(logits, sampling_state, vocab_size) assert torch.allclose(expected, new_logits) def _test_penalties_checker(): - get_sampling_metadata([SamplingParams(presence_penalty=-1.0)]) - get_sampling_metadata([SamplingParams(frequency_penalty=-1.0)]) - get_sampling_metadata([SamplingParams(repetition_penalty=0.7)]) + get_sampling_state([SamplingParams(presence_penalty=-1.0)]) + get_sampling_state([SamplingParams(frequency_penalty=-1.0)]) + get_sampling_state([SamplingParams(repetition_penalty=0.7)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(presence_penalty=-2.1)]) + get_sampling_state([SamplingParams(presence_penalty=-2.1)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(frequency_penalty=-2.1)]) + get_sampling_state([SamplingParams(frequency_penalty=-2.1)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(repetition_penalty=-2.1)]) + get_sampling_state([SamplingParams(repetition_penalty=-2.1)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(presence_penalty=2.1)]) + get_sampling_state([SamplingParams(presence_penalty=2.1)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(frequency_penalty=2.1)]) + get_sampling_state([SamplingParams(frequency_penalty=2.1)]) with pytest.raises(ValueError): - get_sampling_metadata( + get_sampling_state( [ SamplingParams(frequency_penalty=1.1), SamplingParams(repetition_penalty=2.1), @@ -187,10 +187,10 @@ def get_expected_result( frequency_penalty=frequency_penalties[0], ) ] - sampling_metadata = get_sampling_metadata( + sampling_state = get_sampling_state( sampling_param, past_output_tokens=past_output_tokens ) - new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + new_logits = adjust_logits(logits, sampling_state, vocab_size) assert torch.allclose(expected, new_logits) batch_size = 3 @@ -212,34 +212,34 @@ def get_expected_result( ) for i in range(batch_size) ] - sampling_metadata = get_sampling_metadata( + sampling_state = get_sampling_state( sampling_params, past_output_tokens=past_output_tokens ) - new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + new_logits = adjust_logits(logits, sampling_state, vocab_size) assert torch.allclose(expected, new_logits) def _test_top_p_top_k_checker(): - get_sampling_metadata([SamplingParams(top_p=0.8)]) - get_sampling_metadata([SamplingParams(top_k=3)]) + get_sampling_state([SamplingParams(top_p=0.8)]) + get_sampling_state([SamplingParams(top_k=3)]) - get_sampling_metadata([SamplingParams(top_k=-1)]) - get_sampling_metadata([SamplingParams(top_k=1)]) + get_sampling_state([SamplingParams(top_k=-1)]) + get_sampling_state([SamplingParams(top_k=1)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(top_p=0.0)]) + get_sampling_state([SamplingParams(top_p=0.0)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(top_p=-0.8)]) + get_sampling_state([SamplingParams(top_p=-0.8)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(top_k=0)]) + get_sampling_state([SamplingParams(top_k=0)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(top_k=0.8)]) + get_sampling_state([SamplingParams(top_k=0.8)]) with pytest.raises(ValueError): - get_sampling_metadata([SamplingParams(top_k=-2)]) + get_sampling_state([SamplingParams(top_k=-2)]) def _test_top_p_top_k(): @@ -293,8 +293,8 @@ def get_expected_result(logits, top_pks, filter_value=-float("Inf")): sampling_params = [ SamplingParams(top_p=top_p, top_k=top_k) for _ in range(batch_size) ] - sampling_metadata = get_sampling_metadata(sampling_params) - new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + sampling_state = get_sampling_state(sampling_params) + new_logits = adjust_logits(logits, sampling_state, vocab_size) expected = logits.clone() expected = get_expected_result(expected, top_pks=[(top_p, top_k)]) assert torch.allclose(expected, new_logits) @@ -306,9 +306,9 @@ def get_expected_result(logits, top_pks, filter_value=-float("Inf")): sampling_params = [ SamplingParams(top_p=top_p, top_k=top_k) for top_p, top_k in top_pks ] - sampling_metadata = get_sampling_metadata(sampling_params) + sampling_state = get_sampling_state(sampling_params) - new_logits = adjust_logits(logits, sampling_metadata, vocab_size) + new_logits = adjust_logits(logits, sampling_state, vocab_size) expected = logits.clone() expected = get_expected_result(expected, top_pks) assert torch.allclose(expected, new_logits)