Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Parallel sampling eviction" #189

Merged
merged 1 commit into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
class RawLogprobsInfo:
current_token_id: int
current_logprob: float
top_token_ids: Optional[np.ndarray]
top_logprobs: Optional[np.ndarray]
top_token_ids: Optional[np.array]
top_logprobs: Optional[np.array]

RawLogprobsInfos = List[Optional[RawLogprobsInfo]]

Expand Down
119 changes: 25 additions & 94 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from .model_module import (
DecodeRequest,
PrefillRequest,
EvalMultiQueryRequest,
EvictedTokens,
ConversationTemplate,
KVCacheManager,
ModelModule,
Expand Down Expand Up @@ -228,70 +226,26 @@ def update_sequence(

def get_requests_to_process(
current_states: list[RequestState], cache_manager: KVCacheManager
) -> Tuple[
list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], bool, int
]:
requests: list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]] = []
) -> Tuple[list[Union[PrefillRequest, DecodeRequest]], bool, int]:
requests: list[Union[PrefillRequest, DecodeRequest]] = []
# TODO: consider having hybrid batch if the underlying attention kernel supports
# mixing prefill and decode.
is_prompt_batch = any(not state.is_prefilled for state in current_states)

token_counts = 0

is_evicted_parallel_sampling_request = (
lambda state: not state.is_prefilled
and state.num_sequences > 1
and any(
len(gen_seq.generated_token_ids) > 0
for gen_seq in state.generation_sequences
)
)

if is_prompt_batch:
for state in current_states:
if is_evicted_parallel_sampling_request(state):
requests.append(
PrefillRequest(
request_id=state.request_id,
token_ids=state.prompt_token_ids,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
)
)

token_counts += len(state.prompt_token_ids)

for gen_seq in state.generation_sequences:
requests.append(
EvalMultiQueryRequest(
sequence_id=gen_seq.seq_id,
num_past_tokens=state.prompt_len,
queries=EvictedTokens(gen_seq.generated_token_ids),
sampling_params=state.sampling_params,
)
)
cache_manager.extend(
gen_seq.seq_id,
len(gen_seq.generated_token_ids) + 1,
)

# TODO(masahi): How to account for token counts in EvalMultiQueryRequest in
# Prometheus metric?
elif not state.is_prefilled:
token_ids = state.prompt_token_ids
# generated_token_ids is added for the case where the request is
# recovering from cache eviction.

if (
state.num_sequences == 1
and state.generation_sequences[0].generated_token_ids
):
token_ids += state.generation_sequences[0].generated_token_ids

if not state.is_prefilled:
requests.append(
# generated_token_ids is added for the case where the request is
# recovering from cache eviction.
# TODO(masahi): This needs an update when we support evicting
# a parallel-sampling request.
PrefillRequest(
request_id=state.request_id,
token_ids=token_ids,
token_ids=state.prompt_token_ids
+ state.generation_sequences[0].generated_token_ids,
num_sequence=state.num_sequences,
sampling_params=state.sampling_params,
)
Expand Down Expand Up @@ -438,28 +392,16 @@ def evict_request(self, cancell_callback: Callable[[RequestId], None]) -> int:
candidate_victims = parallel_sample_requests

request_to_remove = min(candidate_victims, key=lambda s: s.num_total_tokens)
victim_state = self.current_batch[request_to_remove.request_id]

if victim_state.num_sequences != 1:
prev_generated_token_counts = sum(
[
len(gen_seq.generated_token_ids)
for gen_seq in victim_state.generation_sequences
]

# TODO(masahi): Properly support evicting a multi-sequence request
if self.current_batch[request_to_remove.request_id].num_sequences != 1:
cancell_callback(request_to_remove.request_id)
self.remove_request_from_batch(request_to_remove.request_id)
LOG.warn(
"Preempting a multi-sequence request is currently not supported,"
f" cancelling request '{request_to_remove.request_id}'",
)
# We could allow evicting and restoring a parallel-sampling request whose prev_generated_token_counts
# is > max_num_batched_tokens, by making the model split a list of EvalMultiQuery requests into parts,
# so that an inference on each part can be done with the max_num_batched_tokens budget.
# But this introduces an undesirable coupling between the engine and the model.
if prev_generated_token_counts >= self.max_num_batched_tokens:
cancell_callback(request_to_remove.request_id)
self.remove_request_from_batch(request_to_remove.request_id)
LOG.warn(
f"Cancelling a parallel-sampling request '{request_to_remove.request_id}'"
f"since it has generated more than {self.max_num_batched_tokens} tokens in total"
"and currently we do not support preempting such request.",
)
continue
continue

self.remove_request_from_batch(request_to_remove.request_id)
request_to_remove.is_prefilled = False
Expand Down Expand Up @@ -504,27 +446,14 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]:
gen_seq.next_start_position = (
num_new_batched_tokens
) = num_tokens = self.max_num_batched_tokens

num_kv_slots_needed = min(num_tokens, self.model_context_window_size)
else:
prev_generated_token_counts = sum(
[
len(gen_seq.generated_token_ids)
for gen_seq in state.generation_sequences
]
# Evicting and recovering multi-sequence requests is not supported for now.
assert all(
gen_seq.next_start_position == state.prompt_len
for gen_seq in state.generation_sequences
)

# Restoring an evicted parallel-sampling request with sliding-window attention is
# difficult to reason about, so we use crude upper bounds below for now.
num_tokens = state.prompt_len
num_kv_slots_needed = state.prompt_len + prev_generated_token_counts
# Restoring an evicted parallel-sampling request is done by separate
# Prefill and MultiQuery requests. The maximum below is an upper bound on the
# batch size increase due to this request.
# TODO(masahi): Prefill and EvalMultiQuery requests are handled separately by the model.
# So comparing the sum of their batched token counts against max_num_batched_tokens
# is not optimal.
num_new_batched_tokens += max(state.prompt_len, prev_generated_token_counts)
num_new_batched_tokens += num_tokens

if num_new_batched_tokens > self.max_num_batched_tokens:
LOG.debug(
Expand All @@ -536,6 +465,7 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]:
# We make sure that the KV cache will have enough free space for this request to proceed
# decoding for at least self.max_decode_steps steps.
# See the comment in check_prompt_too_long for the optimization involving the window size.
num_kv_slots_needed = min(num_tokens, self.model_context_window_size)
if (self.cache_manager.get_free_space() - num_kv_slots_needed) / (
len(self.current_batch) + 1
) < self.max_decode_steps * state.num_sequences:
Expand All @@ -547,6 +477,7 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]:
return None

self.queue.popleft()
# TODO parallel sampling: Need update here when evicting multi-sequence requests is supported.
self.cache_manager.allocate(state.request_id, num_tokens, state.num_sequences)
self.current_batch[state.request_id] = state

Expand Down
30 changes: 2 additions & 28 deletions serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,37 +35,11 @@ class PrefillRequest:
class DecodeRequest:
sequence_id: SequenceId
prompt_token_counts: int
# Decoded tokens for this sequence
# All tokens for this request, including prompt
token_ids: List[int]
sampling_params: SamplingParams


@dataclass
class DraftTokens:
token_ids: List[int]

@property
def num_tokens(self):
return len(self.token_ids)


@dataclass
class EvictedTokens:
token_ids: List[int]

@property
def num_tokens(self):
return len(self.token_ids)


@dataclass
class EvalMultiQueryRequest:
sequence_id: SequenceId
num_past_tokens: int
queries: Union[DraftTokens, EvictedTokens]
sampling_params: SamplingParams


@dataclass
class TextGenerationResult:
"""
Expand Down Expand Up @@ -151,7 +125,7 @@ class TextGenerator(Protocol):

def generate(
self,
requests: Sequence[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]],
requests: Sequence[Union[PrefillRequest, DecodeRequest]],
kv_cache,
) -> List[TextGenerationResult]:
"""
Expand Down
Loading
Loading