Skip to content

Commit

Permalink
Recover from sampling error (#67)
Browse files Browse the repository at this point in the history
* check for sampling error

* wip

* fix sync engine when there is error-ed request

* clean
  • Loading branch information
masahi authored Nov 16, 2023
1 parent 8c16b18 commit f8609cd
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 20 deletions.
12 changes: 8 additions & 4 deletions serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def step(self) -> InferenceStepResult:
results = self.text_generator.generate(requests, self.cache_manager.get_cache())
logger.debug("Finished text generation.")

valid_results = []

for res in results:
# For now we only support single sequence per request
request_id = res.sequence_id.request_id
Expand All @@ -154,11 +156,15 @@ def step(self) -> InferenceStepResult:
error=res.error,
)
)
continue
else:
valid_results.append(res)

for res in valid_results:
request_id = res.sequence_id.request_id
state = self.current_batch[request_id]
state.next_start_position = len(state.token_ids)
new_token_ids = res.generated_tokens

for i, token_id in enumerate(new_token_ids):
if (
token_id == self.tokenizer.eos_token_id
Expand All @@ -167,11 +173,9 @@ def step(self) -> InferenceStepResult:
new_token_ids = new_token_ids[:i]
state.is_ended = True
break

state.token_ids.extend(new_token_ids)

for res in results:
request_id = res.sequence_id.request_id
state = self.current_batch[request_id]
delta = self._decode_last_output(state)
state.output_text += delta

Expand Down
80 changes: 64 additions & 16 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
from tvm.runtime import disco as di

from mlc_llm import utils
from mlc_llm.relax_model.llama import LlamaConfig

from .base import get_model_artifact_config
from .tokenizer import HfTokenizerModule
from ..engine import ChatMessage, RequestId, SamplingType, MLCServeEngineConfig
from ..engine import RequestId, SamplingType, MLCServeEngineConfig, SamplingParams
from ..engine.model_module import (
DecodeRequest,
PrefillRequest,
Expand Down Expand Up @@ -244,7 +243,18 @@ def _apply_top_p_top_k(logits, top_ps, top_ks):
return logits


def sample(logits, sampling_params, vocab_size):
def sample(
logits: Union[tvm.nd.NDArray, torch.Tensor],
sampling_params: List[SamplingParams],
vocab_size: int,
check_safety=False,
) -> Optional[np.ndarray]:
def _is_safe_to_sample(prob_like):
return (
torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0))
== 0
)

logits = torch.from_dlpack(logits)
num_seq = len(sampling_params)

Expand Down Expand Up @@ -291,6 +301,10 @@ def sample(logits, sampling_params, vocab_size):
logits = _apply_top_p_top_k(logits_random, top_ps, top_ks)

probs = torch.softmax(logits_random, dim=-1)

if check_safety and not _is_safe_to_sample(probs):
return None

res_random = torch.multinomial(probs, 1, True).cpu().numpy()[:, 0]

if logits_random.shape[0] == num_seq:
Expand Down Expand Up @@ -597,16 +611,50 @@ def generate(
torch.cuda.synchronize()
torch.cuda.nvtx.range_pop()

next_tokens = sample(logits, sampling_params, self.vocab_size)
try:
next_tokens = sample(logits, sampling_params, self.vocab_size)

return [
TextGenerationResult(
sequence_id=sequence_id,
generated_tokens=[new_token],
error=None,
)
for sequence_id, new_token in zip(sequence_ids, next_tokens)
]
return [
TextGenerationResult(
sequence_id=sequence_id,
generated_tokens=[new_token],
error=None,
)
for sequence_id, new_token in zip(sequence_ids, next_tokens)
]
except RuntimeError:
# Fallback to per-token sampling in case some logits values are corrupted.
outputs = []
err_msg = "Error from sampling: probability tensor contains either `inf`, `nan` or element < 0"

for sequence_id, logits_per_token, sampling_param in zip(
sequence_ids, torch.from_dlpack(logits), sampling_params
):
maybe_new_token = sample(
torch.unsqueeze(logits_per_token, 0),
[sampling_param],
self.vocab_size,
check_safety=True,
)

if maybe_new_token is not None:
outputs.append(
TextGenerationResult(
sequence_id=sequence_id,
generated_tokens=[maybe_new_token[0]],
error=None,
)
)
else:
outputs.append(
TextGenerationResult(
sequence_id=sequence_id,
generated_tokens=[],
error=err_msg,
)
)

return outputs


def get_gpu_memory(gpu: int = 0) -> int:
Expand Down Expand Up @@ -653,12 +701,12 @@ def generate(
class PagedCacheModelModule:
def __init__(
self,
model_artifact_path: str,
engine_config: MLCServeEngineConfig,
model_artifact_path: str,
engine_config: MLCServeEngineConfig,
):
max_num_batched_tokens, max_input_len = engine_config.max_num_batched_tokens, engine_config.max_input_len
model_artifact_config = get_model_artifact_config(model_artifact_path)
model_artifact_config = get_model_artifact_config(model_artifact_path)

dev = tvm.device("cuda", 0)

model = Model(model_artifact_config, dev)
Expand Down

0 comments on commit f8609cd

Please sign in to comment.