Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Param] Recheck and update repetition penalty parameter #202

Merged
merged 9 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Common utilites for engine classes.
"""

import torch
import time
from typing import Tuple, Deque, Dict, Optional, Union, Callable, List
from collections import deque
Expand Down Expand Up @@ -240,6 +241,18 @@ def prepare_output(
return delta, out_logprob_info


def set_mask_prompt_to(state: RequestState):
vvchernov marked this conversation as resolved.
Show resolved Hide resolved
# Prompt tokens
tokens=torch.tensor(state.prompt_token_ids, dtype=torch.long)
vocab_size = state.sampling_params.vocab_size
bin_counts = torch.zeros((vocab_size + 1,),
dtype=torch.long,
device=tokens.device)
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
bin_counts = bin_counts[:vocab_size]
state.sampling_params.mask_prompt = bin_counts > 0


def get_requests_to_process(
current_states: list[RequestState],
cache_manager: KVCacheManager,
Expand All @@ -264,6 +277,9 @@ def get_requests_to_process(
if is_prompt_batch:
for state in current_states:
if is_evicted_parallel_sampling_request(state):
# TODO(vvchernov): we still need mask if apply_penalty = True
# if state.sampling_params.repetition_penalty != 1.0:
set_mask_prompt_to(state)
requests.append(
PrefillRequest(
request_id=state.request_id,
Expand Down
2 changes: 2 additions & 0 deletions serve/mlc_serve/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from enum import IntEnum
from functools import cached_property
from typing import Dict, Optional, Any
import torch

_SAMPLING_EPS = 1e-5
LOGPROB_TOP_K_MAX = 5
Expand Down Expand Up @@ -75,6 +76,7 @@ class SamplingParams:
vocab_size = 32000
json_schema: Optional[Dict[str, Any]] = None
logits_processor: Optional[Any] = None
mask_prompt: Optional[torch.Tensor] = None

def __post_init__(self):
if self.logit_bias:
Expand Down
12 changes: 6 additions & 6 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def sample_from_logits(
logits: Union[tvm.nd.NDArray, torch.Tensor],
sequence_ids: List[SequenceId],
requests: Sequence[RequestType],
sampling_metadata: SamplingState,
sampling_state: SamplingState,
vocab_size: int,
copy_stream: torch.cuda.Stream,
torch_dtype: torch.dtype,
Expand All @@ -110,13 +110,13 @@ def sample_from_logits(
sequence_id, cs_input_ids, logits[i]
)

logits = adjust_logits(logits, sampling_metadata, vocab_size)
logits = adjust_logits(logits, sampling_state, vocab_size)
outputs: List[TextGenerationResult] = []

try:
sampling_output: Optional[SamplingOutput] = sample(
logits,
sampling_metadata,
sampling_state,
)

for i, (new_token, logprob_info) in enumerate(
Expand All @@ -142,13 +142,13 @@ def sample_from_logits(
for i in range(batch_size):
sequence_id = sequence_ids[i]
logits_per_token = logits[i]
sampling_param = sampling_metadata.sampling_params[i]
sampling_param = sampling_state.sampling_params[i]
past_decode_tokens_per_request = past_decode_tokens[i]
# NOTE: Rerun the preparation for simplicity.
# Assume this code path is taken rarely and the recomputation overhead is
# marginal.
with torch.cuda.stream(copy_stream):
new_sampling_metadata = SamplingState.from_sampling_params(
new_sampling_state = SamplingState.from_sampling_params(
[sampling_param],
[past_decode_tokens_per_request],
torch_dtype,
Expand All @@ -158,7 +158,7 @@ def sample_from_logits(
torch.cuda.current_stream().wait_stream(copy_stream)
maybe_sampling_output: Optional[SamplingOutput] = sample(
torch.unsqueeze(logits_per_token, 0),
new_sampling_metadata,
new_sampling_state,
check_safety=True,
)

Expand Down
91 changes: 64 additions & 27 deletions serve/mlc_serve/model/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class SamplingTensors:
mask_top_logprob: torch.Tensor
Mask for requests with top_logprob.
shape: (LOGPROB_TOP_K_MAX) + 1, batch_size,)
mask_prompt: torch.Tensor
Mask for request with repetition penalty (prompt part)
shape: (batch_size, vocab_size)
temperatures: torch.Tensor
Tensor for temperature values
shape: (batch_size, )
Expand Down Expand Up @@ -85,6 +88,7 @@ class SamplingTensors:
mask_random: torch.Tensor
mask_greedy: torch.Tensor
mask_top_logprob: torch.Tensor
mask_prompt: torch.Tensor
temperatures: torch.Tensor
top_ps: torch.Tensor
top_ks: torch.Tensor
Expand All @@ -102,6 +106,7 @@ def from_lists(
dev,
list_mask_random: List[bool],
list_mask_top_logprob: List[List[bool]],
list_mask_prompt: List[torch.Tensor],
list_temperatures: List[float],
list_top_ps: List[float],
list_top_ks: List[int],
Expand All @@ -124,6 +129,7 @@ def from_lists(
)
# `mask_top_logprob` will be on cpu
mask_top_logprob = torch.from_numpy(list_mask_top_logprob)
mask_prompt = torch.stack(list_mask_prompt)
temp = torch.tensor(
list_temperatures,
dtype=dtype,
Expand Down Expand Up @@ -185,6 +191,7 @@ def from_lists(
mask_random,
mask_greedy,
mask_top_logprob,
mask_prompt,
temp.to(device=dev, non_blocking=True),
top_ps.to(device=dev, non_blocking=True),
top_ks.to(device=dev, non_blocking=True),
Expand Down Expand Up @@ -250,6 +257,7 @@ def from_sampling_params(
vocab_size: int,
):
list_mask_random = []
list_mask_prompt = []
list_temperatures = []
list_top_ps = []
list_top_ks = []
Expand Down Expand Up @@ -307,6 +315,7 @@ def from_sampling_params(
list_frequency_penalties.append(param.frequency_penalty)
list_presence_penalties.append(param.presence_penalty)
list_repetition_penalties.append(param.repetition_penalty)
list_mask_prompt.append(param.mask_prompt)

if param.logit_bias_index:
assert param.logit_bias_value
Expand Down Expand Up @@ -348,6 +357,7 @@ def from_sampling_params(
dev,
list_mask_random,
list_mask_top_logprob,
list_mask_prompt,
list_temperatures,
list_top_ps,
list_top_ks,
Expand All @@ -372,20 +382,39 @@ def from_sampling_params(
)


def adjust_logits(logits, sampling_metadata, vocab_size):
def get_bin_counts_and_mask(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment with above. Would it make sense to move this to SamplingState prep and perform this on CPU? If the perf impact is not bad, it seems better to prepare there as we are looking for an option to make the prep process more async.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valery talked about this in the issue, repasting some of it here:

In my point of view the good place is SamplingState, but it calculates each time for new request. It looks like it is not good thing due to in general it can be done once, but redo this requires much time. May be SamplingParams in Request should be replaced by SamplingState with some sharing for parameter n > 1.
What do you think about it?
Just now I plan to save it in SamplingParams for quick fix, but we need a better place

Copy link
Member

@sunggg sunggg Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks @binarybana and this is great point. @vvchernov, can we follow-up about this? We can separate what needs/does not need to be updated each iteration like you raised.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, thanks @binarybana to share my ideas! I repeat it again in other words above.

can we follow-up about this? We can separate what needs/does not need to be updated each iteration like you raised.

First of all it looks like a separate task, if we need to support correct repetition penalty in high priority it is better to do it first. I see the task is redesign of API, it should be done carefully and we need more discussions in details. I can prepare separate draft of PR with some dirty code which implements my idea and we will discuss there how it should be. And I'm still aware about logits_processor. Now it looks like it was done for json mode, but in general view it does the same as sampler. It is better also rethink this moment in our redesign

tokens: torch.Tensor,
vocab_size: int,
num_seqs: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
dtype=torch.long,
device=tokens.device)
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
bin_counts = bin_counts[:, :vocab_size]
mask = bin_counts > 0

return bin_counts, mask


def adjust_logits(
logits: torch.Tensor,
sampling_state: SamplingState,
vocab_size: int):
batch_size = logits.shape[0]
(
apply_top_p_top_k,
apply_penalty,
apply_bias,
sampling_tensors,
) = (
sampling_metadata.apply_top_p_top_k,
sampling_metadata.apply_penalty,
sampling_metadata.apply_bias,
sampling_metadata.sampling_tensors,
sampling_state.apply_top_p_top_k,
sampling_state.apply_penalty,
sampling_state.apply_bias,
sampling_state.sampling_tensors,
)
(
prompt_mask,
temp_t,
top_ps_t,
top_ks_t,
Expand All @@ -396,6 +425,7 @@ def adjust_logits(logits, sampling_metadata, vocab_size):
logit_bias_indices_t,
logit_bias_values_t,
) = (
sampling_tensors.mask_prompt,
sampling_tensors.temperatures,
sampling_tensors.top_ps,
sampling_tensors.top_ks,
Expand All @@ -411,20 +441,27 @@ def adjust_logits(logits, sampling_metadata, vocab_size):
# (e.g., repetition penalty, frequency/presence penalty, logit bias, temperature...)
# in the right order.
if apply_penalty:
repetition_penalties_t = repetition_penalties_t[:, None].repeat(1, vocab_size)
logits = torch.where(
logits > 0, logits / repetition_penalties_t, logits * repetition_penalties_t
)
bin_counts = torch.zeros(
(batch_size, vocab_size + 1), dtype=torch.long, device=logits.device
bin_counts, output_mask = get_bin_counts_and_mask(
past_output_tokens_t,
vocab_size,
batch_size,
)
bin_counts.scatter_add_(
1, past_output_tokens_t, torch.ones_like(past_output_tokens_t)
)
bin_counts = bin_counts[:, :vocab_size]
mask = bin_counts > 0

# It was checked that vLLM and HF approaches for repetition penalty are the same
# For calculation of it their combination is used (see references below)
# Calculate repetition penalty use vLLM approach
# https://github.com/vllm-project/vllm/blob/0580aab02ffe60fee50bddc80b787828eb233c44/vllm/model_executor/layers/sampler.py#L177
# and RepetitionPenaltyLogitsProcessor approach from HF TGI API
vvchernov marked this conversation as resolved.
Show resolved Hide resolved
# https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L332C1-L339C22
# where score is logits
# https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L76C1-L78C92
repetition_penalties_t = repetition_penalties_t[:, None].repeat(1, vocab_size)
repetition_penalties_t[~(prompt_mask | output_mask)] = 1.0
logits /= repetition_penalties_t

# Calculate frequency and presence penalties
logits -= frequency_penalties_t.unsqueeze_(dim=1) * bin_counts
logits -= presence_penalties_t.unsqueeze_(dim=1) * mask
logits -= presence_penalties_t.unsqueeze_(dim=1) * output_mask

# Adjust temperature
logits.div_(temp_t.unsqueeze(dim=1))
Expand All @@ -447,7 +484,7 @@ class SamplingOutput:

def sample(
logits: torch.Tensor,
sampling_metadata: SamplingState,
sampling_state: SamplingState,
check_safety: bool = False,
) -> SamplingOutput:
def _is_safe_to_sample(prob_like):
Expand All @@ -457,7 +494,7 @@ def _is_safe_to_sample(prob_like):
)

res_greedy, res_random = None, None
sampling_tensors = sampling_metadata.sampling_tensors
sampling_tensors = sampling_state.sampling_tensors

batch_size = logits.shape[0]
mask_greedy_t, mask_random_t = (
Expand All @@ -466,13 +503,13 @@ def _is_safe_to_sample(prob_like):
)

next_tokens = np.empty((batch_size,), dtype=np.int64)
if sampling_metadata.has_greedy:
if sampling_state.has_greedy:
res_greedy = torch.argmax(logits[mask_greedy_t], -1)
np_mask_greedy = mask_greedy_t.cpu().numpy()
next_tokens[np_mask_greedy] = res_greedy.cpu().numpy()

probs_random = None
if sampling_metadata.has_random:
if sampling_state.has_random:
probs_random = torch.softmax(logits[mask_random_t], dim=-1)
if check_safety and not _is_safe_to_sample(probs_random):
return None
Expand All @@ -481,9 +518,9 @@ def _is_safe_to_sample(prob_like):
next_tokens[np_mask_random] = res_random.cpu().numpy()

logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * batch_size
if sampling_metadata.has_logprob:
if sampling_state.has_logprob:
# If everything is random sampling, save one extra softmax
if not sampling_metadata.has_greedy:
if not sampling_state.has_greedy:
assert probs_random is not None
logprobs = torch.log(probs_random)
else:
Expand All @@ -494,13 +531,13 @@ def _is_safe_to_sample(prob_like):
all_top_logprobs, all_top_tokens = torch.topk(
extended_logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True
)
mask = sampling_metadata.sampling_tensors.mask_top_logprob
mask = sampling_state.sampling_tensors.mask_top_logprob
top_tokens = all_top_tokens[mask]
top_logprobs = all_top_logprobs[mask]
for idx, batch_idx in enumerate(sampling_metadata.logprob_batch_indices):
for idx, batch_idx in enumerate(sampling_state.logprob_batch_indices):
next_token = next_tokens[batch_idx]
assert sampling_metadata.sampling_params[batch_idx].logprobs
top_k = sampling_metadata.sampling_params[batch_idx].top_logprobs
assert sampling_state.sampling_params[batch_idx].logprobs
top_k = sampling_state.sampling_params[batch_idx].top_logprobs
logprob_infos[batch_idx] = RawLogprobsInfo(
current_token_id=next_token,
current_logprob=logprobs[batch_idx][next_token],
Expand Down
8 changes: 4 additions & 4 deletions serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def generate_multi_query(
# Prepare sampling tensors in another stream to overlap
# CPU<->GPU data transfer with GPU computation in forward pass.
with torch.cuda.stream(self._copy_stream):
sampling_metadata = SamplingState.from_sampling_params(
sampling_state = SamplingState.from_sampling_params(
sampling_params,
past_decode_tokens,
self.torch_dtype,
Expand Down Expand Up @@ -318,7 +318,7 @@ def generate_multi_query(
last_query_logits,
sequence_ids,
requests,
sampling_metadata,
sampling_state,
self.vocab_size,
self._copy_stream,
self.torch_dtype,
Expand Down Expand Up @@ -381,7 +381,7 @@ def generate(
# Prepare sampling tensors in another stream to overlap
# CPU<->GPU data transfer with GPU computation in forward pass.
with torch.cuda.stream(self._copy_stream):
sampling_metadata = SamplingState.from_sampling_params(
sampling_state = SamplingState.from_sampling_params(
sampling_params,
past_decode_tokens,
self.torch_dtype,
Expand Down Expand Up @@ -502,7 +502,7 @@ def generate(
logits,
sequence_ids,
requests,
sampling_metadata,
sampling_state,
self.vocab_size,
self._copy_stream,
self.torch_dtype,
Expand Down
Loading
Loading