Skip to content

Commit

Permalink
fix sync engine when there is error-ed request
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Nov 16, 2023
1 parent f4c2de8 commit f26d409
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def step(self) -> InferenceStepResult:
results = self.text_generator.generate(requests, self.cache_manager.get_cache())
logger.debug("Finished text generation.")

valid_results = []

for res in results:
# For now we only support single sequence per request
request_id = res.sequence_id.request_id
Expand All @@ -154,11 +156,15 @@ def step(self) -> InferenceStepResult:
error=res.error,
)
)
continue
else:
valid_results.append(res)

for res in valid_results:
request_id = res.sequence_id.request_id
state = self.current_batch[request_id]
state.next_start_position = len(state.token_ids)
new_token_ids = res.generated_tokens

for i, token_id in enumerate(new_token_ids):
if (
token_id == self.tokenizer.eos_token_id
Expand All @@ -167,11 +173,9 @@ def step(self) -> InferenceStepResult:
new_token_ids = new_token_ids[:i]
state.is_ended = True
break

state.token_ids.extend(new_token_ids)

for res in results:
request_id = res.sequence_id.request_id
state = self.current_batch[request_id]
delta = self._decode_last_output(state)
state.output_text += delta

Expand Down

0 comments on commit f26d409

Please sign in to comment.