diff --git a/serve/benchmarks/benchmark_throughput.py b/serve/benchmarks/benchmark_throughput.py index 53a9a5325c..60d3d31ff8 100644 --- a/serve/benchmarks/benchmark_throughput.py +++ b/serve/benchmarks/benchmark_throughput.py @@ -74,7 +74,7 @@ def run_mlc( temperature=1.0, top_p=1.0, frequency_penalty=-1, - logit_bias={1: -1} + logit_bias={1: -1, 3: 1, 2: 2} ) engine.add( diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index 5df6bd8023..5012b78487 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -50,9 +50,13 @@ class SamplingParams: top_k: int = -1 logit_bias: Optional[Dict[int, float]] = None appeared_tokens_freq: Dict[int, int] = None + logit_bias_index: list[int] = None + logit_bias_value: list[float] = None def __post_init__(self): self.appeared_tokens_freq = {} + self.logit_bias_index = list(self.logit_bias.keys()) + self.logit_bias_value = list(self.logit_bias.values()) self._verify_args() if self.temperature < _SAMPLING_EPS: # Zero temperature means greedy sampling. diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 1f9a6fd4f6..bab0a78ebb 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -106,16 +106,12 @@ def _is_safe_to_sample(prob_like): do_top_k |= top_ks[-1] != vocab_size if not param.presence_penalty == 0.0 or not param.frequency_penalty == 0 and bool(freq): - freq_tensor = np.array(list(freq.items())) - index = torch.from_numpy(freq_tensor[..., 0]).to(device=logits.device) - src = torch.from_numpy(freq_tensor[..., 1]).type_as(logits).to(device=logits.device) - logits[i] = torch.scatter_add(logits[i], dim=0, index=index, src=-(src * param.frequency_penalty + param.presence_penalty)) + index = torch.from_numpy(np.array(list(freq.keys()))).to(device=logits.device) + src = torch.from_numpy(np.array(list(freq.values()))).type_as(logits).to(device=logits.device) + logits[i][index] -= src * param.frequency_penalty + param.presence_penalty if param.logit_bias: - bias_tensor = np.array(list(param.logit_bias.items())) - index = torch.from_numpy(bias_tensor[..., 0]).to(device=logits.device) - src = torch.from_numpy(bias_tensor[..., 1]).type_as(logits).to(device=logits.device) - logits[i] = torch.scatter_add(logits[i], dim=0, index=index, src=src) + logits[i][param.logit_bias_index] += torch.Tensor(param.logit_bias_value).type_as(logits).to(device=logits.device) logits_random = logits[mask_random] @@ -521,6 +517,9 @@ def generate( ) if maybe_new_token is not None: + if not maybe_new_token in requests[i].sampling_params.appeared_tokens_freq: + requests[i].sampling_params.appeared_tokens_freq[maybe_new_token] = 0 + requests[i].sampling_params.appeared_tokens_freq[maybe_new_token] += 1 if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: for seq_id in range(num_sequences[i]): outputs.append(