Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLS-272] Fix Special Token Encode Difference #201

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ def generate(
class Tokenizer(Protocol):
_tokenizer: Any
eos_token_id: int
skip_special_tokens: bool
skip_special_tokens: bool # for decoder
add_special_tokens: bool # for encoder
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
all_special_ids: List[int]
is_fast: bool

Expand Down
12 changes: 10 additions & 2 deletions serve/mlc_serve/model/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,23 @@


class Tokenizer:
def __init__(self, hf_tokenizer, skip_special_tokens=True):
def __init__(
self,
hf_tokenizer,
skip_special_tokens=True, # for decoder
add_special_tokens=False # for encoder
):
self._tokenizer = hf_tokenizer
self.eos_token_id = self._tokenizer.eos_token_id
self.add_special_tokens = add_special_tokens
self.skip_special_tokens = skip_special_tokens
self.all_special_ids = self._tokenizer.all_special_ids
self.is_fast = self._tokenizer.is_fast

def encode(self, text: str) -> List[int]:
return self._tokenizer.encode(text)
return self._tokenizer.encode(
text, add_special_tokens=self.add_special_tokens
)

def decode(self, token_ids: List[int]) -> str:
return self._tokenizer.decode(
Expand Down
Loading