diff --git a/.github/scripts/run-test.sh b/.github/scripts/run-test.sh index 8e3350fc..b12763fe 100755 --- a/.github/scripts/run-test.sh +++ b/.github/scripts/run-test.sh @@ -52,6 +52,21 @@ for wave in ${waves[@]}; do done done +log "Start testing ${repo_url} with hotwords" + +time $EXE \ + $repo/tokens.txt \ + $repo/encoder_jit_trace-pnnx.ncnn.param \ + $repo/encoder_jit_trace-pnnx.ncnn.bin \ + $repo/decoder_jit_trace-pnnx.ncnn.param \ + $repo/decoder_jit_trace-pnnx.ncnn.bin \ + $repo/joiner_jit_trace-pnnx.ncnn.param \ + $repo/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1.wav \ + 2 \ + modified_beam_search \ + $repo/test_wavs/hotwords.txt + rm -rf $repo log "------------------------------------------------------------" @@ -588,4 +603,4 @@ time $EXE \ modified_beam_search \ $repo/hotwords.txt 1.6 -rm -rf $repo \ No newline at end of file +rm -rf $repo diff --git a/sherpa-ncnn/csrc/CMakeLists.txt b/sherpa-ncnn/csrc/CMakeLists.txt index 41ed79f0..5b28488f 100644 --- a/sherpa-ncnn/csrc/CMakeLists.txt +++ b/sherpa-ncnn/csrc/CMakeLists.txt @@ -77,4 +77,6 @@ endif() if(SHERPA_NCNN_ENABLE_TEST) add_executable(test-resample test-resample.cc) target_link_libraries(test-resample sherpa-ncnn-core) + add_executable(test-context-graph test-context-graph.cc) + target_link_libraries(test-context-graph sherpa-ncnn-core) endif() diff --git a/sherpa-ncnn/csrc/context-graph.cc b/sherpa-ncnn/csrc/context-graph.cc index 78c08bc6..481d51d1 100644 --- a/sherpa-ncnn/csrc/context-graph.cc +++ b/sherpa-ncnn/csrc/context-graph.cc @@ -4,22 +4,57 @@ #include "sherpa-ncnn/csrc/context-graph.h" +#include #include #include +#include +#include #include namespace sherpa_ncnn { -void ContextGraph::Build( - const std::vector> &token_ids) const { +void ContextGraph::Build(const std::vector> &token_ids, + const std::vector &scores, + const std::vector &phrases, + const std::vector &ac_thresholds) const { + if (!scores.empty()) { + assert(token_ids.size() == scores.size()); + } + if (!phrases.empty()) { + assert(token_ids.size() == phrases.size()); + } + if (!ac_thresholds.empty()) { + assert(token_ids.size() == ac_thresholds.size()); + } for (int32_t i = 0; i < token_ids.size(); ++i) { auto node = root_.get(); + float score = scores.empty() ? 0.0f : scores[i]; + score = score == 0.0f ? context_score_ : score; + float ac_threshold = ac_thresholds.empty() ? 0.0f : ac_thresholds[i]; + ac_threshold = ac_threshold == 0.0f ? ac_threshold_ : ac_threshold; + std::string phrase = phrases.empty() ? std::string() : phrases[i]; + for (int32_t j = 0; j < token_ids[i].size(); ++j) { int32_t token = token_ids[i][j]; if (0 == node->next.count(token)) { bool is_end = j == token_ids[i].size() - 1; node->next[token] = std::make_unique( - token, context_score_, node->node_score + context_score_, - is_end ? node->node_score + context_score_ : 0, is_end); + token, score, node->node_score + score, + is_end ? node->node_score + score : 0, j + 1, + is_end ? ac_threshold : 0.0f, is_end, + is_end ? phrase : std::string()); + } else { + float token_score = std::max(score, node->next[token]->token_score); + node->next[token]->token_score = token_score; + float node_score = node->node_score + token_score; + node->next[token]->node_score = node_score; + bool is_end = + (j == token_ids[i].size() - 1) || node->next[token]->is_end; + node->next[token]->output_score = is_end ? node_score : 0.0f; + node->next[token]->is_end = is_end; + if (j == token_ids[i].size() - 1) { + node->next[token]->phrase = phrase; + node->next[token]->ac_threshold = ac_threshold; + } } node = node->next[token].get(); } @@ -27,8 +62,9 @@ void ContextGraph::Build( FillFailOutput(); } -std::pair ContextGraph::ForwardOneStep( - const ContextState *state, int32_t token) const { +std::tuple +ContextGraph::ForwardOneStep(const ContextState *state, int32_t token, + bool strict_mode /*= true*/) const { const ContextState *node; float score; if (1 == state->next.count(token)) { @@ -45,7 +81,22 @@ std::pair ContextGraph::ForwardOneStep( } score = node->node_score - state->node_score; } - return std::make_pair(score + node->output_score, node); + + assert(nullptr != node); + + const ContextState *matched_node = + node->is_end ? node : (node->output != nullptr ? node->output : nullptr); + + if (!strict_mode && node->output_score != 0) { + assert(nullptr != matched_node); + float output_score = + node->is_end ? node->node_score + : (node->output != nullptr ? node->output->node_score + : node->node_score); + return std::make_tuple(score + output_score - node->node_score, root_.get(), + matched_node); + } + return std::make_tuple(score + node->output_score, node, matched_node); } std::pair ContextGraph::Finalize( @@ -54,6 +105,22 @@ std::pair ContextGraph::Finalize( return std::make_pair(score, root_.get()); } +std::pair ContextGraph::IsMatched( + const ContextState *state) const { + bool status = false; + const ContextState *node = nullptr; + if (state->is_end) { + status = true; + node = state; + } else { + if (state->output != nullptr) { + status = true; + node = state->output; + } + } + return std::make_pair(status, node); +} + void ContextGraph::FillFailOutput() const { std::queue node_queue; for (auto &kv : root_->next) { diff --git a/sherpa-ncnn/csrc/context-graph.h b/sherpa-ncnn/csrc/context-graph.h index 0002fa52..326c4fa0 100644 --- a/sherpa-ncnn/csrc/context-graph.h +++ b/sherpa-ncnn/csrc/context-graph.h @@ -6,11 +6,12 @@ #define SHERPA_NCNN_CSRC_CONTEXT_GRAPH_H_ #include +#include +#include #include #include #include - namespace sherpa_ncnn { class ContextGraph; @@ -21,34 +22,55 @@ struct ContextState { float token_score; float node_score; float output_score; + int32_t level; + float ac_threshold; bool is_end; + std::string phrase; std::unordered_map> next; const ContextState *fail = nullptr; const ContextState *output = nullptr; ContextState() = default; ContextState(int32_t token, float token_score, float node_score, - float output_score, bool is_end) + float output_score, int32_t level = 0, float ac_threshold = 0.0f, + bool is_end = false, const std::string &phrase = {}) : token(token), token_score(token_score), node_score(node_score), output_score(output_score), - is_end(is_end) {} + level(level), + ac_threshold(ac_threshold), + is_end(is_end), + phrase(phrase) {} }; class ContextGraph { public: ContextGraph() = default; ContextGraph(const std::vector> &token_ids, - float hotwords_score) - : context_score_(hotwords_score) { - root_ = std::make_unique(-1, 0, 0, 0, false); + float context_score, float ac_threshold, + const std::vector &scores = {}, + const std::vector &phrases = {}, + const std::vector &ac_thresholds = {}) + : context_score_(context_score), ac_threshold_(ac_threshold) { + root_ = std::make_unique(-1, 0, 0, 0); root_->fail = root_.get(); - Build(token_ids); + Build(token_ids, scores, phrases, ac_thresholds); } - std::pair ForwardOneStep( - const ContextState *state, int32_t token_id) const; + ContextGraph(const std::vector> &token_ids, + float context_score, const std::vector &scores = {}, + const std::vector &phrases = {}) + : ContextGraph(token_ids, context_score, 0.0f, scores, phrases, + std::vector()) {} + + std::tuple ForwardOneStep( + const ContextState *state, int32_t token_id, + bool strict_mode = true) const; + + std::pair IsMatched( + const ContextState *state) const; + std::pair Finalize( const ContextState *state) const; @@ -56,8 +78,12 @@ class ContextGraph { private: float context_score_; + float ac_threshold_; std::unique_ptr root_; - void Build(const std::vector> &token_ids) const; + void Build(const std::vector> &token_ids, + const std::vector &scores, + const std::vector &phrases, + const std::vector &ac_thresholds) const; void FillFailOutput() const; }; diff --git a/sherpa-ncnn/csrc/modified-beam-search-decoder.cc b/sherpa-ncnn/csrc/modified-beam-search-decoder.cc index 17c61ff4..898b10df 100644 --- a/sherpa-ncnn/csrc/modified-beam-search-decoder.cc +++ b/sherpa-ncnn/csrc/modified-beam-search-decoder.cc @@ -117,82 +117,7 @@ ncnn::Mat ModifiedBeamSearchDecoder::BuildDecoderInput( void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, DecoderResult *result) { - int32_t context_size = model_->ContextSize(); - Hypotheses cur = std::move(result->hyps); - /* encoder_out.w == encoder_out_dim, encoder_out.h == num_frames. */ - for (int32_t t = 0; t != encoder_out.h; ++t) { - std::vector prev = cur.GetTopK(num_active_paths_, true); - cur.Clear(); - - ncnn::Mat decoder_input = BuildDecoderInput(prev); - ncnn::Mat decoder_out; - if (t == 0 && prev.size() == 1 && prev[0].ys.size() == context_size && - !result->decoder_out.empty()) { - // When an endpoint is detected, we keep the decoder_out - decoder_out = result->decoder_out; - } else { - decoder_out = RunDecoder2D(model_, decoder_input); - } - - // decoder_out.w == decoder_dim - // decoder_out.h == num_active_paths - ncnn::Mat encoder_out_t(encoder_out.w, 1, encoder_out.row(t)); - // Note: encoder_out_t.h == 1, we rely on the binary op broadcasting - // in ncnn - // See https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting - // broadcast B for outer axis, type 14 - ncnn::Mat joiner_out = model_->RunJoiner(encoder_out_t, decoder_out); - - // joiner_out.w == vocab_size - // joiner_out.h == num_active_paths - LogSoftmax(&joiner_out); - - float *p_joiner_out = joiner_out; - - for (int32_t i = 0; i != joiner_out.h; ++i) { - float prev_log_prob = prev[i].log_prob; - for (int32_t k = 0; k != joiner_out.w; ++k, ++p_joiner_out) { - *p_joiner_out += prev_log_prob; - } - } - - auto topk = TopkIndex(static_cast(joiner_out), - joiner_out.w * joiner_out.h, num_active_paths_); - - int32_t frame_offset = result->frame_offset; - for (auto i : topk) { - int32_t hyp_index = i / joiner_out.w; - int32_t new_token = i % joiner_out.w; - - const float *p = joiner_out.row(hyp_index); - - Hypothesis new_hyp = prev[hyp_index]; - - // blank id is fixed to 0 - if (new_token != 0 && new_token != 2) { - new_hyp.ys.push_back(new_token); - new_hyp.num_trailing_blanks = 0; - new_hyp.timestamps.push_back(t + frame_offset); - } else { - ++new_hyp.num_trailing_blanks; - } - // We have already added prev[hyp_index].log_prob to p[new_token] - new_hyp.log_prob = p[new_token]; - - cur.Add(std::move(new_hyp)); - } - } - - result->hyps = std::move(cur); - result->frame_offset += encoder_out.h; - auto hyp = result->hyps.GetMostProbable(true); - - // set decoder_out in case of endpointing - ncnn::Mat decoder_input = BuildDecoderInput({hyp}); - result->decoder_out = model_->RunDecoder(decoder_input); - - result->tokens = std::move(hyp.ys); - result->num_trailing_blanks = hyp.num_trailing_blanks; + Decode(encoder_out, nullptr, result); } void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s, @@ -252,10 +177,10 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s, new_hyp.num_trailing_blanks = 0; new_hyp.timestamps.push_back(t + frame_offset); if (s && s->GetContextGraph()) { - auto context_res = - s->GetContextGraph()->ForwardOneStep(context_state, new_token); - context_score = context_res.first; - new_hyp.context_state = context_res.second; + auto context_res = s->GetContextGraph()->ForwardOneStep( + context_state, new_token, false /*strict_mode*/); + context_score = std::get<0>(context_res); + new_hyp.context_state = std::get<1>(context_res); } } else { ++new_hyp.num_trailing_blanks; diff --git a/sherpa-ncnn/csrc/recognizer.cc b/sherpa-ncnn/csrc/recognizer.cc index 12b373e6..7884f9ff 100644 --- a/sherpa-ncnn/csrc/recognizer.cc +++ b/sherpa-ncnn/csrc/recognizer.cc @@ -25,6 +25,7 @@ #include #include +#include "sherpa-ncnn/csrc/context-graph.h" #include "sherpa-ncnn/csrc/decoder.h" #include "sherpa-ncnn/csrc/greedy-search-decoder.h" #include "sherpa-ncnn/csrc/modified-beam-search-decoder.h" @@ -225,7 +226,11 @@ class Recognizer::Impl { } RecognitionResult GetResult(Stream *s) const { + if (IsEndpoint(s)) { + s->Finalize(); + } DecoderResult decoder_result = s->GetResult(); + decoder_->StripLeadingBlanks(&decoder_result); // Those 2 parameters are figured out from sherpa source code @@ -272,23 +277,35 @@ class Recognizer::Impl { std::vector tmp; std::string line; std::string word; - + // The format of each line in hotwords_file looks like: + // ▁HE LL O ▁WORLD :1.5 + // the first several items are tokens of the hotword, the item starts with + // ":" is the customize boosting score for this hotword, if there is no + // customize score it will use the score from configuration (i.e. + // config_.hotwords_score). while (std::getline(is, line)) { std::istringstream iss(line); + float tmp_score = 0.0; // MUST be 0.0, meaning if no customize score use + // the global one. while (iss >> word) { if (sym_.contains(word)) { int32_t number = sym_[word]; tmp.push_back(number); } else { - NCNN_LOGE( - "Cannot find ID for hotword %s at line: %s. (Hint: words on the " - "same line are separated by spaces)", - word.c_str(), line.c_str()); - exit(-1); + if (word[0] == ':') { + tmp_score = std::stof(word.substr(1)); + } else { + NCNN_LOGE( + "Cannot find ID for hotword %s at line: %s. (Hint: words on " + "the " + "same line are separated by spaces)", + word.c_str(), line.c_str()); + exit(-1); + } } } - hotwords_.push_back(std::move(tmp)); + boost_scores_.push_back(tmp_score); } } @@ -299,6 +316,7 @@ class Recognizer::Impl { Endpoint endpoint_; SymbolTable sym_; std::vector> hotwords_; + std::vector boost_scores_; }; Recognizer::Recognizer(const RecognizerConfig &config) diff --git a/sherpa-ncnn/csrc/sherpa-ncnn-alsa.cc b/sherpa-ncnn/csrc/sherpa-ncnn-alsa.cc index c7bff976..fd29f684 100644 --- a/sherpa-ncnn/csrc/sherpa-ncnn-alsa.cc +++ b/sherpa-ncnn/csrc/sherpa-ncnn-alsa.cc @@ -47,7 +47,7 @@ int main(int32_t argc, char *argv[]) { /path/to/joiner.ncnn.param \ /path/to/joiner.ncnn.bin \ device_name \ - [num_threads] [decode_method, can be greedy_search/modified_beam_search] + [num_threads] [decode_method, can be greedy_search/modified_beam_search] [hotwords_file] [hotwords_score] Please refer to https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html @@ -108,6 +108,14 @@ as the device_name. } } + if (argc >= 11) { + config.hotwords_file = argv[10]; + } + + if (argc == 12) { + config.hotwords_score = atof(argv[11]); + } + int32_t expected_sampling_rate = 16000; config.enable_endpoint = true; @@ -148,6 +156,10 @@ as the device_name. } bool is_endpoint = recognizer.IsEndpoint(s.get()); + + if (is_endpoint) { + s->Finalize(); + } auto text = recognizer.GetResult(s.get()).text; if (!text.empty() && last_text != text) { diff --git a/sherpa-ncnn/csrc/sherpa-ncnn-microphone.cc b/sherpa-ncnn/csrc/sherpa-ncnn-microphone.cc index c450e24a..01dc8f38 100644 --- a/sherpa-ncnn/csrc/sherpa-ncnn-microphone.cc +++ b/sherpa-ncnn/csrc/sherpa-ncnn-microphone.cc @@ -60,7 +60,7 @@ int32_t main(int32_t argc, char *argv[]) { /path/to/decoder.ncnn.bin \ /path/to/joiner.ncnn.param \ /path/to/joiner.ncnn.bin \ - [num_threads] [decode_method, can be greedy_search/modified_beam_search] + [num_threads] [decode_method, can be greedy_search/modified_beam_search] [hotwords_file] [hotwords_score] Please refer to https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html @@ -97,6 +97,14 @@ for a list of pre-trained models to download. } } + if (argc >= 11) { + config.hotwords_file = argv[10]; + } + + if (argc == 12) { + config.hotwords_score = atof(argv[11]); + } + config.enable_endpoint = true; config.endpoint_config.rule1.min_trailing_silence = 2.4; @@ -166,6 +174,10 @@ for a list of pre-trained models to download. } bool is_endpoint = recognizer.IsEndpoint(s.get()); + + if (is_endpoint) { + s->Finalize(); + } auto text = recognizer.GetResult(s.get()).text; if (!text.empty() && last_text != text) { diff --git a/sherpa-ncnn/csrc/sherpa-ncnn.cc b/sherpa-ncnn/csrc/sherpa-ncnn.cc index ccc79ece..ca367fdb 100644 --- a/sherpa-ncnn/csrc/sherpa-ncnn.cc +++ b/sherpa-ncnn/csrc/sherpa-ncnn.cc @@ -40,7 +40,7 @@ int32_t main(int32_t argc, char *argv[]) { /path/to/decoder.ncnn.bin \ /path/to/joiner.ncnn.param \ /path/to/joiner.ncnn.bin \ - /path/to/foo.wav [num_threads] [decode_method, can be greedy_search/modified_beam_search] + /path/to/foo.wav [num_threads] [decode_method, can be greedy_search/modified_beam_search] [hotwords_file] [hotwords_score] Please refer to https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html @@ -112,10 +112,10 @@ for a list of pre-trained models to download. static_cast(0.3 * expected_sampling_rate)); stream->AcceptWaveform(expected_sampling_rate, tail_paddings.data(), tail_paddings.size()); - while (recognizer.IsReady(stream.get())) { recognizer.DecodeStream(stream.get()); } + stream->Finalize(); auto result = recognizer.GetResult(stream.get()); std::cout << "Done!\n"; diff --git a/sherpa-ncnn/csrc/stream.cc b/sherpa-ncnn/csrc/stream.cc index 7b7af4d0..9e6fa4d9 100644 --- a/sherpa-ncnn/csrc/stream.cc +++ b/sherpa-ncnn/csrc/stream.cc @@ -18,6 +18,8 @@ #include "sherpa-ncnn/csrc/stream.h" +#include + namespace sherpa_ncnn { class Stream::Impl { @@ -49,6 +51,18 @@ class Stream::Impl { num_processed_frames_ = 0; } + void Finalize() { + if (!context_graph_) return; + auto &cur = result_.hyps; + for (auto iter = cur.begin(); iter != cur.end(); ++iter) { + auto context_res = context_graph_->Finalize(iter->second.context_state); + iter->second.log_prob += context_res.first; + iter->second.context_state = context_res.second; + } + auto hyp = result_.hyps.GetMostProbable(true); + result_.tokens = std::move(hyp.ys); + } + int32_t &GetNumProcessedFrames() { return num_processed_frames_; } void SetResult(const DecoderResult &r) { @@ -99,6 +113,8 @@ ncnn::Mat Stream::GetFrames(int32_t frame_index, int32_t n) const { void Stream::Reset() { impl_->Reset(); } +void Stream::Finalize() { impl_->Finalize(); } + int32_t &Stream::GetNumProcessedFrames() { return impl_->GetNumProcessedFrames(); } diff --git a/sherpa-ncnn/csrc/stream.h b/sherpa-ncnn/csrc/stream.h index 9b3f4248..45f7a71b 100644 --- a/sherpa-ncnn/csrc/stream.h +++ b/sherpa-ncnn/csrc/stream.h @@ -70,6 +70,15 @@ class Stream { void Reset(); + /** + * Finalize the decoding result. This is mainly for decoding with hotwords + * (i.e. providing context_graph). It will cancel the boosting score of the + * partial matching paths. For example, the hotword is "BCD", the path "ABC" + * gets boosting score of "BC" but it fails to match the whole hotword "BCD", + * so we have to cancel the scores of "BC" at the end. + */ + void Finalize(); + // Return a reference to the number of processed frames so far // before subsampling.. // Initially, it is 0. It is always less than NumFramesReady(). diff --git a/sherpa-ncnn/csrc/test-context-graph.cc b/sherpa-ncnn/csrc/test-context-graph.cc new file mode 100644 index 00000000..024e127a --- /dev/null +++ b/sherpa-ncnn/csrc/test-context-graph.cc @@ -0,0 +1,105 @@ +// sherpa-ncnn/csrc/test-context-graph.cc +// +// Copyright (c) 2023-2024 Xiaomi Corporation + +#include +#include // NOLINT +#include +#include +#include +#include +#include + +#include "sherpa-ncnn/csrc/context-graph.h" + +static void TestHelper(const std::map &queries, float score, + bool strict_mode) { + std::vector contexts_str( + {"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"}); + std::vector> contexts; + std::vector scores; + for (int32_t i = 0; i < contexts_str.size(); ++i) { + contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end()); + scores.push_back(std::round(score / contexts_str[i].size() * 100) / 100); + } + auto context_graph = sherpa_ncnn::ContextGraph(contexts, 1, scores); + + for (const auto &iter : queries) { + float total_scores = 0; + auto state = context_graph.Root(); + for (auto q : iter.first) { + auto res = context_graph.ForwardOneStep(state, q, strict_mode); + total_scores += std::get<0>(res); + state = std::get<1>(res); + } + auto res = context_graph.Finalize(state); + assert(res.second->token == -1); + total_scores += res.first; + assert(total_scores == iter.second); + } +} + +static void TestBasic() { + auto queries = std::map{ + {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, + {"SHED", 6}, {"SHELF", 6}, {"HELL", 2}, + {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; + TestHelper(queries, 0, true); +} + +static void TestBasicNonStrict() { + auto queries = std::map{ + {"HEHERSHE", 7}, {"HERSHE", 5}, {"HISHE", 5}, {"SHED", 3}, {"SHELF", 3}, + {"HELL", 2}, {"HELLO", 2}, {"DHRHISQ", 3}, {"THEN", 2}}; + TestHelper(queries, 0, false); +} + +static void TestCustomize() { + auto queries = std::map{ + {"HEHERSHE", 35.84}, {"HERSHE", 30.84}, {"HISHE", 24.18}, + {"SHED", 18.34}, {"SHELF", 18.34}, {"HELL", 5}, + {"HELLO", 13}, {"DHRHISQ", 10.84}, {"THEN", 5}}; + TestHelper(queries, 5, true); +} + +static void TestCustomizeNonStrict() { + auto queries = std::map{ + {"HEHERSHE", 20}, {"HERSHE", 15}, {"HISHE", 10.84}, + {"SHED", 10}, {"SHELF", 10}, {"HELL", 5}, + {"HELLO", 5}, {"DHRHISQ", 5.84}, {"THEN", 5}}; + TestHelper(queries, 5, false); +} + +static void Benchmark() { + std::random_device rd; + std::mt19937 mt(rd()); + std::uniform_int_distribution char_dist(0, 25); + std::uniform_int_distribution len_dist(3, 8); + for (int32_t num = 10; num <= 10000; num *= 10) { + std::vector> contexts; + for (int32_t i = 0; i < num; ++i) { + std::vector tmp; + int32_t word_len = len_dist(mt); + for (int32_t j = 0; j < word_len; ++j) { + tmp.push_back(char_dist(mt)); + } + contexts.push_back(std::move(tmp)); + } + auto start = std::chrono::high_resolution_clock::now(); + auto context_graph = sherpa_ncnn::ContextGraph(contexts, 1); + auto stop = std::chrono::high_resolution_clock::now(); + auto duration = + std::chrono::duration_cast(stop - start); + fprintf(stderr, "Construct context graph for %d item takes %d us.\n", num, + static_cast(duration.count())); + } +} + +int32_t main() { + TestBasic(); + TestBasicNonStrict(); + TestCustomize(); + TestCustomizeNonStrict(); + Benchmark(); + return 0; +}