Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sampling penalties and logit bias #125

Merged
merged 6 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions serve/benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def run_mlc(
sampling_params = SamplingParams(
temperature=1.0,
top_p=1.0,
frequency_penalty=-1,
logit_bias={1: -1, 3: 1, 2: 2}
)

engine.add(
Expand Down
2 changes: 2 additions & 0 deletions serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams:
sampling_params.temperature = request.temperature
if request.top_p is not None:
sampling_params.top_p = request.top_p
if request.logit_bias is not None:
sampling_params.logit_bias = request.logit_bias
return sampling_params


Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class ChatCompletionRequest(BaseModel):
stream: bool = False
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
logit_bias: Optional[Dict[str, float]] = None
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
ignore_eos: Optional[bool] = False

Expand Down
17 changes: 17 additions & 0 deletions serve/mlc_serve/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

based on https://github.com/vllm-project/vllm/blob/ac5cf86aa6aebbf9e42df51f7e377fbee85bc703/vllm/sampling_params.py
"""
from collections import defaultdict
from dataclasses import dataclass
from enum import IntEnum
from functools import cached_property
from typing import Dict, Optional


_SAMPLING_EPS = 1e-5
Expand Down Expand Up @@ -37,15 +39,24 @@ class SamplingParams:
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
logit_bias: The bias applied on the logit before sampling. Must be in
[-100, 100].
"""

presence_penalty: float = 0.0
frequency_penalty: float = 0.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
logit_bias: Optional[Dict[int, float]] = None
appeared_tokens_freq: Dict[int, int] = None
logit_bias_index: list[int] = None
logit_bias_value: list[float] = None

def __post_init__(self):
self.appeared_tokens_freq = {}
self.logit_bias_index = list(self.logit_bias.keys())
self.logit_bias_value = list(self.logit_bias.values())
self._verify_args()
if self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
Expand All @@ -71,6 +82,12 @@ def _verify_args(self) -> None:
raise ValueError(
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
)
if self.logit_bias:
for token, bias in self.logit_bias.items():
if not -100 <= bias <= 100:
raise ValueError(
f"logit bias must be in [-100, 100], got {bias} for token {token}."
)

def _verify_greedy_sampling(self) -> None:
if self.top_p < 1.0 - _SAMPLING_EPS:
Expand Down
22 changes: 19 additions & 3 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _is_safe_to_sample(prob_like):

for i in range(num_seq):
param = sampling_params[i]
freq = param.appeared_tokens_freq

if param.sampling_type == SamplingType.RANDOM:
temperatures.append(param.temperature)
Expand All @@ -104,6 +105,15 @@ def _is_safe_to_sample(prob_like):
do_top_p |= top_ps[-1] < 1.0
do_top_k |= top_ks[-1] != vocab_size

if not param.presence_penalty == 0.0 or not param.frequency_penalty == 0 and bool(freq):
index = torch.from_numpy(np.array(list(freq.keys()))).to(device=logits.device)
src = torch.from_numpy(np.array(list(freq.values()))).type_as(logits).to(device=logits.device)
logits[i][index] -= src * param.frequency_penalty + param.presence_penalty

if param.logit_bias:
logits[i][param.logit_bias_index] += torch.Tensor(param.logit_bias_value).type_as(logits).to(device=logits.device)


logits_random = logits[mask_random]

if divide_by_temperature:
Expand Down Expand Up @@ -462,11 +472,13 @@ def generate(
try:
next_tokens = sample(logits, sampling_params, self.vocab_size)
assert next_tokens is not None

outputs = []
for i, (sequence_id, new_token) in enumerate(
zip(sequence_ids, next_tokens)
):
if not new_token in requests[i].sampling_params.appeared_tokens_freq:
requests[i].sampling_params.appeared_tokens_freq[new_token] = 0
requests[i].sampling_params.appeared_tokens_freq[new_token] += 1
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
for seq_id in range(num_sequences[i]):
outputs.append(
Expand Down Expand Up @@ -505,22 +517,26 @@ def generate(
)

if maybe_new_token is not None:
new_token = maybe_new_token[0]
if not new_token in requests[i].sampling_params.appeared_tokens_freq:
requests[i].sampling_params.appeared_tokens_freq[new_token] = 0
requests[i].sampling_params.appeared_tokens_freq[new_token] += 1
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
for seq_id in range(num_sequences[i]):
outputs.append(
TextGenerationResult(
sequence_id=SequenceId(
sequence_id.request_id, seq_id
),
generated_tokens=[maybe_new_token[0]], # type: ignore
generated_tokens=[new_token], # type: ignore
error=None,
)
)
else:
outputs.append(
TextGenerationResult(
sequence_id=sequence_id,
generated_tokens=[maybe_new_token[0]], # type: ignore
generated_tokens=[new_token], # type: ignore
error=None,
)
)
Expand Down
56 changes: 50 additions & 6 deletions serve/tests/unittest/test_engine_with_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,15 @@ def create_engine(
))
return engine

def create_request(idx, prompt, temp, max_tokens, stop, ignore_eos):
def create_request(idx, prompt, temp, freq_pen, pre_pen, max_tokens, stop, ignore_eos, logit_bias=None):
return Request(
request_id = str(idx),
messages = [ChatMessage(role="user", content=prompt)],
sampling_params = SamplingParams(
temperature=0.0,
temperature=temp,
frequency_penalty=freq_pen,
presence_penalty=pre_pen,
logit_bias=logit_bias,
),
stopping_criteria = StoppingCriteria(
max_tokens=max_tokens,
Expand All @@ -83,7 +86,7 @@ def _test_max_tokens(
max_input_len,
)

requests = [create_request(idx=str(n-1), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=ignore_eos) for n in range(1, num_requests)]
requests = [create_request(idx=str(n-1), prompt=prompt, temp=0, freq_pen=0, pre_pen=0, max_tokens=n, stop=None, ignore_eos=ignore_eos) for n in range(1, num_requests)]
engine.add(requests)

generated = ["" for _ in range(num_requests)]
Expand Down Expand Up @@ -122,7 +125,7 @@ def _test_max_context_length(
)
prompt = "hi " * (max_context_length - 15)

requests = [create_request(idx=str(n), prompt=prompt, temp=0, max_tokens=None, stop=None, ignore_eos=ignore_eos) for n in range(num_requests)]
requests = [create_request(idx=str(n), prompt=prompt, temp=0, freq_pen=0, pre_pen=0, max_tokens=None, stop=None, ignore_eos=ignore_eos) for n in range(num_requests)]
engine.add(requests)

generated = ["" for _ in range(num_requests)]
Expand Down Expand Up @@ -157,7 +160,7 @@ def _test_ignore_eos(
max_input_len,
)
s = 113
requests = [create_request(idx=str(n-s), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=True) for n in range(s, s+num_requests)]
requests = [create_request(idx=str(n-s), prompt=prompt, temp=0, freq_pen=0, pre_pen=0, max_tokens=n, stop=None, ignore_eos=True) for n in range(s, s+num_requests)]
engine.add(requests)

generated = ["" for _ in range(num_requests)]
Expand Down Expand Up @@ -194,7 +197,7 @@ def _test_stop(
)
requests = []
for n, stop in enumerate(["\n", ["\n"], "\n\n", "!", ["n", "!"]]):
requests.append(create_request(idx=str(n), prompt=prompt, temp=0, max_tokens=300, stop=stop, ignore_eos=False))
requests.append(create_request(idx=str(n), prompt=prompt, temp=0, freq_pen=0, pre_pen=0, max_tokens=300, stop=stop, ignore_eos=False))
engine.add(requests)

generated = ["" for _ in range(num_requests)]
Expand All @@ -221,6 +224,44 @@ def _test_stop(
engine.stop()


def test_penalty(
model_artifact_path,
use_staging_engine,
max_num_sequences=4,
max_input_len=512,
num_requests=5,
ignore_eos=False
):
prompt = "Write a merge sort program in Python."
engine = create_engine(
model_artifact_path,
use_staging_engine,
max_num_sequences,
max_input_len,
)

random_requests = [create_request(idx=str(n-1), prompt=prompt, temp=0.5, freq_pen=0.5, pre_pen=-0.5, max_tokens=n, stop=None, ignore_eos=ignore_eos, logit_bias={123: -100, 456: 100}) for n in range(1, num_requests)]
greedy_requests = [create_request(idx=str(n-1), prompt=prompt, temp=0, freq_pen=0, pre_pen=0, max_tokens=n, stop=None, ignore_eos=ignore_eos) for n in range(num_requests, num_requests << 1)]
requests = random_requests + greedy_requests
engine.add(requests)

generated = ["" for _ in range(num_requests << 1)]

while engine.has_pending_requests():
results = engine.step()
for res in results.outputs:
assert len(res.sequences) == 1
seq = res.sequences[0]

if seq.is_finished:
assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens
assert seq.finish_reason == FinishReason.Length
else:
generated[int(res.request_id)] += seq.delta

if use_staging_engine:
engine.stop()

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local-id", type=str, required=True)
Expand All @@ -238,3 +279,6 @@ def _test_stop(
# if max_tokens = None. The tests do not finish in a reasonable time.
# _test_max_context_length(model_artifact_path, use_staging_engine=True)
# _test_max_context_length(model_artifact_path, use_staging_engine=False)
test_penalty(model_artifact_path, use_staging_engine=True)
test_penalty(model_artifact_path, use_staging_engine=False)