Skip to content

Commit

Permalink
Add regex loading from tokenizer.json and code refinement (#863)
Browse files Browse the repository at this point in the history
* Add regex loading from tokenizer.json and code refinement

* minor refinement

* fix test failures

* add more pretoken types

* mark some TODO items
  • Loading branch information
wenbingl authored Jan 7, 2025
1 parent 4e10ee0 commit 641930d
Show file tree
Hide file tree
Showing 14 changed files with 169 additions and 107 deletions.
2 changes: 1 addition & 1 deletion cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"component": {
"type": "git",
"git": {
"commitHash": "9406a60c7839052e4944ea4dbc8344762a89f9bd",
"commitHash": "e39786088138f2749d64e9e90e0f9902daa77c40",
"repositoryUrl": "https://github.com/google/googletest.git"
}
}
Expand Down
4 changes: 2 additions & 2 deletions cmake/externals/googletest.cmake
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
FetchContent_Declare(
googletest
URL https://github.com/google/googletest/archive/9406a60c7839052e4944ea4dbc8344762a89f9bd.zip
URL_HASH SHA1=06096d3900c356e468ba060a609642c635131106
URL https://github.com/google/googletest/archive/refs/tags/v1.15.0.zip
URL_HASH SHA1=9d2d0af8d77ac726ea55d44a8fa727ec98311349
EXCLUDE_FROM_ALL
)

Expand Down
4 changes: 3 additions & 1 deletion onnxruntime_extensions/pp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ def __init__(self, tokenizer_dir):
self.tokenizer = create_tokenizer(tokenizer_dir)

def tokenize(self, text):
if isinstance(text, (list, tuple)):
return batch_tokenize(self.tokenizer, text)
return batch_tokenize(self.tokenizer, [text])[0]

def detokenize(self, tokens):
return batch_detokenize(self.tokenizer, [tokens])[0]
return batch_detokenize(self.tokenizer, [tokens])

def __del__(self):
if delete_object and self.tokenizer:
Expand Down
20 changes: 7 additions & 13 deletions operators/tokenizer/bpe_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,

// Parse input
auto special_token_split_res = bbpe_tokenizer_->SplitByAddedAndSpecial(input);
bpe::TokenWithRegularExp regcmp;
bpe::PreTokenizerWithRegEx reg_splitter;

for (auto& seg_id : special_token_split_res) {
if (static_cast<int64_t>(res.size()) >= max_length) break;
Expand All @@ -274,7 +274,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,

// Note: keep ptr to make sure the string_view is valid in the following process
std::u32string str(seg_id.first);
regcmp.Set(str.c_str());
reg_splitter.Set(str.c_str());

size_t offset = 0;
OffsetMappingType offset_mapping;
Expand All @@ -287,14 +287,8 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
}

while (static_cast<int64_t>(res.size()) < max_length) {
std::string regex_expr = "";
if (ModelName() == kModel_Llama){
regex_expr = regcmp.LLAMA_REGEX_PATTERN;
} else {
// default to GPT2 regex
regex_expr = regcmp.GPT2_REGEX_PATTERN;
}
auto [b, tok] = regcmp.GetNextToken(regex_expr);
std::string regex_expr = bbpe_tokenizer_->GetPreTokenizerRegex(ModelName());
auto [b, tok] = reg_splitter.GetNextToken(regex_expr);

if (!b) break;

Expand Down Expand Up @@ -742,9 +736,9 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::TokenJsonConfig& config
}

bbpe_tokenizer_ = std::make_unique<BpeModel>();
status = bbpe_tokenizer_->Load(*model_node,
bpe_conf_.get().GetSpecialTokens().c_str(),
bpe_conf_.get().spm_model_);
status = bbpe_tokenizer_->Load(*model_node, tok_json,
bpe_conf_.get().GetSpecialTokens().c_str(),
bpe_conf_.get().spm_model_);
if (status.IsOk()) {
UpdateTokenizer(config, tok_json);
}
Expand Down
66 changes: 64 additions & 2 deletions operators/tokenizer/bpe_tokenizer_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "string_utils.h"
#include "string_tensor.h"

#include <set>
#include <list>
#include <unordered_map>
#include <iostream>
Expand All @@ -20,13 +21,25 @@
#include "trietree.hpp"
#include "tokenizer_common.h"

#define ORTX_JSON_RETURN_IF_NULL(node_iter, name, var) \
auto var = (node_iter)->find(name); \
if (var == (node_iter)->end() || var->is_null()) { \
return {}; \
}

namespace ort_extensions {

class BpeModel {
using json = nlohmann::json;
const std::array<const char*, 12> kPreTokenizerType = {
"BertPreTokenizer", "ByteLevel", "CharDelimiterSplit", "Digits", "Metaspace",
"PreTokenizer", "Punctuation", "Sequence", "Split", "UnicodeScripts",
"Whitespace", "WhitespaceSplit",
};

public:
BpeModel() = default;
BpeModel()
: pre_tokenizer_types_(kPreTokenizerType.begin(), kPreTokenizerType.end()) {};

static void UpdateSpmByteToken(std::unordered_map<std::string, uint32_t>& vocab_map) {
static const char* hex = "0123456789ABCDEF";
Expand All @@ -44,6 +57,37 @@ class BpeModel {
}
}

OrtxStatus LoadPreTokenizer(const json& bpe_model) {
auto root_node = &bpe_model;
ORTX_JSON_RETURN_IF_NULL(root_node, "pre_tokenizer", node_pre_tokenizer);
auto iter_type = node_pre_tokenizer->find("type");
if (iter_type != node_pre_tokenizer->end() && !iter_type->is_null()) {
auto pre_token_type = iter_type->get<std::string>();
if (pre_tokenizer_types_.count(pre_token_type) == 0) {
return {kOrtxErrorNotImplemented, std::string("Unsupported pretokenizer type!") + pre_token_type};
}
}

ORTX_JSON_RETURN_IF_NULL(node_pre_tokenizer, "pretokenizers", iter_node_list);

for (const auto& node : *iter_node_list) {
ORTX_JSON_RETURN_IF_NULL(&node, "type", iter_type);
auto pre_type = iter_type->get<std::string>();
if (pre_type == "Split") {
ORTX_JSON_RETURN_IF_NULL(&node, "pattern", iter_pattern);
ORTX_JSON_RETURN_IF_NULL(iter_pattern, "Regex", regex_str);
pre_tokenizer_regex_ = regex_str->get<std::string>();
} else {
if (pre_tokenizer_types_.count(pre_type) == 0) {
return {kOrtxErrorNotImplemented, "Unsupported pretokenizer type!"};
}
; // TODO: implement other pretokenizer types
}
}

return {};
}

OrtxStatus Load(std::istream& vocab_stream, std::istream& merges_stream, const char* unk_token,
const char* special_tokens, bool spm_converted) {
nlohmann::json tok_json;
Expand Down Expand Up @@ -120,7 +164,9 @@ class BpeModel {
return {};
}

OrtxStatus Load(const json& bpe_model, const char* /* special_tokens */, bool spm_converted) {
OrtxStatus Load(const json& bpe_model, const json& tok_json, const char* /* special_tokens */, bool spm_converted) {
ORTX_RETURN_IF_ERROR(LoadPreTokenizer(tok_json));

const json& vocab_json = bpe_model["vocab"];
const json& merges_json = bpe_model["merges"];
vocab_json.get_to(vocab_map_);
Expand Down Expand Up @@ -358,6 +404,19 @@ class BpeModel {

const std::string& GetEndOfWordSuffix() const { return end_of_word_suffix_; }

std::string GetPreTokenizerRegex(const std::string& model_name) const {
if (!pre_tokenizer_regex_.empty()) {
return pre_tokenizer_regex_;
}

if (model_name == "Llama") {
return bpe::PreTokenizerWithRegEx::LLAMA_REGEX_PATTERN;
}

// by default, use the GPT2 pretokenizer regex
return bpe::PreTokenizerWithRegEx::GPT2_REGEX_PATTERN;
}

private:
struct BpeNode {
uint32_t id;
Expand All @@ -379,6 +438,9 @@ class BpeModel {
uint32_t unk_id_ = (std::numeric_limits<uint32_t>::max)();
bpe::SpecialTokenMap special_tokens_;
TrieTree<char32_t> added_tokens_;
std::string pre_tokenizer_regex_;

std::set<std::string_view> pre_tokenizer_types_;
};

} // namespace ort_extensions
10 changes: 5 additions & 5 deletions operators/tokenizer/bpe_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,12 @@ class SpecialTokenMap {
std::unordered_map<ustring, int> token_map_;
};

class TokenWithRegularExp {
class PreTokenizerWithRegEx {
public:
static constexpr const char* GPT2_REGEX_PATTERN = "'s|'t|'re|'ve|'m|'ll|'d|?\\p{L}+|?\\p{N}+|?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+";
static constexpr const char* LLAMA_REGEX_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}|?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
static constexpr const char* LLAMA_REGEX_PATTERN_2 = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";

void Set(std::u32string_view val) {
m_text = val;
}
Expand All @@ -115,10 +119,6 @@ class TokenWithRegularExp {
return {false, {}};
}

const std::string GPT2_REGEX_PATTERN = "'s|'t|'re|'ve|'m|'ll|'d|?\\p{L}+|?\\p{N}+|?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+";
const std::string LLAMA_REGEX_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}|?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
const std::string LLAMA_REGEX_PATTERN_2 = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";

public:

// Although we have RegexMatchGeneral which performs regex matching given any general regex string
Expand Down
1 change: 1 addition & 0 deletions operators/tokenizer/tokenizer_jsconfig.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ constexpr std::pair<const char*, TokenType> kTokenizerDict[] = {
{"GPT2Tokenizer", TokenType::kBPE},
{"Qwen2Tokenizer", TokenType::kBPE},
{"BaichuanTokenizer", TokenType::kBPE},
{"GPTNeoXTokenizer", TokenType::kBPE},

{"", TokenType::kUnigram},
{"T5Tokenizer", TokenType::kUnigram},
Expand Down
1 change: 1 addition & 0 deletions operators/tokenizer/trietree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class TrieTree {
tok_idx += 1;
if (tok_id == invalid_id) {
if (tok_idx < input.length()) {
tok_idx -= tok_len; // backtrack to the last token
continue;
} else {
tok_idx += 1; // Assign tok_idx to input.length()
Expand Down
2 changes: 1 addition & 1 deletion pyop/py_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void AddGlobalMethodsCApi(pybind11::module& m) {
OrtxTokenizer* tokenizer = nullptr;
auto err = OrtxCreateTokenizer(&tokenizer, tokenizer_def_json.c_str());
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to create tokenizer") + OrtxGetLastErrorMessage());
throw std::runtime_error(std::string("Failed to create tokenizer\n") + OrtxGetLastErrorMessage());
}
return reinterpret_cast<std::uintptr_t>(tokenizer);
},
Expand Down
7 changes: 2 additions & 5 deletions pyop/pyfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,9 @@ void PyCustomOpKernel::Compute(OrtKernelContext* context) {
inputs.push_back(InputInformation{input_X, i_dtype, i_dimensions});
}

/* Acquire GIL before calling Python code, due to it was released in sess.run */
py::gil_scoped_acquire acquire;

{
/* Acquire GIL before calling Python C API, due to it was released in sess.run */
py::gil_scoped_acquire acquire;
py::list pyinputs;
for (auto it = inputs.begin(); it != inputs.end(); ++it) {
py::object input0 = PyCustomOpDefImpl::BuildPyArrayFromTensor(
Expand Down Expand Up @@ -349,8 +348,6 @@ void PyCustomOpKernel::Compute(OrtKernelContext* context) {
memcpy(out, retval.data(), size * retval.size());
}
}

py::gil_scoped_release release;
}
}

Expand Down
77 changes: 28 additions & 49 deletions shared/api/tokenizer_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,15 @@

namespace ort_extensions {

std::set<std::string> TokenizerImpl::supported_bpe_models_ = {
"PreTrainedTokenizerFast",
"CLIPTokenizer",
"WhisperTokenizer",
"GemmaTokenizer",
"LlamaTokenizer",
"Phi3Tokenizer",
"CodeLlamaTokenizer",
"CodeGenTokenizer",
"GPT2Tokenizer",
"Qwen2Tokenizer",
"BaichuanTokenizer"
};

std::set<std::string> TokenizerImpl::supported_ugm_models_ = {
"XLMRobertaTokenizer",
"T5Tokenizer",
"ChatGLMTokenizer"
};

TokenizerImpl::TokenizerImpl()
: OrtxObjectImpl(extObjectKind_t::kOrtxKindTokenizer) {};
TokenizerImpl::~TokenizerImpl() {};

OrtxStatus TokenizerImpl::LoadTokenizer(const OrtxTokenizerBlob* blob) {
if (tok_config_->tokenizer_class_.empty() ||
supported_ugm_models_.count(tok_config_->tokenizer_class_)) {

auto type = TokenJsonConfig::GetTokenType(tok_config_->tokenizer_class_);
if (type == TokenType::kUnigram) {
auto tokenizer = std::make_unique<SpmUgmTokenizer>();
auto status = tokenizer->Load(*tok_config_);
if (!status.IsOk()) {
Expand All @@ -53,42 +35,39 @@ OrtxStatus TokenizerImpl::LoadTokenizer(const OrtxTokenizerBlob* blob) {
tokenizer_ = std::move(tokenizer);
detokenizer_ = std::move(detok);
}

return status;
}

if (!supported_bpe_models_.count(tok_config_->tokenizer_class_)) {
return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class");
}

auto tokenizer = std::make_unique<JsonFastTokenizer>();
auto fx_load = &JsonFastTokenizer::Load;
if (blob == nullptr) {
auto vocab_file_path = ortx::path(tok_config_->GetVocabDataFile());
// vocab file is checked in TokenJsonConfig::Load
if (vocab_file_path.extension() != ".json") {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
} else if (type == TokenType::kBPE) {
auto tokenizer = std::make_unique<JsonFastTokenizer>();
auto fx_load = &JsonFastTokenizer::Load;
if (blob == nullptr) {
auto vocab_file_path = ortx::path(tok_config_->GetVocabDataFile());
// vocab file is checked in TokenJsonConfig::Load
if (vocab_file_path.extension() != ".json") {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
}
} else {
if (blob->raw_model_blob_len > 0) {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
}
}
} else {
if (blob->raw_model_blob_len > 0) {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;

auto status = (tokenizer.get()->*fx_load)(*tok_config_);
if (!status.IsOk()) {
return status;
}
}

auto status = (tokenizer.get()->*fx_load)(*tok_config_);
if (!status.IsOk()) {
return status;
}
auto detok = std::make_unique<BpeStreamingDecoder>();
status = detok->Load(tok_config_, *tokenizer);

auto detok = std::make_unique<BpeStreamingDecoder>();
status = detok->Load(tok_config_, *tokenizer);
if (status.IsOk()) {
tokenizer_ = std::move(tokenizer);
detokenizer_ = std::move(detok);
}

if (status.IsOk()) {
tokenizer_ = std::move(tokenizer);
detokenizer_ = std::move(detok);
return status;
}

return status;
return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class");
}

OrtxStatus TokenizerImpl::Load(const OrtxTokenizerBlob& blob) {
Expand Down
3 changes: 0 additions & 3 deletions shared/api/tokenizer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ namespace ort_extensions {

class TokenizerImpl : public OrtxObjectImpl {
public:
static std::set<std::string> supported_bpe_models_;
static std::set<std::string> supported_ugm_models_;

TokenizerImpl();
virtual ~TokenizerImpl();

Expand Down
Loading

0 comments on commit 641930d

Please sign in to comment.