Skip to content

Commit

Permalink
apply code review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
cyx-6 committed Dec 20, 2023
1 parent 69eb2fd commit 941f2e4
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
2 changes: 1 addition & 1 deletion serve/benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run_mlc(
temperature=1.0,
top_p=1.0,
frequency_penalty=-1,
logit_bias={1: -1}
logit_bias={1: -1, 3: 1, 2: 2}
)

engine.add(
Expand Down
4 changes: 4 additions & 0 deletions serve/mlc_serve/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,13 @@ class SamplingParams:
top_k: int = -1
logit_bias: Optional[Dict[int, float]] = None
appeared_tokens_freq: Dict[int, int] = None
logit_bias_index: list[int] = None
logit_bias_value: list[float] = None

def __post_init__(self):
self.appeared_tokens_freq = {}
self.logit_bias_index = list(self.logit_bias.keys())
self.logit_bias_value = list(self.logit_bias.values())
self._verify_args()
if self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
Expand Down
15 changes: 7 additions & 8 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,12 @@ def _is_safe_to_sample(prob_like):
do_top_k |= top_ks[-1] != vocab_size

if not param.presence_penalty == 0.0 or not param.frequency_penalty == 0 and bool(freq):
freq_tensor = np.array(list(freq.items()))
index = torch.from_numpy(freq_tensor[..., 0]).to(device=logits.device)
src = torch.from_numpy(freq_tensor[..., 1]).type_as(logits).to(device=logits.device)
logits[i] = torch.scatter_add(logits[i], dim=0, index=index, src=-(src * param.frequency_penalty + param.presence_penalty))
index = torch.from_numpy(np.array(list(freq.keys()))).to(device=logits.device)
src = torch.from_numpy(np.array(list(freq.values()))).type_as(logits).to(device=logits.device)
logits[i][index] -= src * param.frequency_penalty + param.presence_penalty

if param.logit_bias:
bias_tensor = np.array(list(param.logit_bias.items()))
index = torch.from_numpy(bias_tensor[..., 0]).to(device=logits.device)
src = torch.from_numpy(bias_tensor[..., 1]).type_as(logits).to(device=logits.device)
logits[i] = torch.scatter_add(logits[i], dim=0, index=index, src=src)
logits[i][param.logit_bias_index] += torch.Tensor(param.logit_bias_value).type_as(logits).to(device=logits.device)


logits_random = logits[mask_random]
Expand Down Expand Up @@ -521,6 +517,9 @@ def generate(
)

if maybe_new_token is not None:
if not maybe_new_token in requests[i].sampling_params.appeared_tokens_freq:
requests[i].sampling_params.appeared_tokens_freq[maybe_new_token] = 0
requests[i].sampling_params.appeared_tokens_freq[maybe_new_token] += 1
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
for seq_id in range(num_sequences[i]):
outputs.append(
Expand Down

0 comments on commit 941f2e4

Please sign in to comment.