diff --git a/3rdparty/tokenizers-cpp b/3rdparty/tokenizers-cpp index 470bbd49ff..5703f8da64 160000 --- a/3rdparty/tokenizers-cpp +++ b/3rdparty/tokenizers-cpp @@ -1 +1 @@ -Subproject commit 470bbd49ffb56bf5ed9a9724cff2caaf2329bd3d +Subproject commit 5703f8da64201d03e4d8d950ebbc655b46f000aa diff --git a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatState.java b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatState.java index 17e1046f41..c66d0abde3 100644 --- a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatState.java +++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatState.java @@ -91,12 +91,12 @@ void Dummy(String text, Handler handler) { e.printStackTrace(); } } - Utils.sendEnd("encode: 100.0 tok/s, decode: 100.0 tok/s", handler); + Utils.sendEnd("prefill: 100.0 tok/s, decode: 100.0 tok/s", handler); } void Generate(String prompt, Handler handler) { // System.err.println("Start generating"); - backend.Encode(prompt); + backend.Prefill(prompt); // System.err.println("Encoding " + prompt); while (!backend.Stopped()) { backend.Decode(); diff --git a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/LLMChat.java b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/LLMChat.java index 13ee57f591..088963617e 100644 --- a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/LLMChat.java +++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/LLMChat.java @@ -8,7 +8,7 @@ import org.apache.tvm.Module; public class LLMChat { - private Function encode_func_; + private Function prefill_func_; private Function decode_func_; private Function get_message_; private Function stopped_func_; @@ -36,7 +36,7 @@ public void Init() { System.err.println("[INFO] Before LLM Chat create"); llm_chat_ = fcreate.pushArg(lib).pushArg(tokenizer_path).pushArg(param_path).pushArg(Device.opencl().deviceType).pushArg(0).invoke().asModule(); System.err.println("[INFO] LLM Chat created!"); - encode_func_ = llm_chat_.getFunction("encode"); + prefill_func_ = llm_chat_.getFunction("prefill"); decode_func_ = llm_chat_.getFunction("decode"); get_message_ = llm_chat_.getFunction("get_message"); @@ -45,7 +45,7 @@ public void Init() { runtime_stats_text_func_ = llm_chat_.getFunction("runtime_stats_text"); - assert encode_func_ != null; + assert prefill_func_ != null; assert decode_func_ != null; assert stopped_func_ != null; assert runtime_stats_text_func_ != null; @@ -71,8 +71,8 @@ public String GetMessage() { return get_message_.invoke().asString(); } - public void Encode(String prompt) { - encode_func_.pushArg(prompt).invoke(); + public void Prefill(String prompt) { + prefill_func_.pushArg(prompt).invoke(); } public boolean Stopped() { diff --git a/build.py b/build.py index 7b983f526c..1e55706c6e 100644 --- a/build.py +++ b/build.py @@ -202,8 +202,8 @@ def mod_transform_before_build( ) -> tvm.IRModule: """First-stage: Legalize ops and trace""" model_names = [ - "encoding", - "decoding", + "prefill", + "decode", "create_kv_cache", "softmax_with_temperature", "get_metadata", diff --git a/cpp/cli_main.cc b/cpp/cli_main.cc index 315b726e45..7bb1d686bf 100644 --- a/cpp/cli_main.cc +++ b/cpp/cli_main.cc @@ -171,7 +171,7 @@ struct LLMChatModule { public: explicit LLMChatModule(const DLDevice& device) { this->chat_mod_ = mlc::llm::CreateChatModule(device); - this->encode_ = this->chat_mod_->GetFunction("encode"); + this->prefill_ = this->chat_mod_->GetFunction("prefill"); this->decode_ = this->chat_mod_->GetFunction("decode"); this->stopped_ = this->chat_mod_->GetFunction("stopped"); this->get_message_ = this->chat_mod_->GetFunction("get_message"); @@ -180,7 +180,7 @@ struct LLMChatModule { this->get_role1_ = this->chat_mod_->GetFunction("get_role1"); this->runtime_stats_text_ = this->chat_mod_->GetFunction("runtime_stats_text"); this->reset_chat_ = this->chat_mod_->GetFunction("reset_chat"); - ICHECK(encode_ != nullptr); + ICHECK(prefill_ != nullptr); ICHECK(decode_ != nullptr); ICHECK(stopped_ != nullptr); ICHECK(get_message_ != nullptr); @@ -206,7 +206,7 @@ struct LLMChatModule { void Reset() { reset_chat_(); } void Converse(const std::string& input, int stream_interval, std::ostream& os) { - this->Encode(input); + this->Prefill(input); std::string cur_msg = ""; std::vector cur_utf8_chars = CountUTF8(cur_msg); @@ -241,7 +241,7 @@ struct LLMChatModule { protected: // Low-level APIs - void Encode(const std::string& input) { encode_(input); } + void Prefill(const std::string& input) { prefill_(input); } void Decode() { decode_(); } @@ -251,7 +251,7 @@ struct LLMChatModule { // TVM Modules and functions with TVM's calling convention tvm::runtime::Module chat_mod_; - tvm::runtime::PackedFunc encode_; + tvm::runtime::PackedFunc prefill_; tvm::runtime::PackedFunc decode_; tvm::runtime::PackedFunc stopped_; tvm::runtime::PackedFunc get_message_; @@ -264,10 +264,14 @@ struct LLMChatModule { std::optional TryInferMLCChatConfig(const std::string& artifact_path, const std::string& local_id) { - return FindFile({artifact_path + "/prebuilt/" + local_id, // - artifact_path + "/" + local_id + "/params"}, // - {"mlc-chat-config"}, // - {".json"}); + return FindFile( + { + // + artifact_path + "/" + local_id + "/params", // + artifact_path + "/prebuilt/" + local_id, // + }, // + {"mlc-chat-config"}, // + {".json"}); } std::string ReadStringFromJSONFile(const std::filesystem::path& config_path, @@ -317,10 +321,9 @@ ModelPaths ModelPaths::Find(const std::string& artifact_path, const std::string& std::filesystem::path lib_path; if (auto path = FindFile( { - artifact_path + "/prebuilt/lib/", // prebuild lib - artifact_path + "/prebuilt/" + lib_local_id, // For prebuilts - artifact_path + "/" + lib_local_id, // Usually this is the candidate - artifact_path + "/" + lib_local_id + "/lib/", + artifact_path + "/" + lib_local_id, // Usually this is the candidate + artifact_path + "/prebuilt/lib/", // prebuild lib + artifact_path + "/prebuilt/" + lib_local_id // For prebuilts }, { lib_name, diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index d7ac7d5cb3..6f7eb88acc 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -469,7 +469,7 @@ class LLMChat { std::string RuntimeStatsText() { std::ostringstream os; os << "prefill: " << std::setprecision(1) << std::fixed - << this->encode_total_tokens / this->encode_total_time << " tok/s" + << this->prefill_total_tokens / this->prefill_total_time << " tok/s" << ", decode: " << std::setprecision(1) << std::fixed << this->decode_total_tokens / this->decode_total_time << " tok/s"; // os << ", sample-cost: " << std::setprecision(1) << std::fixed @@ -492,8 +492,8 @@ class LLMChat { static_cast(kDLCPU), 0, static_cast(relax_vm::AllocatorType::kPooled)); - encoding_func_ = vm_->GetFunction("encoding"); - decoding_func_ = vm_->GetFunction("decoding"); + encoding_func_ = vm_->GetFunction("prefill"); + decoding_func_ = vm_->GetFunction("decode"); encoding_without_cache_func_ = vm_->GetFunction("encoding_without_cache"); softmax_func_ = vm_->GetFunction("softmax_with_temperature"); get_metadata_func_ = vm_->GetFunction("get_metadata"); @@ -629,9 +629,9 @@ class LLMChat { /*! \brief reset the runtime stats. */ void ResetRuntimeStats() { - this->encode_total_tokens = 0; + this->prefill_total_tokens = 0; this->decode_total_tokens = 0; - this->encode_total_time = 0; + this->prefill_total_time = 0; this->decode_total_time = 0; this->sample_total_time = 0; } @@ -725,8 +725,8 @@ class LLMChat { /*! * \brief Generate the next token given a prompt. */ - void EncodeStep(std::string inp) { - if (reset_stats_per_encode_) { + void PrefillStep(std::string inp) { + if (reset_stats_per_prefill_) { this->ResetRuntimeStats(); } output_ids_.clear(); @@ -755,8 +755,8 @@ class LLMChat { TVMSynchronize(device_.device_type, device_.device_id, nullptr); auto tend = std::chrono::high_resolution_clock::now(); - this->encode_total_time += static_cast((tend - tstart).count()) / 1e9; - this->encode_total_tokens += token_len; + this->prefill_total_time += static_cast((tend - tstart).count()) / 1e9; + this->prefill_total_tokens += token_len; if (temperature_ < 1e-6f) { next_token_ = this->SampleFromLogitsOnCPU(); } else { @@ -1024,12 +1024,12 @@ class LLMChat { //---------------------------- // Statistics //---------------------------- - bool reset_stats_per_encode_ = true; + bool reset_stats_per_prefill_ = true; double decode_total_time = 0; double sample_total_time = 0; - double encode_total_time = 0; + double prefill_total_time = 0; int64_t decode_total_tokens = 0; - int64_t encode_total_tokens = 0; + int64_t prefill_total_tokens = 0; //---------------------------- // Conversation //---------------------------- @@ -1141,10 +1141,10 @@ class LLMChatModule : public ModuleNode { } else if (name == "try_tokenizer") { return PackedFunc( [this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { GetChat()->TryTokenizer(); }); - } else if (name == "encode") { + } else if (name == "prefill") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.size(), 1); - GetChat()->EncodeStep(args[0]); + GetChat()->PrefillStep(args[0]); }); } else if (name == "decode") { return PackedFunc( @@ -1219,8 +1219,8 @@ class LLMChatModule : public ModuleNode { static_cast(relax_vm::AllocatorType::kPooled), static_cast(kDLCPU), 0, static_cast(relax_vm::AllocatorType::kPooled)); - chat_->encoding_func_ = chat_->vm_->GetFunction("encoding"); - chat_->decoding_func_ = chat_->vm_->GetFunction("decoding"); + chat_->encoding_func_ = chat_->vm_->GetFunction("prefill"); + chat_->decoding_func_ = chat_->vm_->GetFunction("decode"); chat_->encoding_without_cache_func_ = chat_->vm_->GetFunction("encoding_without_cache"); chat_->softmax_func_ = chat_->vm_->GetFunction("softmax_with_temperature"); chat_->get_metadata_func_ = chat_->vm_->GetFunction("get_metadata"); diff --git a/ios/MLCChat/ChatState.swift b/ios/MLCChat/ChatState.swift index cefed3c721..1c4d81d894 100644 --- a/ios/MLCChat/ChatState.swift +++ b/ios/MLCChat/ChatState.swift @@ -126,7 +126,7 @@ class ChatState : ObservableObject { threadWorker.push {[self] in self.appendMessage(role: MessageRole.user, message: prompt) - backend.encode(prompt); + backend.prefill(prompt); while (!backend.stopped()) { assert(self.inProgress); backend.decode(); diff --git a/ios/MLCChat/LLMChat.mm b/ios/MLCChat/LLMChat.mm index d3dbeaf3cc..1862a77800 100644 --- a/ios/MLCChat/LLMChat.mm +++ b/ios/MLCChat/LLMChat.mm @@ -27,7 +27,7 @@ reload_func_ = llm_chat_->GetFunction("reload"); unload_func_ = llm_chat_->GetFunction("unload"); - encode_func_ = llm_chat_->GetFunction("encode"); + prefill_func_ = llm_chat_->GetFunction("prefill"); decode_func_ = llm_chat_->GetFunction("decode"); get_message_ = llm_chat_->GetFunction("get_message"); stopped_func_ = llm_chat_->GetFunction("stopped"); @@ -36,7 +36,7 @@ ICHECK(reload_func_ != nullptr); ICHECK(unload_func_ != nullptr); - ICHECK(encode_func_ != nullptr); + ICHECK(prefill_func_ != nullptr); ICHECK(decode_func_ != nullptr); ICHECK(get_message_ != nullptr); ICHECK(stopped_func_ != nullptr); @@ -64,9 +64,9 @@ void Evaluate() { return get_message_(); } - void Encode(std::string prompt) { - ICHECK(encode_func_ != nullptr); - encode_func_(prompt); + void Prefill(std::string prompt) { + ICHECK(prefill_func_ != nullptr); + prefill_func_(prompt); } bool Stopped() { return stopped_func_(); } @@ -86,7 +86,7 @@ void Encode(std::string prompt) { Module llm_chat_; PackedFunc unload_func_; PackedFunc reload_func_; - PackedFunc encode_func_; + PackedFunc prefill_func_; PackedFunc decode_func_; PackedFunc get_message_; PackedFunc stopped_func_; @@ -112,8 +112,8 @@ - (void)evaluate { LLMChatModuleWrapper::Global()->Evaluate(); } -- (void)encode:(NSString*)prompt { - LLMChatModuleWrapper::Global()->Encode(prompt.UTF8String); +- (void)prefill:(NSString*)prompt { + LLMChatModuleWrapper::Global()->Prefill(prompt.UTF8String); } - (void)decode { diff --git a/ios/MLCChat/MLCChat-Bridging-Header.h b/ios/MLCChat/MLCChat-Bridging-Header.h index 2c929d74e5..1aa19c829d 100644 --- a/ios/MLCChat/MLCChat-Bridging-Header.h +++ b/ios/MLCChat/MLCChat-Bridging-Header.h @@ -10,7 +10,7 @@ - (void)evaluate; - (void)unload; - (void)reload:(NSString*)model_lib modelPath:(NSString*)modelPath; -- (void)encode:(NSString*)prompt; +- (void)prefill:(NSString*)prompt; - (void)decode; - (void)reset; - (NSString*)getMessage; diff --git a/mlc_llm/relax_model/gpt_neox.py b/mlc_llm/relax_model/gpt_neox.py index b2b9a2306c..15d1089cce 100644 --- a/mlc_llm/relax_model/gpt_neox.py +++ b/mlc_llm/relax_model/gpt_neox.py @@ -469,7 +469,7 @@ def create_encoding_func( batch_size = tvm.tir.IntImm("int64", 1) seq_len = tvm.tir.Var("n", "int64") all_seq_len = tvm.tir.Var("m", "int64") - with bb.function("encoding"): + with bb.function("prefill"): model = GPTNeoXForCausalLM(config) input_ids = nn.Placeholder( (batch_size, seq_len), dtype="int32", name="input_ids" @@ -500,7 +500,7 @@ def create_encoding_func( gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) bb.emit_func_output(gv, params) mod = bb.get() - gv = mod.get_global_var("encoding") + gv = mod.get_global_var("prefill") bb.update_func(gv, mod[gv].with_attr("num_input", 3)) @@ -512,7 +512,7 @@ def create_decoding_func( batch_size = tvm.tir.IntImm("int64", 1) seq_len = tvm.tir.IntImm("int64", 1) all_seq_len = tvm.tir.Var("n", "int64") - with bb.function("decoding"): + with bb.function("decode"): model = GPTNeoXForCausalLM(config) input_ids = nn.Placeholder( (batch_size, seq_len), dtype="int32", name="input_ids" @@ -544,7 +544,7 @@ def create_decoding_func( gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) bb.emit_func_output(gv, params) mod = bb.get() - gv = mod.get_global_var("decoding") + gv = mod.get_global_var("decode") bb.update_func(gv, mod[gv].with_attr("num_input", 3)) diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 047bfb2b22..a350781e27 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -559,7 +559,7 @@ def create_encoding_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: bsz = 1 seq_len = tvm.tir.Var("n", "int64") all_seq_len = tvm.tir.Var("m", "int64") - with bb.function("encoding"): + with bb.function("prefill"): model = LlamaForCausalLM(config) input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") all_seq_len_shape = relax.Var( @@ -584,7 +584,7 @@ def create_encoding_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: bb.emit_func_output(gv, params) mod = bb.get() - gv = mod.get_global_var("encoding") + gv = mod.get_global_var("prefill") bb.update_func(gv, mod[gv].with_attr("num_input", 3)) @@ -592,7 +592,7 @@ def create_decoding_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: bsz = 1 all_seq_len = tvm.tir.Var("n", "int64") - with bb.function("decoding"): + with bb.function("decode"): model = LlamaForCausalLM(config) input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") all_seq_len_shape = relax.Var( @@ -617,7 +617,7 @@ def create_decoding_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: bb.emit_func_output(gv, params) mod = bb.get() - gv = mod.get_global_var("decoding") + gv = mod.get_global_var("decode") bb.update_func(gv, mod[gv].with_attr("num_input", 3)) diff --git a/mlc_llm/relax_model/moss.py b/mlc_llm/relax_model/moss.py index 7b16556a26..f0c3f10445 100644 --- a/mlc_llm/relax_model/moss.py +++ b/mlc_llm/relax_model/moss.py @@ -500,7 +500,7 @@ def create_encoding_func( batch_size = tvm.tir.IntImm("int64", 1) seq_len = tvm.tir.Var("n", "int64") all_seq_len = tvm.tir.Var("m", "int64") - with bb.function("encoding"): + with bb.function("prefill"): model = MossForCausalLM(config) input_ids = nn.Placeholder( (batch_size, seq_len), dtype="int32", name="input_ids" @@ -532,7 +532,7 @@ def create_encoding_func( gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) bb.emit_func_output(gv, params) mod = bb.get() - gv = mod.get_global_var("encoding") + gv = mod.get_global_var("prefill") bb.update_func(gv, mod[gv].with_attr("num_input", 3)) @@ -544,7 +544,7 @@ def create_decoding_func( batch_size = tvm.tir.IntImm("int64", 1) seq_len = tvm.tir.IntImm("int64", 1) all_seq_len = tvm.tir.Var("n", "int64") - with bb.function("decoding"): + with bb.function("decode"): model = MossForCausalLM(config) input_ids = nn.Placeholder( (batch_size, seq_len), dtype="int32", name="input_ids" @@ -577,7 +577,7 @@ def create_decoding_func( gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) bb.emit_func_output(gv, params) mod = bb.get() - gv = mod.get_global_var("decoding") + gv = mod.get_global_var("decode") bb.update_func(gv, mod[gv].with_attr("num_input", 3)) diff --git a/tests/chat.py b/tests/chat.py index 78f520d282..9891533825 100644 --- a/tests/chat.py +++ b/tests/chat.py @@ -194,11 +194,11 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: self.tot_seq_len += inputs.shape[1] seq_len_shape = tvm.runtime.ShapeTuple([self.tot_seq_len]) if inputs.shape[1] > 1: - logits, kv_cache = vm["encoding"]( + logits, kv_cache = vm["prefill"]( inputs, seq_len_shape, self.kv_cache, const_params ) else: - logits, kv_cache = vm["decoding"]( + logits, kv_cache = vm["decode"]( inputs, seq_len_shape, self.kv_cache, const_params ) self.kv_cache = kv_cache diff --git a/tests/debug/compare_lib.py b/tests/debug/compare_lib.py index 6c976d2bb8..e796059750 100644 --- a/tests/debug/compare_lib.py +++ b/tests/debug/compare_lib.py @@ -153,7 +153,7 @@ def deploy_to_pipeline(args) -> None: print("Running inference...") print("======================= Starts Encoding =======================") - logits, kv_caches = state.vm["encoding"]( + logits, kv_caches = state.vm["prefill"]( inputs, seq_len_shape, kv_caches, const_params ) print_as_table( @@ -165,7 +165,7 @@ def deploy_to_pipeline(args) -> None: state.cmp_instrument.time_eval_results.clear() state.cmp_instrument.visited.clear() print("======================= Starts Decoding =======================") - logits, kv_caches = state.vm["decoding"]( + logits, kv_caches = state.vm["decode"]( first_sampled_token, second_seq_len_shape, kv_caches, const_params ) print_as_table( diff --git a/tests/evaluate.py b/tests/evaluate.py index 3f066cb22b..4f02e510d8 100644 --- a/tests/evaluate.py +++ b/tests/evaluate.py @@ -106,8 +106,8 @@ def deploy_to_pipeline(args) -> None: kv_caches = vm["create_kv_cache"]() # skip warm up - logits, kv_caches = vm["encoding"](inputs, seq_len_shape, kv_caches, const_params) - logits, kv_caches = vm["decoding"]( + logits, kv_caches = vm["prefill"](inputs, seq_len_shape, kv_caches, const_params) + logits, kv_caches = vm["decode"]( first_sampled_token, second_seq_len_shape, kv_caches, const_params ) device.sync() @@ -115,10 +115,10 @@ def deploy_to_pipeline(args) -> None: kv_caches = vm["create_kv_cache"]() print("Running inference...") start = time.time() - logits, kv_caches = vm["encoding"](inputs, seq_len_shape, kv_caches, const_params) + logits, kv_caches = vm["prefill"](inputs, seq_len_shape, kv_caches, const_params) device.sync() encoding_end = time.time() - logits, kv_caches = vm["decoding"]( + logits, kv_caches = vm["decode"]( first_sampled_token, second_seq_len_shape, kv_caches, const_params ) device.sync() @@ -139,7 +139,7 @@ def deploy_to_pipeline(args) -> None: print("Profiling...") kv_caches = vm["create_kv_cache"]() - logits, kv_caches = vm["encoding"]( + logits, kv_caches = vm["prefill"]( inputs, seq_len_shape, kv_caches, const_params ) print("======================= Encoding Profiling =======================") @@ -151,7 +151,7 @@ def deploy_to_pipeline(args) -> None: ) cmp_instrument.time_eval_results.clear() - logits, kv_caches = vm["decoding"]( + logits, kv_caches = vm["decode"]( first_sampled_token, second_seq_len_shape, kv_caches, const_params ) print("======================= Decoding Profiling =======================")