diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000000..478f75e8fd --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,87 @@ +name: Lint +on: [push, pull_request] +env: + IMAGE: 'mlcaidev/ci-cpu:caab922' + +jobs: + isort: + name: Python / isort + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: '' + - name: Version + run: | + wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh + chmod u+x ./ci/bash.sh + ./ci/bash.sh $IMAGE "conda env export --name ci-lint" + - name: Lint + run: | + ./ci/bash.sh $IMAGE bash ./ci/task/isort.sh + + black: + name: Python / black + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: '' + - name: Version + run: | + wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh + chmod u+x ./ci/bash.sh + ./ci/bash.sh $IMAGE "conda env export --name ci-lint" + - name: Lint + run: | + ./ci/bash.sh $IMAGE bash ./ci/task/black.sh + + mypy: + name: Python / mypy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: '' + - name: Version + run: | + wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh + chmod u+x ./ci/bash.sh + ./ci/bash.sh $IMAGE "conda env export --name ci-lint" + - name: Lint + run: | + ./ci/bash.sh $IMAGE bash ./ci/task/mypy.sh + + pylint: + name: Python / pylint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: '' + - name: Version + run: | + wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh + chmod u+x ./ci/bash.sh + ./ci/bash.sh $IMAGE "conda env export --name ci-lint" + - name: Lint + run: | + ./ci/bash.sh $IMAGE bash ./ci/task/pylint.sh + + clang-format: + name: C++ / clang-format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: '' + ref: ${{ github.event.pull_request.head.sha }} + fetch-depth: 0 + - name: Version + run: | + wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh + chmod u+x ./ci/bash.sh + ./ci/bash.sh $IMAGE "conda env export --name ci-lint" + - name: Lint + run: | + ./ci/bash.sh $IMAGE bash ./ci/task/clang-format.sh diff --git a/3rdparty/tvm b/3rdparty/tvm index e5ca38dd73..30b4fa3c13 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit e5ca38dd735ba4d30782a4a58bf6195861642eb0 +Subproject commit 30b4fa3c13fc80d5c9151a9dc445d22c57ced3e0 diff --git a/README.md b/README.md index bb52c1c735..f20d1c8a93 100644 --- a/README.md +++ b/README.md @@ -4,18 +4,18 @@ [Documentation](https://llm.mlc.ai/docs) | [Blog](https://blog.mlc.ai/) | [Discord][discord-url] -Machine Learning Compilation for Large Language Models (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques. +**M**achine **L**earning **C**ompilation for **L**arge **L**anguage **M**odels (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques. **Universal deployment.** MLC LLM supports the following platforms and hardware: - - - - - + + + + + @@ -28,21 +28,18 @@ Machine Learning Compilation for Large Language Models (MLC LLM) is a high-perfo - + - + - - - - + - + @@ -52,8 +49,25 @@ Machine Learning Compilation for Large Language Models (MLC LLM) is a high-perfo
AMD GPUNVIDIA GPUApple M1/M2 GPUIntel GPU AMD GPUNVIDIA GPUApple GPUIntel GPU
macOS✅ Metal✅ Metal (dGPU) N/A ✅ Metal✅ Metal✅ Metal (iGPU)
Web Browser✅ WebGPU✅ WebGPU✅ WebGPU✅ WebGPU✅ WebGPU and WASM
iOS / iPadOS✅ Metal on Apple M1/M2 GPU✅ Metal on Apple A-series GPU
Android
+ +**Scalable.** MLC LLM scales universally on NVIDIA and AMD GPUs, cloud and gaming GPUs. Below +showcases our single batch decoding performance with prefilling = 1 and decoding = 256. + +Performance of 4-bit CodeLlama-34B and Llama2-70B on two NVIDIA RTX 4090 and two AMD Radeon 7900 XTX: +

+ + +

+ +Scaling of fp16 and 4-bit CodeLlama-34 and Llama2-70B on A100-80G-PCIe and A10G-24G-PCIe, up to 8 GPUs: +

+ +

+ ## News +* [10/18/2023] [[Post]](https://blog.mlc.ai/2023/10/19/Scalable-Language-Model-Inference-on-Multiple-NVDIA-AMD-GPUs) Scalable multi-GPU support for CUDA and ROCm are official. +* [09/02/2023] Prebuilt ROCm 5.7 and CUDA 12.2 package is [available](https://llm.mlc.ai/docs/install/tvm.html#option-1-prebuilt-package). * [08/25/2023] CodeLlama support is up. * [08/14/2023] [[Post]](https://blog.mlc.ai/2023/08/09/GPU-Accelerated-LLM-on-Orange-Pi) Mali GPU support is up on Orange Pi. * [08/09/2023] [[Post]](https://blog.mlc.ai/2023/08/09/Making-AMD-GPUs-competitive-for-LLM-inference) ROCm backend is mature to use. @@ -66,7 +80,55 @@ Machine Learning Compilation for Large Language Models (MLC LLM) is a high-perfo ## Getting Started -Please visit our [this page](https://llm.mlc.ai/docs/index.html#getting-started) for detailed instructions. +Please visit our [documentation](https://llm.mlc.ai/docs/index.html#getting-started) for detailed instructions. + +## Model Support + +MLC LLM supports a wide range of model architectures and variants. We have the following prebuilts which you can +use off-the-shelf. Visit [Prebuilt Models](https://llm.mlc.ai/docs/prebuilt_models.html) to see the full list, and [Compile Models via MLC](https://llm.mlc.ai/docs/compilation/compile_models.html) to see how to use models not on this list. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ArchitecturePrebuilt Model Variants
LlamaLlama-2, Code Llama, Vicuna, WizardLM, WizardMath, OpenOrca Platypus2, FlagAlpha Llama-2 Chinese, georgesung Llama-2 Uncensored
GPT-NeoXRedPajama
GPT-J
RWKVRWKV-raven
MiniGPT
GPTBigCodeWizardCoder
ChatGLM
StableLM
## Universal Deployment APIs diff --git a/android/MLCChat/app/src/main/assets/app-config.json b/android/MLCChat/app/src/main/assets/app-config.json index ddbcb793ae..fb7c4546b3 100644 --- a/android/MLCChat/app/src/main/assets/app-config.json +++ b/android/MLCChat/app/src/main/assets/app-config.json @@ -1,9 +1,14 @@ { "model_libs": [ + "Llama-2-7b-chat-hf-q4f16_0", "Llama-2-7b-chat-hf-q4f16_1", "RedPajama-INCITE-Chat-3B-v1-q4f16_1" ], "model_list": [ + { + "model_url": "https://huggingface.co/mlc-ai/mlc-chat-Llama-2-7b-chat-hf-q4f16_0/", + "local_id": "Llama-2-7b-chat-hf-q4f16_0" + }, { "model_url": "https://huggingface.co/mlc-ai/mlc-chat-Llama-2-7b-chat-hf-q4f16_1/", "local_id": "Llama-2-7b-chat-hf-q4f16_1" diff --git a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt index e1b5928019..f51d56ec10 100644 --- a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt +++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt @@ -2,6 +2,9 @@ package ai.mlc.mlcchat import ai.mlc.mlcllm.ChatModule import android.app.Application +import android.content.ClipData +import android.content.ClipboardManager +import android.content.Context import android.os.Environment import android.widget.Toast import androidx.compose.runtime.mutableStateOf @@ -23,6 +26,8 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { val modelList = emptyList().toMutableStateList() val chatState = ChatState() val modelSampleList = emptyList().toMutableStateList() + private var showAlert = mutableStateOf(false) + private var alertMessage = mutableStateOf("") private var appConfig = AppConfig( emptyList(), emptyList().toMutableList(), @@ -44,13 +49,38 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { loadAppConfig() } + fun supportedModelLibs(): List { + return appConfig.modelLibs + } + + fun isShowingAlert(): Boolean { + return showAlert.value + } + + fun errorMessage(): String { + return alertMessage.value + } + + fun dismissAlert() { + require(showAlert.value) + showAlert.value = false + } + + fun copyError() { + require(showAlert.value) + val clipboard = + application.getSystemService(Context.CLIPBOARD_SERVICE) as ClipboardManager + clipboard.setPrimaryClip(ClipData.newPlainText("MLCChat", errorMessage())) + } + + private fun issueAlert(error: String) { + showAlert.value = true + alertMessage.value = error + } + fun requestAddModel(url: String, localId: String?) { if (localId != null && localIdSet.contains(localId)) { - Toast.makeText( - application, - "localId: $localId has been occupied", - Toast.LENGTH_SHORT - ).show() + issueAlert("localId: $localId has been occupied") } else { downloadModelConfig(if (url.endsWith("/")) url else "$url/", localId, false) } @@ -58,11 +88,7 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { fun requestDeleteModel(localId: String) { deleteModel(localId) - Toast.makeText( - application, - "Model: $localId has been deleted", - Toast.LENGTH_SHORT - ).show() + issueAlert("Model: $localId has been deleted") } @@ -133,11 +159,7 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { private fun isModelConfigAllowed(modelConfig: ModelConfig): Boolean { if (appConfig.modelLibs.contains(modelConfig.modelLib)) return true; viewModelScope.launch { - Toast.makeText( - application, - "Model lib ${modelConfig.modelLib} is not supported.", - Toast.LENGTH_SHORT - ).show() + issueAlert("Model lib ${modelConfig.modelLib} is not supported.") } return false } @@ -169,11 +191,7 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { } if (localIdSet.contains(modelConfig.localId)) { tempFile.delete() - Toast.makeText( - application, - "${modelConfig.localId} has been used, please consider another local ID", - Toast.LENGTH_SHORT - ).show() + issueAlert("${modelConfig.localId} has been used, please consider another local ID") return@launch } if (!isModelConfigAllowed(modelConfig)) { @@ -188,21 +206,13 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { addModelConfig(modelConfig, modelUrl, isBuiltin) } catch (e: Exception) { viewModelScope.launch { - Toast.makeText( - application, - "Add model failed: ${e.localizedMessage}", - Toast.LENGTH_SHORT - ).show() + issueAlert("Add model failed: ${e.localizedMessage}") } } } } catch (e: Exception) { viewModelScope.launch { - Toast.makeText( - application, - "Download model config failed: ${e.localizedMessage}", - Toast.LENGTH_SHORT - ).show() + issueAlert("Download model config failed: ${e.localizedMessage}") } } diff --git a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/StartView.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/StartView.kt index ee2833fca0..87fba77a05 100644 --- a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/StartView.kt +++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/StartView.kt @@ -20,6 +20,7 @@ import androidx.compose.material.icons.outlined.Delete import androidx.compose.material.icons.outlined.Download import androidx.compose.material.icons.outlined.Pause import androidx.compose.material.icons.outlined.Schedule +import androidx.compose.material3.AlertDialog import androidx.compose.material3.Divider import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Icon @@ -94,9 +95,14 @@ fun StartView( } } if (isAddingModel) { - Text( - text = "Add Model Variant", modifier = Modifier.padding(top = 10.dp) - ) + Text(text = "Supported Base Model Libs", modifier = Modifier.padding(top = 10.dp)) + for (lib in appViewModel.supportedModelLibs()) { + Text( + text = lib, + style = MaterialTheme.typography.bodyMedium + ) + } + Text(text = "Add Model Variant", modifier = Modifier.padding(top = 10.dp)) LazyColumn() { items( items = appViewModel.modelSampleList @@ -148,10 +154,36 @@ fun StartView( } } } - + if (appViewModel.isShowingAlert()) { + AlertDialog( + onDismissRequest = { appViewModel.dismissAlert() }, + onConfirmation = { appViewModel.copyError() }, + error = appViewModel.errorMessage() + ) + } } } +@ExperimentalMaterial3Api +@Composable +fun AlertDialog( + onDismissRequest: () -> Unit, + onConfirmation: () -> Unit, + error: String, +) { + AlertDialog( + title = { Text(text = "Error") }, + text = { Text(text = error) }, + onDismissRequest = { onDismissRequest() }, + confirmButton = { + TextButton(onClick = { onConfirmation() }) { Text("Copy") } + }, + dismissButton = { + TextButton(onClick = { onDismissRequest() }) { Text("Dismiss") } + } + ) +} + @Composable fun ModelView( navController: NavController, diff --git a/android/prepare_libs.sh b/android/prepare_libs.sh index 72457954c0..938ffd5cd8 100755 --- a/android/prepare_libs.sh +++ b/android/prepare_libs.sh @@ -9,7 +9,10 @@ python prepare_model_lib.py cd build touch config.cmake -echo "set(TVM_HOME ${TVM_HOME})" >> config.cmake +if [ ${TVM_HOME-0} -ne 0 ]; then + echo "set(TVM_HOME ${TVM_HOME})" >> config.cmake +fi + cmake .. \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \ diff --git a/android/src/cpp/tvm_runtime.h b/android/src/cpp/tvm_runtime.h index 5a1267119d..2caaaaeb1a 100644 --- a/android/src/cpp/tvm_runtime.h +++ b/android/src/cpp/tvm_runtime.h @@ -1,13 +1,12 @@ #define DMLC_USE_LOGGING_LIBRARY #define TVM_USE_LIBBACKTRACE 0 +#include #include #include #include #include -#include - static_assert(TVM_LOG_CUSTOMIZE == 1, "TVM_LOG_CUSTOMIZE must be 1"); namespace tvm { diff --git a/ci/bash.sh b/ci/bash.sh new file mode 100755 index 0000000000..d54eae48ef --- /dev/null +++ b/ci/bash.sh @@ -0,0 +1,91 @@ +#!/usr/bin/env bash + +# +# Start a bash, mount /workspace to be current directory. +# +# Usage: docker/bash.sh +# Starts an interactive session +# +# Usage2: docker/bash.sh [COMMAND] +# Execute command in the docker image, non-interactive +# +if [ "$#" -lt 1 ]; then + echo "Usage: docker/bash.sh [--no-gpu] [COMMAND]" + exit -1 +fi + +if [ "$1" == "--no-gpu" ]; then + ENABLE_NV_DOCKER=0 + shift +else + ENABLE_NV_DOCKER=1 +fi + +DOCKER_IMAGE_NAME=("$1") + + +if [ "$#" -eq 1 ]; then + COMMAND="bash" + if [[ $(uname) == "Darwin" ]]; then + # Docker's host networking driver isn't supported on macOS. + # Use default bridge network and expose port for jupyter notebook. + DOCKER_EXTRA_PARAMS=("-it -p 8888:8888") + else + DOCKER_EXTRA_PARAMS=("-it --net=host") + fi +else + shift 1 + COMMAND=("$@") +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +WORKSPACE="$(pwd)" + +# Use nvidia-docker if the container is GPU. +if [[ ! -z $CUDA_VISIBLE_DEVICES ]]; then + CUDA_ENV="-e CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}" +else + CUDA_ENV="" +fi + +# If this is an wheel test command then pass the env var to docker. +if [[ ! -z $WHEEL_TEST ]]; then + WHEEL_TEST="-e WHEEL_TEST=${WHEEL_TEST}" +fi + +if [[ "${DOCKER_IMAGE_NAME}" == *"cu"* ]]; then + if [ "$ENABLE_NV_DOCKER" -eq 1 ]; then + if ! type "nvidia-docker" 1> /dev/null 2> /dev/null + then + DOCKER_BINARY="docker" + CUDA_ENV=" --gpus all "${CUDA_ENV} + else + DOCKER_BINARY="nvidia-docker" + fi + else + DOCKER_BINARY="docker" + fi +else + DOCKER_BINARY="docker" +fi + +# Print arguments. +echo "WORKSPACE: ${WORKSPACE}" +echo "DOCKER CONTAINER NAME: ${DOCKER_IMAGE_NAME}" +echo "" + +echo "Running '${COMMAND[@]}' inside ${DOCKER_IMAGE_NAME}..." + +# By default we cleanup - remove the container once it finish running (--rm) +# and share the PID namespace (--pid=host) so the process inside does not have +# pid 1 and SIGKILL is propagated to the process inside (jenkins can kill it). + +${DOCKER_BINARY} run --rm --pid=host\ + -v ${WORKSPACE}:/workspace \ + -v ${SCRIPT_DIR}:/docker \ + -w /workspace \ + ${CUDA_ENV} \ + ${WHEEL_TEST} \ + ${DOCKER_EXTRA_PARAMS[@]} \ + ${DOCKER_IMAGE_NAME} \ + ${COMMAND[@]} diff --git a/ci/task/black.sh b/ci/task/black.sh new file mode 100755 index 0000000000..dcc4b42555 --- /dev/null +++ b/ci/task/black.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -eo pipefail + +source ~/.bashrc +micromamba activate ci-lint +export NUM_THREADS=$(nproc) +export PYTHONPATH="./python:$PYTHONPATH" + +set -x + +black --check --workers $NUM_THREADS \ + ./python/ \ + ./tests/python diff --git a/ci/task/clang-format.sh b/ci/task/clang-format.sh new file mode 100755 index 0000000000..54780ec4f9 --- /dev/null +++ b/ci/task/clang-format.sh @@ -0,0 +1,67 @@ +#!/bin/bash +set -eo pipefail + +source ~/.bashrc +micromamba activate ci-lint +export NUM_THREADS=$(nproc) +export PYTHONPATH="./python:$PYTHONPATH" + +set -x +git config --global --add safe.directory '*' + +INPLACE_FORMAT=${INPLACE_FORMAT:=false} +LINT_ALL_FILES=true +REVISION=$(git rev-list --max-parents=0 HEAD) + +while (($#)); do + case "$1" in + -i) + INPLACE_FORMAT=true + shift 1 + ;; + --rev) + LINT_ALL_FILES=false + REVISION=$2 + shift 2 + ;; + *) + echo "Usage: clang-format.sh [-i] [--rev ]" + echo "" + echo "Run clang-format on files that changed since or on all files in the repo" + echo "Examples:" + echo "- Compare last one commit: clang-format.sh --rev HEAD~1" + echo "- Compare against upstream/main: clang-format.sh --rev upstream/main" + echo "The -i will format files in-place instead of checking them." + exit 1 + ;; + esac +done + +cleanup() { + if [ -f /tmp/$$.clang-format.txt ]; then + echo "" + echo "---------clang-format log----------" + cat /tmp/$$.clang-format.txt + fi + rm -rf /tmp/$$.clang-format.txt +} +trap cleanup 0 + +if [[ "$INPLACE_FORMAT" == "true" ]]; then + echo "Running inplace git-clang-format against $REVISION" + git-clang-format --extensions h,hh,hpp,c,cc,cpp,mm "$REVISION" + exit 0 +fi + +if [[ "$LINT_ALL_FILES" == "true" ]]; then + echo "Running git-clang-format against all C++ files" + git-clang-format --diff --extensions h,hh,hpp,c,cc,cpp,mm "$REVISION" 1>/tmp/$$.clang-format.txt +else + echo "Running git-clang-format against $REVISION" + git-clang-format --diff --extensions h,hh,hpp,c,cc,cpp,mm "$REVISION" 1>/tmp/$$.clang-format.txt +fi + +if grep --quiet -E "diff" (); } -ModelPaths ModelPaths::Find(const std::string& device_name, const std::string& local_id) { +ModelPaths ModelPaths::Find(const std::string& device_name, const std::string& local_id, + const std::string& user_lib_path) { // Step 1. Find config path std::filesystem::path config_path; if (auto path = TryInferMLCChatConfig(local_id)) { @@ -368,26 +370,36 @@ ModelPaths ModelPaths::Find(const std::string& device_name, const std::string& l } std::cout << "Use model weights: " << params_json << std::endl; // Step 3. Find model lib path - std::string lib_local_id = ReadStringFromJSONFile(config_path, "model_lib"); - std::string lib_name = lib_local_id + "-" + device_name; std::filesystem::path lib_path; - if (auto path = FindFile({lib_local_id, - "dist/prebuilt/lib", // Using prebuilt workflow - "dist/" + local_id, "dist/prebuilt/" + lib_local_id}, - { - lib_name + GetArchSuffix(), - lib_name, - }, - GetLibSuffixes())) { - lib_path = path.value(); + if (!user_lib_path.empty()) { + lib_path = user_lib_path; + if (!std::filesystem::exists(lib_path) || !std::filesystem::is_regular_file(lib_path)) { + LOG(FATAL) << "The `lib_path` you passed in is not a file: " << user_lib_path << "\n"; + exit(1); + } } else { - LOG(FATAL) << "Cannot find the model library that corresponds to `" << lib_local_id << "`.\n" - << "We searched over the following possible paths: \n" - << "- " + lib_local_id << "\n" - << "- dist/prebuilt/lib \n" - << "- dist/" + local_id << "\n" - << "- dist/prebuilt/" + lib_local_id; - exit(1); + std::string lib_local_id = ReadStringFromJSONFile(config_path, "model_lib"); + std::string lib_name = lib_local_id + "-" + device_name; + if (auto path = FindFile({lib_local_id, + "dist/prebuilt/lib", // Using prebuilt workflow + "dist/" + local_id, "dist/prebuilt/" + lib_local_id}, + { + lib_name + GetArchSuffix(), + lib_name, + }, + GetLibSuffixes())) { + lib_path = path.value(); + } else { + LOG(FATAL) << "Cannot find the model library that corresponds to `" << lib_local_id << "`.\n" + << "We searched over the following possible paths: \n" + << "- " + lib_local_id << "\n" + << "- dist/prebuilt/lib \n" + << "- dist/" + local_id << "\n" + << "- dist/prebuilt/" + lib_local_id << "\n" + << "If you would like to directly specify the full model library path, you may " + << "consider passing in the `--model-lib-path` argument.\n"; + exit(1); + } } std::cout << "Use model library: " << lib_path << std::endl; return ModelPaths{config_path, params_json, lib_path}; @@ -427,8 +439,8 @@ void Converse(ChatModule* chat, const std::string& input, int stream_interval, * \param stream_interval The interval that should be used for streaming the response. */ void Chat(ChatModule* chat, const std::string& device_name, std::string local_id, - int stream_interval = 2) { - ModelPaths model = ModelPaths::Find(device_name, local_id); + std::string lib_path, int stream_interval = 2) { + ModelPaths model = ModelPaths::Find(device_name, local_id, lib_path); PrintSpecialCommands(); chat->Reload(model); chat->ProcessSystemPrompts(); @@ -456,7 +468,7 @@ void Chat(ChatModule* chat, const std::string& device_name, std::string local_id if (new_local_id.empty()) { new_local_id = local_id; } - model = ModelPaths::Find(device_name, new_local_id); + model = ModelPaths::Find(device_name, new_local_id, lib_path); chat->Reload(model); local_id = new_local_id; } else if (input.substr(0, 5) == "/help") { @@ -468,9 +480,19 @@ void Chat(ChatModule* chat, const std::string& device_name, std::string local_id } int main(int argc, char* argv[]) { - argparse::ArgumentParser args("mlc_chat"); - - args.add_argument("--model"); + argparse::ArgumentParser args("mlc_chat_cli"); + + args.add_description( + "MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box.\n" + "Note: the --model argument is required. It can either be the model name with its " + "quantization scheme or a full path to the model folder. In the former case, the " + "provided name will be used to search for the model folder over possible paths. " + "--model-lib-path argument is optional. If unspecified, the --model argument will be used " + "to search for the library file over possible paths."); + + args.add_argument("--model").help("[required] the model to use"); + args.add_argument("--model-lib-path") + .help("[optional] the full path to the model library file to use"); args.add_argument("--device").default_value("auto"); args.add_argument("--evaluate").default_value(false).implicit_value(true); args.add_argument("--eval-prompt-len").default_value(128).scan<'i', int>(); @@ -485,6 +507,10 @@ int main(int argc, char* argv[]) { } std::string local_id = args.get("--model"); + std::string lib_path; + if (args.present("--model-lib-path")) { + lib_path = args.get("--model-lib-path"); + } auto [device_name, device_id] = DetectDevice(args.get("--device")); try { @@ -494,14 +520,14 @@ int main(int argc, char* argv[]) { // that are not supposed to be used in chat app setting int prompt_len = args.get("--eval-prompt-len"); int gen_len = args.get("--eval-gen-len"); - ModelPaths model = ModelPaths::Find(device_name, local_id); + ModelPaths model = ModelPaths::Find(device_name, local_id, lib_path); tvm::runtime::Module chat_mod = mlc::llm::CreateChatModule(GetDevice(device_name, device_id)); std::string model_path = model.config.parent_path().string(); tvm::runtime::Module lib = tvm::runtime::Module::LoadFromFile(model.lib.string()); chat_mod.GetFunction("reload")(lib, tvm::String(model_path)); chat_mod.GetFunction("evaluate")(prompt_len, gen_len); } else { - Chat(&chat, device_name, local_id); + Chat(&chat, device_name, local_id, lib_path); } } catch (const std::runtime_error& err) { std::cerr << err.what() << std::endl; diff --git a/cpp/conv_templates.cc b/cpp/conv_templates.cc index 69f84b2421..ae91bf2070 100644 --- a/cpp/conv_templates.cc +++ b/cpp/conv_templates.cc @@ -473,6 +473,25 @@ Conversation VanillaLM() { return conv; } +Conversation StableLM3B() { + Conversation conv; + conv.name = "stablelm-3b"; + conv.system = ""; + conv.roles = {"Prompt", "LM"}; + conv.messages = {}; + conv.separator_style = SeparatorStyle::kLM; + conv.offset = 0; + conv.seps = {""}; + conv.role_msg_sep = ""; + conv.role_empty_sep = ""; + // TODO(mlc-team): add eos to mlc-chat-config + // and remove eos from stop token setting. + // so the same template works for more tokenizers + conv.stop_tokens = {0}; + conv.add_bos = true; + return conv; +} + Conversation GPTBigCode() { Conversation conv; conv.name = "gpt_bigcode"; @@ -580,6 +599,7 @@ Conversation Conversation::FromTemplate(const std::string& name) { {"minigpt", MiniGPT}, {"moss", MOSS}, {"LM", VanillaLM}, + {"stablelm-3b", StableLM3B}, {"gpt_bigcode", GPTBigCode}, {"wizardlm_7b", WizardLM7B}, {"wizard_coder_or_math", WizardCoderOrMATH}, diff --git a/cpp/conversation.h b/cpp/conversation.h index 82332aede6..6211c24c25 100644 --- a/cpp/conversation.h +++ b/cpp/conversation.h @@ -283,7 +283,8 @@ class Conversation { /* place_in_prompt= */ place_in_prompt); } else { ICHECK(this->separator_style == SeparatorStyle::kLM || - this->separator_style == SeparatorStyle::kCodeCompletion) << "Unsupported separator_style"; + this->separator_style == SeparatorStyle::kCodeCompletion) + << "Unsupported separator_style"; // special handle LM, LM mode have no memory // and only returns last one if (this->messages.size() >= 2) { diff --git a/cpp/image_embed.h b/cpp/image_embed.h index 87b862242c..e0e21da686 100644 --- a/cpp/image_embed.h +++ b/cpp/image_embed.h @@ -6,17 +6,7 @@ #include #include -#ifndef MLC_LLM_DLL -#ifdef _WIN32 -#ifdef MLC_LLM_EXPORTS -#define MLC_LLM_DLL __declspec(dllexport) -#else -#define MLC_LLM_DLL __declspec(dllimport) -#endif -#else -#define MLC_LLM_DLL __attribute__((visibility("default"))) -#endif -#endif +#include "base.h" namespace mlc { namespace llm { diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index f3fdcc3a36..35a8d1f41e 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -32,6 +32,9 @@ #include #include "conversation.h" +#include "random.h" +#include "support.h" +#include "tokenizers.h" namespace mlc { namespace llm { @@ -39,62 +42,6 @@ namespace llm { using tvm::Device; using namespace tvm::runtime; namespace { -//---------------------------- -// Tokenizers -//---------------------------- -using tokenizers::Tokenizer; - -std::string LoadBytesFromFile(const std::string& path) { - std::ifstream fs(path, std::ios::in | std::ios::binary); - ICHECK(!fs.fail()) << "Cannot open " << path; - std::string data; - fs.seekg(0, std::ios::end); - size_t size = static_cast(fs.tellg()); - fs.seekg(0, std::ios::beg); - data.resize(size); - fs.read(data.data(), size); - return data; -} - -std::unique_ptr TokenizerFromPath(const std::string& _path) { - std::filesystem::path path(_path); - std::filesystem::path sentencepiece; - std::filesystem::path huggingface; - std::filesystem::path rwkvworld; - CHECK(std::filesystem::exists(path)) << "Cannot find tokenizer via path: " << _path; - if (std::filesystem::is_directory(path)) { - sentencepiece = path / "tokenizer.model"; - huggingface = path / "tokenizer.json"; - rwkvworld = path / "tokenizer_model"; - // Check ByteLevelBPE - { - std::filesystem::path merges_path = path / "merges.txt"; - std::filesystem::path vocab_path = path / "vocab.json"; - std::filesystem::path added_tokens_path = path / "added_tokens.json"; - if (std::filesystem::exists(merges_path) && std::filesystem::exists(vocab_path) && - std::filesystem::exists(added_tokens_path)) { - std::string vocab = LoadBytesFromFile(vocab_path.string()); - std::string merges = LoadBytesFromFile(merges_path.string()); - std::string added_tokens = LoadBytesFromFile(added_tokens_path.string()); - return Tokenizer::FromBlobByteLevelBPE(vocab, merges, added_tokens); - } - } - } else { - sentencepiece = path.parent_path() / "tokenizer.model"; - huggingface = path.parent_path() / "tokenizer.json"; - rwkvworld = path.parent_path() / "tokenizer_model"; - } - if (std::filesystem::exists(sentencepiece)) { - return Tokenizer::FromBlobSentencePiece(LoadBytesFromFile(sentencepiece.string())); - } - if (std::filesystem::exists(huggingface)) { - return Tokenizer::FromBlobJSON(LoadBytesFromFile(huggingface.string())); - } - if (std::filesystem::exists(rwkvworld)) { - return Tokenizer::FromBlobRWKVWorld(rwkvworld.string()); - } - LOG(FATAL) << "Cannot find any tokenizer under: " << _path; -} //------------------------------ // support functions @@ -315,23 +262,6 @@ struct FunctionTable { PackedFunc fkvcache_array_popn_; }; -class RandomGenerator { - private: - std::mt19937 gen; - std::uniform_real_distribution<> dis; - - RandomGenerator(int seed) : gen(seed), dis(0.0, 1.0) {} - - public: - static RandomGenerator& GetInstance(int seed = std::random_device{}()) { - static RandomGenerator instance(seed); - return instance; - } - - double GetRandomNumber() { return dis(gen); } - - void SetSeed(int seed) { gen.seed(seed); } -}; } // namespace //------------------------------ @@ -507,7 +437,7 @@ class LLMChat { /*! * \brief Reload model, tokenizers and configurations from the specified model path. - * \param executable The module to reload. + * \param reload_lib The module to reload, it can either be a path to the library or a tvm Module. * \param model_path The path to search for models. * \param app_config_json The JSON string used to partially override the configuration loaded from * disk, default to empty string. @@ -599,7 +529,20 @@ class LLMChat { * \brief Get input tokens based on history * \param place_in_prompt The place of the input message in the prompt. */ - std::vector GetInputTokens(PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll) { + std::vector GetInputTokens(PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll, + picojson::object generation_config = picojson::object()) { + // prepare generation settings + // the generation_config will not override the original config + // since is only used for this generation + int64_t gen_mean_gen_len; + if (generation_config.count("mean_gen_len")) { + CHECK(generation_config["mean_gen_len"].is()); + gen_mean_gen_len = generation_config["mean_gen_len"].get(); + } else { + gen_mean_gen_len = this->mean_gen_len_; + } + + // work on input tokens std::vector tokens; std::vector prompts; @@ -619,7 +562,7 @@ class LLMChat { std::string all_prompt = GetConcatPrompt(prompts, 0, 0); std::vector encoded = this->tokenizer_->Encode(all_prompt); tokens.insert(tokens.end(), encoded.begin(), encoded.end()); - if (this->total_seq_len_ + tokens.size() + this->mean_gen_len_ < this->max_window_size_) { + if (this->total_seq_len_ + tokens.size() + gen_mean_gen_len < this->max_window_size_) { return tokens; } // need shift window and re-encode @@ -656,11 +599,11 @@ class LLMChat { if (tokens.size() >= this->max_window_size_) { LOG(WARNING) << "The prompt tokens are more than `max_window_size`, the input will be truncated."; - ICHECK_GT(this->max_window_size_, this->mean_gen_len_); + ICHECK_GT(this->max_window_size_, gen_mean_gen_len); std::vector truncated_tokens( - tokens.end() - (this->max_window_size_ - this->mean_gen_len_), tokens.end()); + tokens.end() - (this->max_window_size_ - gen_mean_gen_len), tokens.end()); return truncated_tokens; - } else if (tokens.size() + this->mean_gen_len_ >= this->max_window_size_) { + } else if (tokens.size() + gen_mean_gen_len >= this->max_window_size_) { LOG(WARNING) << "The prompt tokens are too long and the generated text may be incomplete, due to " "limited `max_window_size`. "; @@ -695,8 +638,10 @@ class LLMChat { return view; } - std::vector PrepareBeforeEmbedding(std::string inp, bool append_conversation = true, - PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll) { + std::vector PrepareBeforeEmbedding( + std::string inp, bool append_conversation = true, + PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll, + picojson::object generation_config = picojson::object()) { if (conversation_.separator_style == SeparatorStyle::kLM || conversation_.separator_style == SeparatorStyle::kCodeCompletion) { this->ResetChat(); @@ -705,7 +650,7 @@ class LLMChat { this->ResetRuntimeStats(); } output_ids_.clear(); - appeared_token_ids_.clear(); + appeared_token_freq_.clear(); output_message_.clear(); stop_triggered_ = false; if (append_conversation) { @@ -713,7 +658,7 @@ class LLMChat { conversation_.AppendReplyHeader(conversation_.roles[1]); } - return this->GetInputTokens(place_in_prompt); + return this->GetInputTokens(place_in_prompt, generation_config); } /*! @@ -724,9 +669,14 @@ class LLMChat { * \return the embedding of the tokenized prompt. */ ObjectRef EmbedStep(std::string inp, bool append_conversation = true, - PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll) { + PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll, + String generation_config_str = "") { + // process generation settings + picojson::object generation_config = + this->LoadGenerationConfigFromString(generation_config_str); + std::vector prompt_tokens = - PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt); + PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt, generation_config); int64_t token_len = static_cast(prompt_tokens.size()); if (token_len == 0) { return NDArray::Empty({}, DataType::Float(32), device_); @@ -755,7 +705,8 @@ class LLMChat { * \param embedding The embedding to prefill with. * \param decode_next_token Whether to decode next token. */ - void PrefillWithEmbedStep(NDArray embedding, bool decode_next_token = true) { + void PrefillWithEmbedStep(NDArray embedding, bool decode_next_token = true, + String generation_config_str = "") { if (ft_.use_disco) { LOG(FATAL) << "NotImplementedError: Distributed inference is not supported for this model"; throw; @@ -774,13 +725,16 @@ class LLMChat { return; } - int32_t next_token = this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_); + picojson::object generation_config = + this->LoadGenerationConfigFromString(generation_config_str); + + int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config); auto tend = std::chrono::high_resolution_clock::now(); this->prefill_total_time += static_cast((tend - tstart).count()) / 1e9; this->prefill_total_tokens += token_len; - this->ProcessNextToken(next_token); + this->ProcessNextToken(next_token, generation_config); } /*! @@ -791,20 +745,25 @@ class LLMChat { * \param place_in_prompt The place of the input message in the prompt. */ void PrefillStep(std::string inp, bool append_conversation = true, bool decode_next_token = true, - PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll) { + PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll, + String generation_config_str = "") { if (ft_.embed_func_.defined() && ft_.prefill_with_embed_func_.defined()) { // Temporarily placed inside `PrefillStep` for compatibility in transition. // Will be separated out in the future. if (ft_.use_disco) { LOG(FATAL) << "NotImplementedError: Distributed inference is not supported for this model"; } - NDArray embedding = Downcast(EmbedStep(inp, append_conversation, place_in_prompt)); - PrefillWithEmbedStep(embedding, decode_next_token); + NDArray embedding = Downcast( + EmbedStep(inp, append_conversation, place_in_prompt, generation_config_str)); + PrefillWithEmbedStep(embedding, decode_next_token, generation_config_str); return; } + picojson::object generation_config = + this->LoadGenerationConfigFromString(generation_config_str); + std::vector prompt_tokens = - this->PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt); + this->PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt, generation_config); int64_t token_len = static_cast(prompt_tokens.size()); if (token_len == 0) return; if (ft_.use_disco) { @@ -824,16 +783,19 @@ class LLMChat { return; } - int32_t next_token = this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_); + int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config); auto tend = std::chrono::high_resolution_clock::now(); this->prefill_total_time += static_cast((tend - tstart).count()) / 1e9; this->prefill_total_tokens += token_len; - this->ProcessNextToken(next_token); + this->ProcessNextToken(next_token, generation_config); } - void DecodeStep() { + void DecodeStep(String generation_config_str = "") { + picojson::object generation_config = + this->LoadGenerationConfigFromString(generation_config_str); + ICHECK(!output_ids_.empty()); int32_t last_token = output_ids_.back(); tvm::runtime::NDArray input_data = GetInputTokenNDArray({last_token}); @@ -843,13 +805,13 @@ class LLMChat { NDArray logits_on_device = this->ForwardTokens({last_token}, total_seq_len_ + 1); total_seq_len_ += 1; - int32_t next_token = this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_); + int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config); auto tend = std::chrono::high_resolution_clock::now(); this->decode_total_time += static_cast((tend - tstart).count()) / 1e9; this->decode_total_tokens += 1; - this->ProcessNextToken(next_token); + this->ProcessNextToken(next_token, generation_config); } bool Stopped() { return stop_triggered_; } @@ -921,7 +883,7 @@ class LLMChat { { auto tstart = std::chrono::high_resolution_clock::now(); logits_on_device = this->ForwardTokens(tokens, tokens.size()); - tokens.push_back(this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_)); + tokens.push_back(this->SampleTokenFromLogits(logits_on_device)); auto tend = std::chrono::high_resolution_clock::now(); this->prefill_total_time = static_cast((tend - tstart).count()) / 1e9; @@ -933,7 +895,7 @@ class LLMChat { auto tstart = std::chrono::high_resolution_clock::now(); for (int64_t len = 1; len < generate_len; ++len) { logits_on_device = this->ForwardTokens({tokens.back()}, tokens.size()); - tokens.push_back(this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_)); + tokens.push_back(this->SampleTokenFromLogits(logits_on_device)); } auto tend = std::chrono::high_resolution_clock::now(); @@ -950,6 +912,8 @@ class LLMChat { picojson::object config; config["temperature"] = picojson::value(this->temperature_); config["repetition_penalty"] = picojson::value(this->repetition_penalty_); + config["presence_penalty"] = picojson::value(this->presence_penalty_); + config["frequency_penalty"] = picojson::value(this->frequency_penalty_); config["top_p"] = picojson::value(this->top_p_); config["mean_gen_len"] = picojson::value(this->mean_gen_len_); config["max_gen_len"] = picojson::value(this->max_gen_len_); @@ -957,29 +921,109 @@ class LLMChat { config["conv_config"] = this->conversation_.SerializeToJSON(); return picojson::value(config); } + + picojson::object LoadGenerationConfigFromString(const std::string& generation_config_str) { + picojson::object generation_config = picojson::object(); + if (!generation_config_str.empty()) { + picojson::value generation_config_json; + picojson::parse(generation_config_json, generation_config_str); + generation_config = generation_config_json.get(); + } + return generation_config; + } + + void ReadGenerationConfig(picojson::object generation_config, double* gen_temperature, + NDArray* gen_temperature_arr, double* gen_repetition_penalty, + double* gen_presence_penalty, double* gen_frequency_penalty, + double* gen_top_p) { + if (generation_config.count("temperature")) { + CHECK(generation_config["temperature"].is()); + *gen_temperature = generation_config["temperature"].get(); + + *gen_temperature_arr = NDArray::Empty({}, DataType::Float(32), device_); + float temperature_cast = static_cast(*gen_temperature); + gen_temperature_arr->CopyFromBytes(&temperature_cast, sizeof(float)); + } else { + *gen_temperature = this->temperature_; + *gen_temperature_arr = this->temperature_arr_; + } + if (generation_config.count("repetition_penalty")) { + CHECK(generation_config["repetition_penalty"].is()); + CHECK(generation_config["repetition_penalty"].get() > 0) + << "Repetition penalty must be a positive number!"; + *gen_repetition_penalty = generation_config["repetition_penalty"].get(); + } else { + *gen_repetition_penalty = this->repetition_penalty_; + } + if (generation_config.count("presence_penalty")) { + CHECK(generation_config["presence_penalty"].is()); + CHECK(abs(generation_config["presence_penalty"].get()) <= 2) + << "Presence penalty must be in the range -2 to 2!"; + *gen_presence_penalty = generation_config["presence_penalty"].get(); + } else { + *gen_presence_penalty = this->presence_penalty_; + } + if (generation_config.count("frequency_penalty")) { + CHECK(generation_config["frequency_penalty"].is()); + CHECK(abs(generation_config["frequency_penalty"].get()) <= 2) + << "Frequency penalty must be in the range -2 to 2!"; + *gen_frequency_penalty = generation_config["frequency_penalty"].get(); + } else { + *gen_frequency_penalty = this->frequency_penalty_; + } + if (generation_config.count("top_p")) { + CHECK(generation_config["top_p"].is()); + *gen_top_p = generation_config["top_p"].get(); + } else { + *gen_top_p = this->top_p_; + } + } + /*! * \brief Sample output token from logits on device */ - int32_t SampleTokenFromLogits(NDArray logits_on_device, float temperature, float top_p) { - if (repetition_penalty_ == 1.0f) { - if (temperature_ < 1e-6f) { - this->UpdateLogitsOrProbOnCPUSync(logits_on_device); - } else { - this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, temperature_)); + int32_t SampleTokenFromLogits(NDArray logits_on_device, + picojson::object generation_config = picojson::object()) { + // prepare generation settings + // the generation_config will not override the original config + // since is only used for this generation + double gen_temperature; + double gen_repetition_penalty; + double gen_presence_penalty; + double gen_frequency_penalty; + double gen_top_p; + this->ReadGenerationConfig(generation_config, &gen_temperature, &this->temperature_arr_, + &gen_repetition_penalty, &gen_presence_penalty, + &gen_frequency_penalty, &gen_top_p); + + // update logits + if (gen_presence_penalty != 0.0f || gen_frequency_penalty != 0.0f) { + this->UpdateLogitsOrProbOnCPUSync(logits_on_device); + this->ApplyPresenceAndFrequencyPenaltyOnCPU(gen_presence_penalty, gen_presence_penalty); + if (gen_temperature >= 1e-6f) { + this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature); } - } else { + } else if (gen_repetition_penalty != 1.0f) { this->UpdateLogitsOrProbOnCPUSync(logits_on_device); - this->ApplyRepetitionPenaltyOnCPU(); - if (temperature_ >= 1e-6f) { - this->ApplySoftmaxWithTemperatureOnCPU(); + this->ApplyRepetitionPenaltyOnCPU(gen_repetition_penalty); + if (gen_temperature >= 1e-6f) { + this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature); + } + } else { + if (gen_temperature < 1e-6f) { + this->UpdateLogitsOrProbOnCPUSync(logits_on_device); + } else { + this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_)); } } + + // perform sampling auto tstart = std::chrono::high_resolution_clock::now(); int next_token; - if (temperature_ < 1e-6f) { - next_token = this->SampleFromLogitsOnCPU(); + if (gen_temperature < 1e-6f) { + next_token = this->SampleFromLogitsOnCPU(gen_temperature, gen_top_p); } else { - next_token = this->SampleFromProbOnCPU(); + next_token = this->SampleFromProbOnCPU(gen_top_p); } auto tend = std::chrono::high_resolution_clock::now(); this->sample_total_time += static_cast((tend - tstart).count()) / 1e9; @@ -990,7 +1034,38 @@ class LLMChat { * \brief Add a generated token and check for stop condition. * \param next_token The next token. */ - void ProcessNextToken(int32_t next_token) { + void ProcessNextToken(int32_t next_token, + picojson::object generation_config = picojson::object()) { + // prepare generation settings + // the generation_config will not override the original config + // since is only used for this generation + int64_t gen_max_gen_len; + if (generation_config.count("max_gen_len")) { + CHECK(generation_config["max_gen_len"].is()); + gen_max_gen_len = generation_config["max_gen_len"].get(); + } else { + gen_max_gen_len = this->max_gen_len_; + } + + std::vector gen_stop_strs; + gen_stop_strs.push_back(conversation_.stop_str); + + if (generation_config.count("stop")) { + if (!generation_config["stop"].is()) { + CHECK(generation_config["stop"].is() || + generation_config["stop"].is()); + if (generation_config["stop"].is()) { + gen_stop_strs.push_back(generation_config["stop"].get()); + } else { + picojson::array gen_stop_strs_arr = generation_config["stop"].get(); + for (const picojson::value& v : gen_stop_strs_arr) { + CHECK(v.is()); + gen_stop_strs.push_back(v.get()); + } + } + } + } + ICHECK(!stop_triggered_) << "Cannot call process when it is stopped"; stop_triggered_ = @@ -999,32 +1074,39 @@ class LLMChat { if (!stop_triggered_) { output_ids_.push_back(next_token); - appeared_token_ids_.insert(next_token); + if (appeared_token_freq_.find(next_token) != appeared_token_freq_.end()) { + appeared_token_freq_[next_token] += 1; + } else { + appeared_token_freq_[next_token] = 1; + } } output_message_ = tokenizer_->Decode(output_ids_); - if (!conversation_.stop_str.empty()) { - size_t stop_pos = output_message_.rfind(conversation_.stop_str); - if (stop_pos != std::string::npos) { - stop_triggered_ = true; - if (ft_.support_backtracking_kv_) { - // back tracking, find the first set of token that is smaller - // than the length - size_t backoff = 0; - for (; backoff < output_ids_.size(); ++backoff) { - output_ids_.pop_back(); - output_message_ = tokenizer_->Decode(output_ids_); - if (output_message_.length() <= stop_pos) break; - } - // resize kv to remove the context - ft_.fkvcache_array_popn_(kv_cache_, backoff); - total_seq_len_ -= backoff; + size_t stop_pos = std::string::npos; + for (const std::string& stop_str : gen_stop_strs) { + if (!stop_str.empty()) { + stop_pos = std::min(stop_pos, output_message_.rfind(stop_str)); + } + } + + if (stop_pos != std::string::npos) { + stop_triggered_ = true; + if (ft_.support_backtracking_kv_) { + // back tracking, find the first set of token that is smaller + // than the length + size_t backoff = 0; + for (; (output_ids_.size() > 0) && (output_message_.length() > stop_pos); ++backoff) { + output_ids_.pop_back(); + output_message_ = tokenizer_->Decode(output_ids_); } + // resize kv to remove the context + ft_.fkvcache_array_popn_(kv_cache_, backoff); + total_seq_len_ -= backoff; } } - if (static_cast(output_ids_.size()) >= max_gen_len_) { + if (static_cast(output_ids_.size()) >= gen_max_gen_len) { stop_triggered_ = true; } else if (total_seq_len_ >= max_window_size_) { stop_triggered_ = true; @@ -1077,32 +1159,42 @@ class LLMChat { return Downcast(ret[0]); } - NDArray Softmax(NDArray input, float temperature) { + NDArray Softmax(NDArray input, NDArray temperature_arr) { NDArray ret; - ret = ft_.softmax_func_(input, temperature_arr_); + ret = ft_.softmax_func_(input, temperature_arr); return ret; } - void ApplyRepetitionPenaltyOnCPU() { + void ApplyRepetitionPenaltyOnCPU(float repetition_penalty) { CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!"; CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; float* logits_raw_data = static_cast(logits_on_cpu_->data); - for (const int32_t& token_id : this->appeared_token_ids_) { - if (logits_raw_data[token_id] <= 0) { - logits_raw_data[token_id] *= this->repetition_penalty_; + for (const auto& token_freq : this->appeared_token_freq_) { + if (logits_raw_data[token_freq.first] <= 0) { + logits_raw_data[token_freq.first] *= repetition_penalty; } else { // logits > 0 - logits_raw_data[token_id] /= this->repetition_penalty_; + logits_raw_data[token_freq.first] /= repetition_penalty; } } } - void ApplySoftmaxWithTemperatureOnCPU() { + void ApplyPresenceAndFrequencyPenaltyOnCPU(float presence_penalty, float frequency_penalty) { + CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!"; + CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; + float* logits_raw_data = static_cast(logits_on_cpu_->data); + for (const auto& token_freq : this->appeared_token_freq_) { + logits_raw_data[token_freq.first] -= + (token_freq.second * frequency_penalty + presence_penalty); + } + } + + void ApplySoftmaxWithTemperatureOnCPU(float temperature) { CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!"; CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; int vocab_size = logits_on_cpu_->shape[logits_on_cpu_->ndim - 1]; float* logits_raw_data = static_cast(logits_on_cpu_->data); float m = std::numeric_limits::min(); - float inv_temp = 1.0f / this->temperature_; + float inv_temp = 1.0f / temperature; double d = 0.0f; for (int i = 0; i < vocab_size; ++i) { float x = logits_raw_data[i] * inv_temp; @@ -1137,18 +1229,18 @@ class LLMChat { // Utils static double GetRandomNumber() { return RandomGenerator::GetInstance().GetRandomNumber(); } - int32_t SampleFromLogitsOnCPU() { + int32_t SampleFromLogitsOnCPU(float temperature, float top_p) { ICHECK(logits_on_cpu_.defined()) << "logits_on_cpu_ is not defined"; ICHECK_EQ(logits_on_cpu_->ndim, 3) << "logits_on_cpu_ should be 3D"; ICHECK_EQ(logits_on_cpu_->shape[0], 1) << "logits_on_cpu_ should be 1 batch"; - return fsample_topp_from_logits_(logits_on_cpu_, temperature_, top_p_, GetRandomNumber()); + return fsample_topp_from_logits_(logits_on_cpu_, temperature, top_p, GetRandomNumber()); } - int32_t SampleFromProbOnCPU() { + int32_t SampleFromProbOnCPU(float top_p) { ICHECK(logits_on_cpu_.defined()) << "logits_on_cpu_ is not defined"; ICHECK_EQ(logits_on_cpu_->ndim, 3) << "logits_on_cpu_ should be 3D"; ICHECK_EQ(logits_on_cpu_->shape[0], 1) << "logits_on_cpu_ should be 1 batch"; - return fsample_topp_from_prob_(logits_on_cpu_, top_p_, GetRandomNumber()); + return fsample_topp_from_prob_(logits_on_cpu_, top_p, GetRandomNumber()); } //---------------------------- @@ -1185,12 +1277,16 @@ class LLMChat { NDArray temperature_arr_; // repetition penalty double repetition_penalty_{1.0}; + // presence penalty + double presence_penalty_{0.0}; + // frequency penalty + double frequency_penalty_{0.0}; // top_p double top_p_{0.95}; // output ids till now (refresh after encoding step) std::vector output_ids_; - // appeared token ids till now (refresh after encoding step) - std::unordered_set appeared_token_ids_; + // frequency of appeared token ids till now (refresh after encoding step) + std::unordered_map appeared_token_freq_; // output message till now (refresh after encoding step) std::string output_message_; // Whether encounter stop str @@ -1279,7 +1375,7 @@ class LLMChatModule : public ModuleNode { }); } else if (name == "prefill") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - ICHECK(1 <= args.size() && args.size() <= 3); + ICHECK(1 <= args.size() && args.size() <= 4); if (args.size() == 1) { // args: inp (with decode_next_token = true, place_in_prompt = kAll) GetChat()->PrefillStep(args[0]); @@ -1290,11 +1386,15 @@ class LLMChatModule : public ModuleNode { // args: inp, decode_next_token, place_in_prompt PlaceInPrompt place_in_prompt = static_cast(static_cast(args[2])); GetChat()->PrefillStep(args[0], true, args[1], place_in_prompt); + } else if (args.size() == 4) { + // args: inp, decode_next_token, place_in_prompt, generation_config_str + PlaceInPrompt place_in_prompt = static_cast(static_cast(args[2])); + GetChat()->PrefillStep(args[0], true, args[1], place_in_prompt, args[3]); } }); } else if (name == "embed") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - ICHECK(1 <= args.size() && args.size() <= 2); + ICHECK(1 <= args.size() && args.size() <= 3); if (args.size() == 1) { // args: inp (with place_in_prompt = kAll) *rv = GetChat()->EmbedStep(args[0]); @@ -1302,22 +1402,36 @@ class LLMChatModule : public ModuleNode { // args: inp, place_in_prompt PlaceInPrompt place_in_prompt = static_cast(static_cast(args[1])); *rv = GetChat()->EmbedStep(args[0], true, place_in_prompt); + } else if (args.size() == 3) { + // args: inp, place_in_prompt, generation_config_str + PlaceInPrompt place_in_prompt = static_cast(static_cast(args[1])); + *rv = GetChat()->EmbedStep(args[0], true, place_in_prompt, args[2]); } }); } else if (name == "prefill_with_embed") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - ICHECK(1 <= args.size() && args.size() <= 2); + ICHECK(1 <= args.size() && args.size() <= 3); if (args.size() == 1) { // args: embedding (with decode_next_token = true) GetChat()->PrefillWithEmbedStep(args[0]); } else if (args.size() == 2) { // args: embedding, decode_next_token GetChat()->PrefillWithEmbedStep(args[0], args[1]); + } else if (args.size() == 3) { + // args: embedding, decode_next_token, generation_config_str + GetChat()->PrefillWithEmbedStep(args[0], args[1], args[2]); } }); } else if (name == "decode") { - return PackedFunc( - [this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { GetChat()->DecodeStep(); }); + return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { + ICHECK(0 <= args.size() && args.size() <= 1); + if (args.size() == 0) { + GetChat()->DecodeStep(); + } else if (args.size() == 1) { + // args: generation_config_str + GetChat()->DecodeStep(args[0]); + } + }); } else if (name == "reset_chat") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.size(), 0); diff --git a/cpp/llm_chat.h b/cpp/llm_chat.h index 1839e8cb4d..39408d1685 100644 --- a/cpp/llm_chat.h +++ b/cpp/llm_chat.h @@ -6,17 +6,7 @@ #include #include -#ifndef MLC_LLM_DLL -#ifdef _WIN32 -#ifdef MLC_LLM_EXPORTS -#define MLC_LLM_DLL __declspec(dllexport) -#else -#define MLC_LLM_DLL __declspec(dllimport) -#endif -#else -#define MLC_LLM_DLL __attribute__((visibility("default"))) -#endif -#endif +#include "base.h" namespace mlc { namespace llm { diff --git a/cpp/random.h b/cpp/random.h new file mode 100644 index 0000000000..e6331a9699 --- /dev/null +++ b/cpp/random.h @@ -0,0 +1,37 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file random.h + * \brief Header of random number generator. + */ + +#ifndef MLC_LLM_RANDOM_H_ +#define MLC_LLM_RANDOM_H_ + +#include + +namespace mlc { +namespace llm { + +// Random number generator +class RandomGenerator { + private: + std::mt19937 gen; + std::uniform_real_distribution<> dis; + + RandomGenerator(int seed) : gen(seed), dis(0.0, 1.0) {} + + public: + static RandomGenerator& GetInstance(int seed = std::random_device{}()) { + static RandomGenerator instance(seed); + return instance; + } + + double GetRandomNumber() { return dis(gen); } + + void SetSeed(int seed) { gen.seed(seed); } +}; + +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_RANDOM_H_ diff --git a/cpp/support.h b/cpp/support.h new file mode 100644 index 0000000000..20eadbbd0a --- /dev/null +++ b/cpp/support.h @@ -0,0 +1,31 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file support.h + * \brief Header of utilities. + */ + +#ifndef MLC_LLM_COMMON_H_ +#define MLC_LLM_COMMON_H_ + +#include +#include + +namespace mlc { +namespace llm { + +inline std::string LoadBytesFromFile(const std::string& path) { + std::ifstream fs(path, std::ios::in | std::ios::binary); + ICHECK(!fs.fail()) << "Cannot open " << path; + std::string data; + fs.seekg(0, std::ios::end); + size_t size = static_cast(fs.tellg()); + fs.seekg(0, std::ios::beg); + data.resize(size); + fs.read(data.data(), size); + return data; +} + +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_COMMON_H_ diff --git a/cpp/tokenizers.cc b/cpp/tokenizers.cc new file mode 100644 index 0000000000..8d38dd9572 --- /dev/null +++ b/cpp/tokenizers.cc @@ -0,0 +1,61 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file tokenizer.cc + */ + +#include "tokenizers.h" + +#include +#include + +#include +#include +#include + +#include "support.h" + +namespace mlc { +namespace llm { + +std::unique_ptr TokenizerFromPath(const std::string& _path) { + std::filesystem::path path(_path); + std::filesystem::path sentencepiece; + std::filesystem::path huggingface; + std::filesystem::path rwkvworld; + CHECK(std::filesystem::exists(path)) << "Cannot find tokenizer via path: " << _path; + if (std::filesystem::is_directory(path)) { + sentencepiece = path / "tokenizer.model"; + huggingface = path / "tokenizer.json"; + rwkvworld = path / "tokenizer_model"; + // Check ByteLevelBPE + { + std::filesystem::path merges_path = path / "merges.txt"; + std::filesystem::path vocab_path = path / "vocab.json"; + std::filesystem::path added_tokens_path = path / "added_tokens.json"; + if (std::filesystem::exists(merges_path) && std::filesystem::exists(vocab_path) && + std::filesystem::exists(added_tokens_path)) { + std::string vocab = LoadBytesFromFile(vocab_path.string()); + std::string merges = LoadBytesFromFile(merges_path.string()); + std::string added_tokens = LoadBytesFromFile(added_tokens_path.string()); + return Tokenizer::FromBlobByteLevelBPE(vocab, merges, added_tokens); + } + } + } else { + sentencepiece = path.parent_path() / "tokenizer.model"; + huggingface = path.parent_path() / "tokenizer.json"; + rwkvworld = path.parent_path() / "tokenizer_model"; + } + if (std::filesystem::exists(sentencepiece)) { + return Tokenizer::FromBlobSentencePiece(LoadBytesFromFile(sentencepiece.string())); + } + if (std::filesystem::exists(huggingface)) { + return Tokenizer::FromBlobJSON(LoadBytesFromFile(huggingface.string())); + } + if (std::filesystem::exists(rwkvworld)) { + return Tokenizer::FromBlobRWKVWorld(rwkvworld.string()); + } + LOG(FATAL) << "Cannot find any tokenizer under: " << _path; +} + +} // namespace llm +} // namespace mlc diff --git a/cpp/tokenizers.h b/cpp/tokenizers.h new file mode 100644 index 0000000000..f44f828e97 --- /dev/null +++ b/cpp/tokenizers.h @@ -0,0 +1,24 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file tokenizers.h + * \brief Header of tokenizer related functions. + */ + +#ifndef MLC_LLM_TOKENIZER_H_ +#define MLC_LLM_TOKENIZER_H_ + +#include + +#include "base.h" + +namespace mlc { +namespace llm { + +using tokenizers::Tokenizer; + +MLC_LLM_DLL std::unique_ptr TokenizerFromPath(const std::string& _path); + +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_TOKENIZER_H_ diff --git a/docs/community/faq.rst b/docs/community/faq.rst index 45d73b4904..3913dd9639 100644 --- a/docs/community/faq.rst +++ b/docs/community/faq.rst @@ -5,7 +5,7 @@ Frequently Asked Questions This is a list of Frequently Asked Questions (FAQ) about the MLC-LLM. Feel free to suggest new entries! -... How can I customize the temperature, repetition penalty of models? +... How can I customize the temperature, and repetition penalty of models? Please check our :doc:`/get_started/mlc_chat_config` tutorial. ... What's the quantization algorithm MLC-LLM using? @@ -13,5 +13,4 @@ This is a list of Frequently Asked Questions (FAQ) about the MLC-LLM. Feel free ... Why do I encounter an error ``free(): invalid pointer, Aborted (core dumped)`` at the end of model compilation? This happens if you compiled TVM-Unity from source and didn't hide LLVM symbols in cmake configurations. - Please follow our instructions in :ref:`Building TVM Unity from Source ` tutorial to compile TVM-Unity which hides LLVM symbols, - or use our pre-builc MLC-AI pip wheels from `MLC Packages `__. + Please follow our instructions in :ref:`Building TVM Unity from Source ` tutorial to compile TVM-Unity which hides LLVM symbols, or use our pre-built MLC-LLM :doc:`pip wheels <../install/mlc_llm>`. diff --git a/docs/community/guideline.rst b/docs/community/guideline.rst index 7f671f614b..38a03e463e 100644 --- a/docs/community/guideline.rst +++ b/docs/community/guideline.rst @@ -42,11 +42,11 @@ Ready to contribute to MLC-LLM? Awesome! We are excited to see you are ready to The standard way to make changes to MLC-LLM code base is through creating a `pull-request `__, and we will review your code and merge it to the code base when it is ready. -The first step to become a developer is to `fork `__ the repository to your own +The first step to becoming a developer is to `fork `__ the repository to your own github account, you will notice a repository under ``https://github.com/username/mlc-llm`` where ``username`` is your github user name. You can clone your fork to your local machine and commit changes, or edit the contents of your fork (in the case you are just fixing typos) -on github directly. Once your update is complete, you can click the ``contribute`` button and open a pull request to the main repository. +on GitHub directly. Once your update is complete, you can click the ``contribute`` button and open a pull request to the main repository. .. _contribute-new-models: @@ -55,7 +55,7 @@ Contribute New Models to MLC-LLM * If you have compiled a model using our :doc:`/compilation/compile_models` tutorial for an existing model architecture, please upload your models to the internet (e.g., Hugging Face) by following :ref:`distribute-compiled-models` tutorial. Once you have done that, you can create a pull request to add an entry in the :doc:`/prebuilt_models` page. Additionally, you have the option to `create a speed report issue `__ to track the speed and memory consumption of your model. You don't need to test it on all devices; let the community collaborate on building it together! -* If you add a new model variant to MLC-LLM by following our :doc:`/tutorials/bring-your-own-models` tutorial. +* If you add a new model variant to MLC-LLM by following our :doc:`/tutorials/customize/define_new_models` tutorial. Please create a pull request to add your model architecture (currently model architectures are placed under `relax_models `__ folder). @@ -86,14 +86,14 @@ Fo your convenience, you can use `clang-format `__ to acknowledge contributors, -please let us know if you contribute to the project and your name is not included in the list. +please let us know if you contribute to the project and if your name is not included in the list. diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst index e559c8fc27..98c7f2d156 100644 --- a/docs/compilation/compile_models.rst +++ b/docs/compilation/compile_models.rst @@ -4,14 +4,14 @@ Compile Models via MLC ====================== This page describes how to compile a model with MLC LLM. Model compilation takes model inputs, produces quantized model weights, -and optimized model lib for a given platform. It enables users to bring their own new model weights, try different quantization modes, +and optimizes model lib for a given platform. It enables users to bring their own new model weights, try different quantization modes, and customize the overall model optimization flow. .. note:: Before you proceed, please make sure that you have :ref:`install-tvm-unity` correctly installed on your machine. TVM-Unity is the necessary foundation for us to compile models with MLC LLM. If you want to build webgpu, please also complete :ref:`install-web-build`. - Please also follow the instruction in :ref:`deploy-cli` to obtain the CLI app that can be used to chat with the compiled model. + Please also follow the instructions in :ref:`deploy-cli` to obtain the CLI app that can be used to chat with the compiled model. Finally, we strongly recommend you read :ref:`project-overview` first to get familiarized with the high-level terminologies. @@ -25,12 +25,12 @@ Install MLC-LLM Package Work with Source Code ^^^^^^^^^^^^^^^^^^^^^ -The easiest way is to use MLC-LLM is to clone the repository, and compile models under the root directory of the repository. +The easiest way to use MLC-LLM is to clone the repository, and compile models under the root directory of the repository. .. code:: bash # clone the repository - git clone git@github.com:mlc-ai/mlc-llm.git --recursive + git clone https://github.com/mlc-ai/mlc-llm.git --recursive # enter to root directory of the repo cd mlc-llm # install mlc-llm @@ -106,7 +106,7 @@ your personal computer. xcrun: error: unable to find utility "metallib", not a developer tool or in PATH , please check and make sure you have Command Line Tools for Xcode installed correctly. - You can use ``xcrun metal`` to validate: when it prints ``metal: error: no input files``, it means the Command Line Tools for Xcode is installed and can be found, and you can proceed the model compiling. + You can use ``xcrun metal`` to validate: when it prints ``metal: error: no input files``, it means the Command Line Tools for Xcode is installed and can be found, and you can proceed with the model compiling. .. group-tab:: Android @@ -172,7 +172,7 @@ We can check the output with the commands below: tokenizer_config.json We now chat with the model using the command line interface (CLI) app. - Follow the build from source instruction + Follow the build from the source instruction .. code:: shell @@ -271,7 +271,7 @@ We can check the output with the commands below: tokenizer_config.json The model lib ``dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1/RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm`` - can be uploaded to internet. You can pass a ``model_lib_map`` field to WebLLM app config to use this library. + can be uploaded to the internet. You can pass a ``model_lib_map`` field to WebLLM app config to use this library. Each compilation target produces a specific model library for the given platform. The model weight is shared across @@ -311,7 +311,7 @@ In other cases you need to specify the model via ``--model``. - ``dist/models/MODEL_NAME_OR_PATH`` (e.g., ``--model Llama-2-7b-chat-hf``), - ``MODEL_NAME_OR_PATH`` (e.g., ``--model /my-model/Llama-2-7b-chat-hf``). - When running the compile command using ``--model``, please make sure you have placed the model to compile under ``dist/models/`` or other location on the disk. + When running the compile command using ``--model``, please make sure you have placed the model to compile under ``dist/models/`` or another location on the disk. --hf-path HUGGINGFACE_NAME The name of the model's Hugging Face repository. We will download the model to ``dist/models/HUGGINGFACE_NAME`` and load the model from this directory. @@ -336,11 +336,11 @@ The following arguments are optional: we will use the maximum sequence length from the ``config.json`` in the model directory. --reuse-lib LIB_NAME Specifies the previously generated library to reuse. This is useful when building the same model architecture with different weights. - You can refer to the :ref:`model distribution ` page for detail of this argument. + You can refer to the :ref:`model distribution ` page for details of this argument. --use-cache When ``--use-cache=0`` is specified, the model compilation will not use cached file from previous builds, and will compile the model from the very start. - Using cache can help reduce the time needed to compile. + Using a cache can help reduce the time needed to compile. --debug-dump Specifies whether to dump debugging files during compilation. --use-safetensors Specifies whether to use ``.safetensors`` instead of the default ``.bin`` when loading in model weights. @@ -354,7 +354,7 @@ This section lists compile commands for more models that you can try out. .. tab:: Model: Llama-2-7B Please `request for access `_ to the Llama-2 weights from Meta first. - After granted the access, please create directory ``dist/models`` and download the model to the directory. + After granted access, please create directory ``dist/models`` and download the model to the directory. For example, you can run the following code: .. code:: shell diff --git a/docs/compilation/distribute_compiled_models.rst b/docs/compilation/distribute_compiled_models.rst index b6f31f9386..69dc0e847d 100644 --- a/docs/compilation/distribute_compiled_models.rst +++ b/docs/compilation/distribute_compiled_models.rst @@ -67,7 +67,7 @@ You can **optionally** customize the chat config file ``dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1/params/mlc-chat-config.json`` (checkout :ref:`configure-mlc-chat-json` for more detailed instructions). You can also simply use the default configuration and skip this step. -For demonstration purpose, we update ``mean_gen_len`` to 32 and ``max_gen_len`` to 64. +For demonstration purposes, we update ``mean_gen_len`` to 32 and ``max_gen_len`` to 64. We also update ``conv_template`` to ``"LM"`` because the model is instruction-tuned. @@ -160,8 +160,8 @@ Download the Distributed Models and Run in iOS App -------------------------------------------------- For iOS app, model libraries are statically packed into the app at the time of app building. -Therefore, the iOS app supports running any models whose model libraries are integrated into the app. -You can check the :ref:`list of supported model libraries `. +Therefore, the iOS app supports running any model whose model libraries are integrated into the app. +You can check the :ref:`list of supported model libraries `. To download and run the compiled RedPajama-3B instruct model on iPhone, we need to reuse the integrated ``RedPajama-INCITE-Chat-3B-v1-q4f16_1`` model library. Please revisit :ref:`distribute-model-step3-specify-model-lib` and make sure the ``model_lib`` field of `mlc-chat-config.json` is set to ``RedPajama-INCITE-Chat-3B-v1-q4f16_1``. @@ -198,7 +198,7 @@ Now we can download the model weights in iOS app and run the model by following .. tab:: Step 4 - When the download is finished, click into the model and enjoy. + When the download is finished, click on the model and enjoy. .. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/iPhone-distribute-4.jpeg :align: center diff --git a/docs/compilation/get-vicuna-weight.rst b/docs/compilation/get-vicuna-weight.rst index 0cc42380e9..2ea4ba5d97 100644 --- a/docs/compilation/get-vicuna-weight.rst +++ b/docs/compilation/get-vicuna-weight.rst @@ -5,7 +5,7 @@ Getting Vicuna Weights :local: :depth: 2 -`Vicuna `_ is a open-source chatbot trained by fine-tuning `LLaMA `_ on `ShartGPT `_ data. +`Vicuna `_ is an open-source chatbot trained by fine-tuning `LLaMA `_ on `ShartGPT `_ data. Please note that the official Vicuna weights are delta weights applied to the LLaMA weights in order to comply with the LLaMA license. Users are responsible for applying these delta weights themselves. @@ -14,7 +14,7 @@ In this tutorial, we will show how to apply the delta weights to LLaMA weights t Install FastChat ---------------- -FastChat offers convenient utility functions for applying delta to LLaMA weights. You can easily install it using pip. +FastChat offers convenient utility functions for applying the delta to LLaMA weights. You can easily install it using pip. .. code-block:: bash @@ -38,14 +38,14 @@ Then download the weights (both the LLaMA weight and Vicuna delta weight): git clone https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 -There is a name mis-alignment issue in the LLaMA weights and Vicuna delta weights. +There is a name misalignment issue in the LLaMA weights and Vicuna delta weights. Please follow these steps to modify the content of the "config.json" file: .. code-block:: bash sed -i 's/LLaMAForCausalLM/LlamaForCausalLM/g' llama-7b-hf/config.json -Then use ``fschat`` to apply delta to LLaMA weights +Then use ``fschat`` to apply the delta to LLaMA weights .. code-block:: bash diff --git a/docs/compilation/python.rst b/docs/compilation/python.rst index 99486a751b..98e4f934e7 100644 --- a/docs/compilation/python.rst +++ b/docs/compilation/python.rst @@ -5,8 +5,8 @@ Python API for Model Compilation :local: :depth: 2 -We expose Python API for compiling/building model in the package :py:mod:`mlc_llm`, so -that users may build model in any directory in their program (i.e. not just +We expose Python API for compiling/building models in the package :py:mod:`mlc_llm`, so +that users may build a model in any directory in their program (i.e. not just within the mlc-llm repo). Install MLC-LLM as a Package @@ -44,7 +44,7 @@ After installing the package, you can build the model using :meth:`mlc_llm.build which takes in an instance of :class:`BuildArgs` (a dataclass that represents the arguments for building a model). -For detailed instruction with code, please refer to `the python notebook +For detailed instructions with code, please refer to `the Python notebook `_ (executable in Colab), where we walk you through compiling Llama-2 with :py:mod:`mlc_llm` in Python. @@ -56,7 +56,7 @@ API Reference In order to use the python API :meth:`mlc_llm.build_model`, users need to create an instance of the dataclass :class:`BuildArgs`. The corresponding arguments in -command line shown in :ref:`compile-command-specification` are automatically +the command line shown in :ref:`compile-command-specification` are automatically converted from the definition of :class:`BuildArgs` and are equivalent. Then with an instantiated :class:`BuildArgs`, users can call the build API diff --git a/docs/conf.py b/docs/conf.py index ee42500d51..0f7ed19014 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -7,6 +7,8 @@ # -- General configuration ------------------------------------------------ sys.path.insert(0, os.path.abspath("../python")) +sys.path.insert(0, os.path.abspath("../")) +autodoc_mock_imports = ["torch"] # do not load mlc-llm.so in docs os.environ["SKIP_LOADING_MLCLLM_SO"] = "1" @@ -29,9 +31,7 @@ "sphinx_reredirects", ] -redirects = { - "get_started/try_out": "../index.html#getting-started" -} +redirects = {"get_started/try_out": "../index.html#getting-started"} source_suffix = [".rst"] diff --git a/docs/deploy/android.rst b/docs/deploy/android.rst index c26e9f3445..0c2ed8535f 100644 --- a/docs/deploy/android.rst +++ b/docs/deploy/android.rst @@ -98,11 +98,11 @@ To deploy models on Android with reasonable performance, one has to cross-compil --model ./dist/models/$MODEL_NAME \ --quantization $QUANTIZATION -This generates directory ``./dist/$MODEL_NAME-$QUANTIZATION`` which contains the necessary components to run the model, as explained below. +This generates the directory ``./dist/$MODEL_NAME-$QUANTIZATION`` which contains the necessary components to run the model, as explained below. **Expected output format**. By default models are placed under ``./dist/${MODEL_NAME}-${QUANTIZATION}``, and the result consists of 3 major components: -- Runtime configuration: It configures conversation templates including system prompts, repetition repetition penalty, sampling including temperature and top-p probability, maximum sequence length, etc. It is usually named as ``mlc-chat-config.json`` under ``params/`` along side with tokenizer configurations. +- Runtime configuration: It configures conversation templates including system prompts, repetition repetition penalty, sampling including temperature and top-p probability, maximum sequence length, etc. It is usually named as ``mlc-chat-config.json`` under ``params/``alongside with tokenizer configurations. - Model lib: The compiled library that uses mobile GPU. It is usually named as ``${MODEL_NAME}-${QUANTIZATION}-android.tar``, for example, ``Llama-2-7b-chat-hf-q4f16_0-android.tar``. - Model weights: the model weights are sharded as ``params_shard_*.bin`` under ``params/`` and the metadata is stored in ``ndarray-cache.json``. @@ -144,16 +144,16 @@ The model execution logic in mobile GPUs is incorporated into ``libtvm4j_runtime **Build the Android app**. Open folder ``./android/MLCChat`` as an Android Studio Project. Connect your Android device to your machine. In the menu bar of Android Studio, click "Build → Make Project". Once the build is finished, click "Run → Run 'app'" and you will see the app launched on your phone. .. note:: - ❗ This app cannot be run in an emulator and thus a physical phone is required, because MLC LLM needs an actual mobile GPU to meaningfully run at accelerated speed. + ❗ This app cannot be run in an emulator and thus a physical phone is required, because MLC LLM needs an actual mobile GPU to meaningfully run at an accelerated speed. Incorporate Model Weights ------------------------- Instructions have been provided to build an Android App with MLC LLM in previous sections, but it requires run-time weight downloading from HuggingFace, as configured in `app-config.json` in previous steps under `model_url`. However, it could be desirable to bundle weights together into the app to avoid downloading over the network. In this section, we provide a simple ADB-based walkthrough that hopefully helps with further development. -**Generating APK**. Enter Android Studio, click "Build → Generate Signed Bundle/APK" to build an APK for release. If it is the first time you generate an APK, you will need to create a key according to `the official guide from Android `_. This APK will be placed under ``android/MLCChat/app/release/app-release.apk``. +**Generating APK**. Enter Android Studio, and click "Build → Generate Signed Bundle/APK" to build an APK for release. If it is the first time you generate an APK, you will need to create a key according to `the official guide from Android `_. This APK will be placed under ``android/MLCChat/app/release/app-release.apk``. -**Install ADB and USB debugging**. Enable "USB debugging" in the developer mode in your phone settings. In SDK manager, install `Android SDK Platform-Tools `_. Add the path to platform-tool path to environment variable ``PATH``. Run the following commands, and if ADB is installed correctly, your phone will appear as a device: +**Install ADB and USB debugging**. Enable "USB debugging" in the developer mode in your phone settings. In SDK manager, install `Android SDK Platform-Tools `_. Add the path to platform-tool path to the environment variable ``PATH``. Run the following commands, and if ADB is installed correctly, your phone will appear as a device: .. code-block:: bash diff --git a/docs/deploy/cli.rst b/docs/deploy/cli.rst index 79501b113d..2f0951686d 100644 --- a/docs/deploy/cli.rst +++ b/docs/deploy/cli.rst @@ -3,7 +3,7 @@ CLI and C++ API =============== -MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box. You may install it from the prebuilt package we provide, or compile it from source. +MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box. You may install it from the prebuilt package we provide, or compile it from the source. .. contents:: Table of Contents :local: @@ -25,16 +25,16 @@ To use other GPU runtimes, e.g. CUDA, please instead :ref:`build it from source After installation, activating ``mlc-chat-venv`` environment in Conda will give the ``mlc_chat_cli`` command available. .. note:: - The prebuilt package supports **Metal** on macOS and **Vulkan** on Linux and Windows. It is possible to use other GPU runtimes such as **CUDA** by compiling MLCChat CLI from source. + The prebuilt package supports **Metal** on macOS and **Vulkan** on Linux and Windows. It is possible to use other GPU runtimes such as **CUDA** by compiling MLCChat CLI from the source. .. _mlcchat_build_from_source: Option 2. Build MLC Runtime from Source ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -We also provid options to build mlc runtime libraries and ``mlc_chat_cli`` from source. +We also provide options to build mlc runtime libraries and ``mlc_chat_cli`` from source. This step is useful when you want to directly obtain a version of mlc runtime library -and the cli. Please click the details below to see the instruction. +and the cli. Please click the details below to see the instructions. .. collapse:: Details @@ -63,7 +63,7 @@ and the cli. Please click the details below to see the instruction. conda activate mlc-chat-venv .. note:: - :doc:`TVM Unity ` compiler is not a dependency to MLCChat CLI. Only its runtime is required, which is automatically included in `3rdparty/tvm `_. + :doc:`TVM Unity ` compiler is not a dependency on MLCChat CLI. Only its runtime is required, which is automatically included in `3rdparty/tvm `_. **Step 2. Configure and build.** A standard git-based workflow is recommended to download MLC LLM, after which you can specify build requirements with our lightweight config generation tool: @@ -96,7 +96,7 @@ and the cli. Please click the details below to see the instruction. Run Models through MLCChat CLI ------------------------------ -Once ``mlc_chat_cli`` is installed, you are able to run any MLC-compiled model on command line. +Once ``mlc_chat_cli`` is installed, you are able to run any MLC-compiled model on the command line. **Ensure Model Exists.** As the input to ``mlc_chat_cli``, it is always good to double check if the compiled model exists. @@ -111,6 +111,12 @@ Once ``mlc_chat_cli`` is installed, you are able to run any MLC-compiled model o - Model lib should be placed at ``./dist/prebuilt/lib/$(local_id)-$(arch).$(suffix)``. - Model weights and chat config are located under ``./dist/prebuilt/mlc-chat-$(local_id)/``. + .. note:: + Please make sure that you have the same directory structure as above, because the CLI tool + relies on it to automatically search for model lib and weights. If you would like to directly + provide a full model lib path to override the auto-search, you can pass in a ``--model-lib-path`` argument + to the CLI + .. collapse:: Example .. code:: shell @@ -134,6 +140,12 @@ Once ``mlc_chat_cli`` is installed, you are able to run any MLC-compiled model o - Model libraries should be placed at ``./dist/$(local_id)/$(local_id)-$(arch).$(suffix)``. - Model weights and chat config are located under ``./dist/$(local_id)/params/``. + .. note:: + Please make sure that you have the same directory structure as above, because the CLI tool + relies on it to automatically search for model lib and weights. If you would like to directly + provide a full model lib path to override the auto-search, you can pass in a ``--model-lib-path`` argument + to the CLI + .. collapse:: Example .. code:: shell diff --git a/docs/deploy/ios.rst b/docs/deploy/ios.rst index 597f594bfb..b6e8e7b55a 100644 --- a/docs/deploy/ios.rst +++ b/docs/deploy/ios.rst @@ -7,9 +7,9 @@ iOS App and Swift API :local: :depth: 2 -The MLC LLM iOS app can be installed in two ways: through the pre-built package or by building from source. +The MLC LLM iOS app can be installed in two ways: through the pre-built package or by building from the source. If you are an iOS user looking to try out the models, the pre-built package is recommended. If you are a -developer seeking to integrate new features into the package, building the iOS package from source is required. +developer seeking to integrate new features into the package, building the iOS package from the source is required. Use Pre-built iOS App --------------------- @@ -23,7 +23,7 @@ The MLC Chat app is now available in App Store at no cost. You can download and Build iOS App from Source ------------------------- -This section shows how we can build the app from source. +This section shows how we can build the app from the source. Step 1. Install Build Dependencies ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -134,7 +134,7 @@ Ensure that all the necessary dependencies and configurations are correctly set up in the Xcode project. Once you have made the necessary changes, build the iOS app using Xcode. -If you have an Apple Silicon Mac, you can select target "My Mac (designed for ipad)" +If you have an Apple Silicon Mac, you can select target "My Mac (designed for iPad)" to run on your Mac. You can also directly run it on your iPad or iPhone. .. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/xcode-build.jpg @@ -163,7 +163,7 @@ controls the list of model URLs and model libs to be packaged into the app. Additionally, the app prepackages the models under ``./ios/dist``. This built-in list can be controlled by editing ``prepare_params.sh``. -You can package new prebuilt models or compiled models by changing the above fields and then repeat the steps above. +You can package new prebuilt models or compiled models by changing the above fields and then repeating the steps above. Build Apps with MLC Swift API @@ -193,8 +193,8 @@ your own app. The package is located under `ios/MLCSwift`. -ltokenizers_c -You can then can import the `MLCSwift` package in your app. -The following code shows an illustrative example about how to use the chat module. +You can then import the `MLCSwift` package into your app. +The following code shows an illustrative example of how to use the chat module. .. code:: swift @@ -221,7 +221,7 @@ The following code shows an illustrative example about how to use the chat modul Because the chat module makes heavy use of GPU and thread-local resources, it needs to run on a dedicated background thread. Therefore, **avoid using** `DispatchQueue`, which can cause context switching to - different threads and segfaults due to thread-safety issue. + different threads and segfaults due to thread-safety issues. Use the `ThreadWorker` class to launch all the jobs related to the chat module. You can check out the source code of the MLCChat app for a complete example. diff --git a/docs/deploy/javascript.rst b/docs/deploy/javascript.rst index cdbd4cc79e..08dd2cde26 100644 --- a/docs/deploy/javascript.rst +++ b/docs/deploy/javascript.rst @@ -6,9 +6,9 @@ WebLLM and Javascript API :depth: 2 WebLLM is a MLC chat web runtime (`WebLLM `_) -that allows you to build chat applications directly in browser. +that allows you to build chat applications directly in the browser. -Try out Prebuilt Webpage +Try out the Prebuilt Webpage ------------------------ To get started, you can try out `WebLLM prebuilt webpage `__. diff --git a/docs/deploy/python.rst b/docs/deploy/python.rst index ccdfec743d..3df5a08241 100644 --- a/docs/deploy/python.rst +++ b/docs/deploy/python.rst @@ -11,8 +11,7 @@ We also provide a web demo based on `gradio `_ as an exampl Python API ---------- -The Python API is a part of the MLC-Chat package, which we have prepared pre-built pip wheels and you can install it by -following the instructions in ``_. +The Python API is a part of the MLC-Chat package, which we have prepared pre-built pip wheels via the :doc:`installation page <../install/mlc_llm>`. Verify Installation ^^^^^^^^^^^^^^^^^^^ @@ -29,7 +28,7 @@ that supports other GPU runtime than the prebuilt version. Please refer our :ref Get Started ^^^^^^^^^^^ After confirming that the package ``mlc_chat`` is installed, we can follow the steps -below to chat with a MLC-compiled model in Python. +below to chat with an MLC-compiled model in Python. First, let us make sure that the MLC-compiled ``model`` we want to chat with already exists. @@ -52,6 +51,11 @@ If you do not have the MLC-compiled ``model`` ready: - Model lib should be placed at ``./dist/prebuilt/lib/$(model)-$(arch).$(suffix)``. - Model weights and chat config are located under ``./dist/prebuilt/mlc-chat-$(model)/``. + .. note:: + Please make sure that you have the same directory structure as above, because Python API + relies on it to automatically search for model lib and weights. If you would like to directly + provide a full model lib path to override the auto-search, you can specify ``ChatModule.model_lib_path`` + .. collapse:: Example .. code:: shell @@ -75,6 +79,11 @@ If you do not have the MLC-compiled ``model`` ready: - Model libraries should be placed at ``./dist/$(model)/$(model)-$(arch).$(suffix)``. - Model weights and chat config are located under ``./dist/$(model)/params/``. + .. note:: + Please make sure that you have the same directory structure as above, because Python API + relies on it to automatically search for model lib and weights. If you would like to directly + provide a full model lib path to override the auto-search, you can specify ``ChatModule.model_lib_path`` + .. collapse:: Example .. code:: shell @@ -90,7 +99,7 @@ If you do not have the MLC-compiled ``model`` ready: params_shard_*.bin ... -After making sure that the files exist, using the conda environment you used +After making sure that the files exist, use the conda environment you used to install ``mlc_chat``, from the ``mlc-llm`` directory, you can create a Python file ``sample_mlc_chat.py`` and paste the following lines: @@ -158,7 +167,7 @@ You can also checkout the :doc:`/prebuilt_models` page to run other models. | .. note:: - You could also specify the address of ``model`` and ``lib_path`` explicitly. If + You could also specify the address of ``model`` and ``model_lib_path`` explicitly. If you only specify ``model`` as ``model_name`` and ``quantize_mode``, we will do a search for you. See more in the documentation of :meth:`mlc_chat.ChatModule.__init__`. @@ -244,7 +253,7 @@ We provide an example below. fields specified. It is also worth noting that ``ConvConfig`` itself is overriding the original conversation template - specified by the field ``conv_template`` in chat configuration. Learn more about it in + specified by the field ``conv_template`` in the chat configuration. Learn more about it in :ref:`Configure MLCChat in JSON`. Raw Text Generation in Python @@ -263,7 +272,7 @@ We provide an example below. # Use a `ConvConfig` to define the generation settings # Since the "LM" template only supports raw text generation, - # system prompts will not be executed even if provided + # System prompts will not be executed even if provided conv_config = ConvConfig(stop_tokens=[2,], add_bos=True, stop_str="[INST]") # Note that `conv_config` is an optional subfield of `chat_config` @@ -297,6 +306,38 @@ We provide an example below. Additionally, system prompts will not be run when instantiating a `mlc_chat.ChatModule`, unless explicitly given inside the prompt. +Stream Iterator in Python +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Stream Iterator gives users an option to stream generated text to the function that the API is called from, +instead of streaming to stdout, which could be a necessity when building services on top of MLC Chat. + +We provide an example below. + +.. code:: python + + from mlc_chat import ChatModule + from mlc_chat.callback import StreamIterator + + # Create a ChatModule instance + cm = ChatModule(model="Llama-2-7b-chat-hf-q4f16_1") + + # Stream to an Iterator + from threading import Thread + + stream = StreamIterator(callback_interval=2) + generation_thread = Thread( + target=cm.generate, + kwargs={"prompt": "What is the meaning of life?", "progress_callback": stream}, + ) + generation_thread.start() + + output = "" + for delta_message in stream: + output += delta_message + + generation_thread.join() + API Reference ------------- @@ -314,10 +355,19 @@ The :class:`mlc_chat.ChatModule` class provides the following methods: .. automethod:: __init__ +.. autoclass:: ChatConfig + :members: + +.. autoclass:: ConvConfig + :members: + +.. autoclass:: GenerationConfig + :members: + Gradio Frontend --------------- -The gradio frontend provides a web interface for the MLC-Chat model, which allows user to interact with the model in a more user-friendly way and switch between different models to compare performance. +The gradio frontend provides a web interface for the MLC-Chat model, which allows users to interact with the model in a more user-friendly way and switch between different models to compare performance. To use gradio frontend, you need to install gradio first: .. code-block:: bash @@ -335,7 +385,7 @@ Then you can run the following code to start the interface: --port The port number to run gradio. The default value is ``7860``. --share Whether to create a publicly shareable link for the interface. -After setting up properly, you are expected to see the following interface in your browser: +After setting it up properly, you are expected to see the following interface in your browser: .. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/gradio-interface.png :width: 100% diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst index d5955190e9..8451624fdb 100644 --- a/docs/deploy/rest.rst +++ b/docs/deploy/rest.rst @@ -6,13 +6,12 @@ Rest API :depth: 2 We provide `REST API `_ -for user to interact with MLC-Chat in their own programs. +for a user to interact with MLC-Chat in their own programs. Install MLC-Chat Package ------------------------ -The REST API is a part of the MLC-Chat package, which we have prepared pre-built pip wheels and you can install it by -following the instructions in ``_. +The REST API is a part of the MLC-Chat package, which we have prepared pre-built :doc:`pip wheels <../install/mlc_llm>`. Verify Installation ^^^^^^^^^^^^^^^^^^^ @@ -34,7 +33,7 @@ of mlc chat runtime. You only need to do this if you choose not to use the prebu First, make sure you install TVM unity (following the instruction in :ref:`install-tvm-unity`). You can choose to only pip install `mlc-ai-nightly` that comes with the tvm unity but skip `mlc-chat-nightly`. -Then please follow the instruction in :ref:`mlcchat_build_from_source` to build the necessary libraries. +Then please follow the instructions in :ref:`mlcchat_build_from_source` to build the necessary libraries. You can now use ``mlc_chat`` package by including the `python` directory to ``PYTHONPATH`` environment variable. @@ -89,6 +88,9 @@ The REST API provides the following endpoints: Get the latest runtime stats (encode/decode speed). +.. http:get:: /verbose_stats + + Get the verbose runtime stats (encode/decode speed, total runtime). Use REST API in your own program -------------------------------- diff --git a/docs/index.rst b/docs/index.rst index 89be4d4161..345b5d9603 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -11,13 +11,13 @@ Getting Started --------------- To begin with, try out MLC LLM support for int4-quantized Llama2 7B. -It is recommended to have at least 4.5GB of free VRAM to run it. +It is recommended to have at least 6GB free VRAM to run it. .. tabs:: .. tab:: Python - **Install MLC Chat**. `MLC Chat `_ is available via pip. + **Install MLC Chat Python**. :doc:`MLC LLM ` is available via pip. It is always recommended to install it in an isolated conda virtual environment. **Download pre-quantized weights**. The comamnds below download the int4-quantized Llama2-7B from HuggingFace: @@ -47,6 +47,8 @@ It is recommended to have at least 4.5GB of free VRAM to run it. **Colab walkthrough.** A Jupyter notebook on `Colab `_ is provided with detailed walkthrough of the Python API. + **Documentation and tutorial.** Python API reference and its tutorials are `available online `_. + .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-api.jpg :width: 600 :align: center @@ -209,6 +211,7 @@ It is recommended to have at least 4.5GB of free VRAM to run it. :hidden: install/tvm.rst + install/mlc_llm.rst install/conda.rst install/gpu.rst install/emcc.rst diff --git a/docs/install/gpu.rst b/docs/install/gpu.rst index 48ac7a5e1f..608c238265 100644 --- a/docs/install/gpu.rst +++ b/docs/install/gpu.rst @@ -105,7 +105,7 @@ After installation, you can run ``vulkaninfo`` in command line and see if you ca Vulkan SDK ---------- -Vulkan SDK is required for compiling models to Vulkan backend. To build TVM Unity compiler from source, you will need to install Vulkan SDK as a dependency, but our `pre-built wheels `__ already ships with Vulkan SDK. +Vulkan SDK is required for compiling models to Vulkan backend. To build TVM Unity compiler from source, you will need to install Vulkan SDK as a dependency, but our :doc:`pre-built wheels <../install/mlc_llm>` already ships with Vulkan SDK. Check Vulkan SDK installation guide according to your platform: diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst new file mode 100644 index 0000000000..13fc373dbf --- /dev/null +++ b/docs/install/mlc_llm.rst @@ -0,0 +1,133 @@ +.. _install-mlc-packages: + +Install MLC LLM Python Package +============================== + +.. contents:: Table of Contents + :local: + :depth: 2 + +MLC LLM Python Package can be installed directly from a prebuilt developer package, or built from source. + +Option 1. Prebuilt Package +-------------------------- + +We provide nightly built pip wheels for MLC-LLM via pip. +Select your operating system/compute platform and run the command in your terminal: + +.. note:: + ❗ Whenever using Python, it is highly recommended to use **conda** to manage an isolated Python environment to avoid missing dependencies, incompatible versions, and package conflicts. + +.. tabs:: + + .. tab:: Linux + + .. tabs:: + + .. tab:: CPU + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly + + .. tab:: CUDA 11.7 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu117 mlc-ai-nightly-cu117 + + .. tab:: CUDA 11.8 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu118 mlc-ai-nightly-cu118 + + .. tab:: CUDA 12.1 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu121 mlc-ai-nightly-cu121 + + .. tab:: CUDA 12.2 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu122 mlc-ai-nightly-cu122 + + .. tab:: ROCm 5.6 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-rocm56 mlc-ai-nightly-rocm56 + + .. tab:: ROCm 5.7 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-rocm57 mlc-ai-nightly-rocm57 + + .. tab:: Vulkan + + Supported in all Linux packages. + + .. note:: + + If encountering issues with GLIBC not found, please install the latest glibc in conda: + + .. code-block:: bash + + conda install -c conda-forge libgcc-ng + + .. tab:: macOS + + .. tabs:: + + .. tab:: CPU + Metal + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly + + .. note:: + + Always check if conda is installed properly in macOS using the command below: + + .. code-block:: bash + + conda info | grep platform + + It should return "osx-64" for Mac with Intel chip, and "osx-arm64" for Mac with Apple chip. + + .. tab:: Windows + + .. tabs:: + + .. tab:: CPU + Vulkan + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly + + .. note:: + If encountering the error below: + + .. code-block:: bash + + FileNotFoundError: Could not find module 'path\to\site-packages\tvm\tvm.dll' (or one of its dependencies). Try using the full path with constructor syntax. + + It is likely `zstd`, a dependency to LLVM, was missing. Please `download `__ the 64 bit version of precompiled binary, rename it to `zstd.dll` and copy to the same folder as `tvm.dll`. + + +Option 2. Build from Source +--------------------------- + +Upcoming. diff --git a/docs/install/tvm.rst b/docs/install/tvm.rst index 97ec1c9e40..c2b7998ada 100644 --- a/docs/install/tvm.rst +++ b/docs/install/tvm.rst @@ -37,49 +37,49 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly .. tab:: CUDA 11.7 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-cu117 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu117 .. tab:: CUDA 11.8 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-cu118 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu118 .. tab:: CUDA 12.1 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-cu121 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu121 .. tab:: CUDA 12.2 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-cu122 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu122 .. tab:: ROCm 5.6 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-rocm56 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-rocm56 .. tab:: ROCm 5.7 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-rocm57 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-rocm57 .. tab:: Vulkan @@ -97,19 +97,12 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. .. tabs:: - .. tab:: CPU - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly - - .. tab:: Metal + .. tab:: CPU + Metal .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly .. note:: @@ -125,16 +118,12 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. .. tabs:: - .. tab:: CPU + .. tab:: CPU + Vulkan .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly - - .. tab:: Vulkan - - Supported in all Windows packages. + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly .. note:: If encountering the error below: @@ -143,7 +132,7 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. FileNotFoundError: Could not find module 'path\to\site-packages\tvm\tvm.dll' (or one of its dependencies). Try using the full path with constructor syntax. - It is likely `zstd`, a dependency to LLVM, was missing. Please `download `__ the precompiled binary, rename it to `zstd.dll` and copy to the same folder as `tvm.dll`. + It is likely `zstd`, a dependency to LLVM, was missing. Please `download `__ the precompiled binary, rename it to `zstd.dll` and copy to the same folder as `tvm.dll`. Hint - To locate the "tvm.dll" file in Conda, navigate to your user home directory (e.g., "/users/xxxx"). Search for "tvm.dll" and find the folder whose path contains the name of the current environment, such as "mlc-chat-venv." Once located, copy "zstd.dll" to that specific folder. .. _tvm-unity-build-from-source: diff --git a/docs/prebuilt_models.rst b/docs/prebuilt_models.rst index 99a86780db..8a9f74253b 100644 --- a/docs/prebuilt_models.rst +++ b/docs/prebuilt_models.rst @@ -7,168 +7,164 @@ Model Prebuilts :depth: 3 :local: -MLC-LLM is a universal solution for deploying different language models. Any language models that can be described in `TVM Relax `__ (a general representation for Neural Networks and can be imported from models written in PyTorch) can be recognized by MLC-LLM and thus deployed to different backends with the help of :doc:`TVM Unity `. +.. _model-prebuilts-overview: -The community has already supported several LLM architectures (LLaMA, GPT-NeoX, etc.) and have prebuilt some models (Vicuna, RedPajama, etc.) which you can use off the shelf. -With the goal of democratizing the deployment of LLMs, we eagerly anticipate further contributions from the community to expand the range of supported model architectures. +Overview +-------- -This page contains the list of prebuilt models for our CLI (command line interface) app, iOS and Android apps. -The models have undergone extensive testing on various devices, and their performance has been optimized by developers with the help of TVM. +MLC-LLM is a universal solution for deploying different language models. Any models that can be described in `TVM Relax `__ +(a general representation for Neural Networks and can be imported from models written in PyTorch) can be recognized by MLC-LLM and thus deployed to different backends with the +help of :doc:`TVM Unity `. -.. _prebuilt-models-cli: +There are two ways to run a model on MLC-LLM: -Prebuilt Models for CLI ------------------------ +1. Compile your own models following :doc:`the model compilation page `. +2. Use off-the-shelf prebuilts models following this current page. -.. list-table:: - :widths: 15 15 15 15 - :header-rows: 1 +This page focuses on the second option: - * - Model code - - Original Model - - Quantization Mode - - Hugging Face repo - * - `Llama-2-{7, 13, 70}b-chat-hf-q4f16_1` - - `Llama-2 `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - * `7B link `__ - * `13B link `__ - * `70B link `__ - * - `vicuna-v1-7b-q3f16_0` - - `Vicuna `__ - - * Weight storage data type: int3 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` - - `RedPajama `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `rwkv-raven-{1b5, 3b, 7b}-q8f16_0` - - `RWKV `__ - - * Weight storage data type: uint8 - * Running data type: float16 - * Symmetric quantization - - * `1b5 link `__ - * `3b link `__ - * `7b link `__ - * - `WizardLM-13B-V1.2-{q4f16_1, q4f32_1}` - - `WizardLM `__ - - * Weight storage data type: int4 - * Running data type: float{16, 32} - * Symmetric quantization - - * `q4f16_1 link `__ - * `q4f32_1 link `__ - * - `WizardCoder-15B-V1.0-{q4f16_1, q4f32_1}` - - `WizardCoder `__ - - * Weight storage data type: int4 - * Running data type: float{16, 32} - * Symmetric quantization - - * `q4f16_1 link `__ - * `q4f32_1 link `__ - * - `WizardMath-{7, 13, 70}B-V1.0-q4f16_1` - - `WizardMath `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - * `7B link `__ - * `13B link `__ - * `70B link `__ - * - `llama2-7b-chat-uncensored-{q4f16_1, q4f32_1}` - - `georgesung `__ - - * Weight storage data type: int4 - * Running data type: float{16, 32} - * Symmetric quantization - - * `q4f16_1 link `__ - * `q4f32_1 link `__ - * - `Llama2-Chinese-7b-Chat-{q4f16_1, q4f32_1}` - - `FlagAlpha `__ - - * Weight storage data type: int4 - * Running data type: float{16, 32} - * Symmetric quantization - - * `q4f16_1 link `__ - * `q4f32_1 link `__ - * - `GOAT-7B-Community-{q4f16_1, q4f32_1}` - - `GOAT-AI `__ - - * Weight storage data type: int4 - * Running data type: float{16, 32} - * Symmetric quantization - - * `q4f16_1 link `__ - * `q4f32_1 link `__ - * - `OpenOrca-Platypus2-13B-q4f16_1` - - `Llama-2 `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ - -To download and run one model with CLI, follow the instructions below: - -.. code:: shell - - # Create conda environment and install CLI if you have not installed. - conda create -n mlc-chat-venv -c mlc-ai -c conda-forge mlc-chat-cli-nightly - conda activate mlc-chat-venv - conda install git git-lfs - git lfs install - - # Download prebuilt model binary libraries from GitHub if you have not downloaded. - mkdir -p dist/prebuilt - git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt/lib - - # Download prebuilt model weights and run CLI. - cd dist/prebuilt - git clone https://huggingface.co/mlc-ai/mlc-chat-[model-code] - cd ../.. - mlc_chat_cli --model [model-code] - - # e.g., - # cd dist/prebuilt - # git clone https://huggingface.co/mlc-ai/mlc-chat-rwkv-raven-7b-q8f16_0 - # cd ../.. - # mlc_chat_cli --model rwkv-raven-7b-q8f16_0 - - -.. _prebuilt-models-ios: - -Prebuilt Models for iOS ------------------------ - -.. list-table:: Prebuilt models for iOS - :widths: 15 15 15 15 - :header-rows: 1 +- Documenting :ref:`how to use prebuilts ` for various platforms, and +- Tracking what current :ref:`prebuilt models we provide `. + +Prerequisite: Model Libraries and Compiled Weights +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to run a specific model on MLC-LLM, you need: + +**1. A model library:** a binary file containing the end-to-end functionality to inference a model (e.g. ``Llama-2-7b-chat-hf-q4f16_1-cuda.so``). See the full list of all precompiled model libraries `here `__. + +**2. Compiled weights:** a folder containing multiple files that store the compiled and quantized weights of a model (e.g. https://huggingface.co/mlc-ai/mlc-chat-Llama-2-7b-chat-hf-q4f16_1). See the full list of all precompiled weights `here `__. + +.. _using-model-prebuilts: + +Using Prebuilt Models for Different Platforms +--------------------------------------------- + +We quickly go over how to use prebuilt models for each platform. You can find detailed instruction on each platform's corresponding page. + +.. _using-prebuilt-models-cli: + + +Prebuilt Models on CLI / Python +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For more, please see :doc:`the CLI page `, and the :doc:`the Python page `. + +.. collapse:: Click to show details + + First create the conda environment if you have not done so. + + .. code:: shell + + conda create -n mlc-chat-venv -c mlc-ai -c conda-forge mlc-chat-cli-nightly + conda activate mlc-chat-venv + conda install git git-lfs + git lfs install + + Download the prebuilt model libraries from github. + + .. code:: shell + + mkdir -p dist/prebuilt + git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt/lib + + Download the prebuilt model weights from hugging face for the model variant you want. + + .. code:: shell + + # Say we want to run rwkv-raven-7b-q8f16_0 + cd dist/prebuilt + git clone https://huggingface.co/mlc-ai/mlc-chat-rwkv-raven-7b-q8f16_0 + cd ../.. + + # The format being: + # cd dist/prebuilt + # git clone https://huggingface.co/mlc-ai/mlc-chat-[model-code] + # cd ../.. + # mlc_chat_cli --model [model-code] + + Run the model with CLI: - * - Model code - - Model Series - - Quantization Mode - - Hugging Face repo - * - `Llama-2-7b-q3f16_1` - - `Llama `__ - - * Weight storage data type: int3 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `vicuna-v1-7b-q3f16_0` - - `Vicuna `__ - - * Weight storage data type: int3 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` - - `RedPajama `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ - -The `downloadable iOS app `_ has builtin RedPajama-3B model support. -To add a model to the iOS app, follow the steps below: - -.. collapse:: Click to show instructions + .. code:: shell + + # For CLI + mlc_chat_cli --model rwkv-raven-7b-q8f16_0 + + To run the model with Python API, see :doc:`the Python page ` (all other downloading steps are the same as CLI). + + +.. for a blank line + +| + +.. _using-prebuilt-models-ios: + +Prebuilt Models on iOS +^^^^^^^^^^^^^^^^^^^^^^ + +For more, please see :doc:`the iOS page `. + +.. collapse:: Click to show details + + The `iOS app `_ has builtin RedPajama-3B and Llama-2-7b support. + + All prebuilt models with an entry in ``iOS`` in the :ref:`model library table ` are supported by iOS. Namely, we have: + + .. list-table:: Prebuilt model libraries integrated in the iOS app + :widths: 15 15 15 + :header-rows: 1 + + * - Model library name + - Model Family + - Quantization Mode + * - `Llama-2-7b-chat-hf-q3f16_1` + - LLaMA + - * Weight storage data type: int3 + * Running data type: float16 + * Symmetric quantization + * - `vicuna-v1-7b-q3f16_0` + - LLaMA + - * Weight storage data type: int3 + * Running data type: float16 + * Symmetric quantization + * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` + - GPT-NeoX + - * Weight storage data type: int4 + * Running data type: float16 + * Symmetric quantization + + As for prebuilt model weights, the ones we have integrated into app are listed below: + + .. list-table:: Tested prebuilt model weights for iOS + :widths: 15 15 15 15 + :header-rows: 1 + + * - Model code + - Model Series + - Quantization Mode + - Hugging Face repo + * - `Llama-2-7b-q3f16_1` + - `Llama `__ + - * Weight storage data type: int3 + * Running data type: float16 + * Symmetric quantization + - `link `__ + * - `vicuna-v1-7b-q3f16_0` + - `Vicuna `__ + - * Weight storage data type: int3 + * Running data type: float16 + * Symmetric quantization + - `link `__ + * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` + - `RedPajama `__ + - * Weight storage data type: int4 + * Running data type: float16 + * Symmetric quantization + - `link `__ + + To run a model variant you compiled on your own, you can directly reuse the above integrated prebuilt model libraries, as long as the model shares the architecture and is compiled with the same quantization mode. For example, if you compile `OpenLLaMA-7B `_ with quantization mode ``q3f16_0``, then you can run the compiled OpenLLaMA model on iPhone without rebuilding the iOS app by reusing the `vicuna-v1-7b-q3f16_0` model library. Then you can upload the compiled weights to hugging face so that you can download the weights in the app as shown below (for more on uploading to hugging face, please check the :doc:`model distribution page `). + + To add a model to the iOS app, follow the steps below: .. tabs:: @@ -210,126 +206,623 @@ To add a model to the iOS app, follow the steps below: | -The iOS app has integrated with the following model libraries, which can be directly reused when you want to run a model you compiled in iOS, as long as the model is in the supported model family and is compiled with supported quantization mode. -For example, if you compile `OpenLLaMA-7B `_ with quantization mode ``q3f16_0``, then you can run the compiled OpenLLaMA model on iPhone without rebuilding the iOS app by reusing the `vicuna-v1-7b-q3f16_0` model library. Please check the :doc:`model distribution page ` for detailed instructions. - -.. list-table:: Prebuilt model libraries which are integrated in the iOS app - :widths: 15 15 15 - :header-rows: 1 - - * - Model library name - - Model Family - - Quantization Mode - * - `Llama-2-7b-chat-hf-q3f16_1` - - LLaMA - - * Weight storage data type: int3 - * Running data type: float16 - * Symmetric quantization - * - `vicuna-v1-7b-q3f16_0` - - LLaMA - - * Weight storage data type: int3 - * Running data type: float16 - * Symmetric quantization - * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` - - GPT-NeoX - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - .. _prebuilt-models-android: -Prebuilt Models for Android ---------------------------- - -.. list-table:: Prebuilt models for Android - :widths: 15 15 15 15 - :header-rows: 1 +Prebuilt Models on Android +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For more, please see :doc:`the Android page `. + +.. collapse:: Click to show details + + The apk for demo Android app includes the following models. To add more, check out the Android page. + + .. list-table:: Prebuilt Models for Android + :widths: 15 15 15 15 + :header-rows: 1 + + * - Model code + - Model Series + - Quantization Mode + - Hugging Face repo + * - `Llama-2-7b-q4f16_1` + - `Llama `__ + - * Weight storage data type: int4 + * Running data type: float16 + * Symmetric quantization + - `link `__ + * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` + - `RedPajama `__ + - * Weight storage data type: int4 + * Running data type: float16 + * Symmetric quantization + - `link `__ +.. for a blank line - * - Model code - - Model Series - - Quantization Mode - - Hugging Face repo - * - `vicuna-v1-7b-q4f16_1` - - `Vicuna `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `RedPajama-INCITE-Chat-3B-v1-q4f16_0` - - `RedPajama `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ +| ------------------- +.. _supported-model-architectures: -You can check `MLC-LLM pull requests `__ to track the ongoing efforts of new models. We encourage users to upload their compiled models to Hugging Face and share with the community. +Level 1: Supported Model Architectures (The All-In-One Table) +------------------------------------------------------------- -.. _supported-model-architectures: +For each model architecture (e.g. Llama), there are multiple variants (e.g. CodeLlama, WizardLM). The variants share the same code for inference and only differ in their weights. In other words, running CodeLlama and WizardLM can use the same model library file (specified in Level 2 tables), but different precompiled weights (specified in Level 3 tables). Note that we have not provided prebuilt weights for all model variants. -Supported Model Architectures ------------------------------ +Each entry below hyperlinks to the corresponding level 2 and level 3 tables. MLC-LLM supports the following model architectures: .. list-table:: Supported Model Architectures - :widths: 15 15 15 15 + :widths: 10 10 15 15 :header-rows: 1 - * - Category Code - - Series - - Model Definition - - Variants - * - ``llama`` - - `LLaMa `__ - - `Relax Code `__ - - * `Llama-2 `__ - * `Alpaca `__ - * `Vicuna `__ + * - Model Architecture + - Support + - Available MLC Prebuilts + - Unavailable in MLC Prebuilts + * - `LLaMA `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ + - * :ref:`Llama-2 ` + * :ref:`Code Llama ` + * :ref:`Vicuna ` + * :ref:`WizardLM ` + * :ref:`WizardMath ` + * :ref:`OpenOrca Platypus2 ` + * :ref:`FlagAlpha Llama-2 Chinese ` + * :ref:`georgesung Llama-2 Uncensored ` + - * `Alpaca `__ * `Guanaco `__ * `OpenLLaMA `__ * `Gorilla `__ - * `WizardLM `__ * `YuLan-Chat `__ - * `WizardMath `__ - * `FlagAlpha Llama-2 Chinese `__ - * - ``gpt-neox`` - - `GPT-NeoX `__ - - `Relax Code `__ - - * `RedPajama `__ - * `Dolly `__ + * `WizardCoder (new) `__ + * - `GPT-NeoX `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ + - * :ref:`RedPajama ` + - * `Dolly `__ * `Pythia `__ * `StableCode `__ - * - ``gptj`` - - `GPT-J `__ - - `Relax Code `__ + * - `GPT-J `__ + - * Prebuilt not compiled yet + * `MLC Implementation `__ + - - * `MOSS `__ - * - ``rwkv`` - - `RWKV `__ - - `Relax Code `__ - - * `RWKV-raven `__ - * - ``minigpt`` - - `MiniGPT `__ - - `Relax Code `__ - - - * - ``gpt_bigcode`` - - `GPTBigCode `__ - - `Relax Code `__ + * - `RWKV `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ + - * :ref:`RWKV-raven ` + - + * - `MiniGPT `__ + - * Prebuilt not compiled yet + * `MLC Implementation `__ + - + - * `MiniGPT-4 `__ + * - `GPTBigCode `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ + - * :ref:`WizardCoder (old) ` - * `StarCoder `__ - * `WizardCoder `__ * `SantaCoder `__ - * - ``chatglm`` - - `ChatGLM `__ - - `Relax Code `__ + * - `ChatGLM `__ + - * Prebuilt not compiled yet + * `MLC Implementation `__ + - - * `ChatGLM2 `__ * `CodeGeeX2 `__ + * - `StableLM `__ + - * Prebuilt not compiled yet + * `MLC Implementation `__ + - + - * `StableLM `__ + +If the model variant you are interested in uses one of these model architectures we support (but we have not provided the prebuilt weights yet), you can check out :doc:`/compilation/compile_models` on how to compile your own models. Afterwards, you may follow :doc:`/compilation/distribute_compiled_models` to upload your prebuilt weights to hugging face, and submit a PR that adds an entry to this page, contributing to the community. + +For models structured in an architecture we have not supported yet, you could: + +- Either `create a [Model Request] issue `__ which automatically shows up on our `Model Request Tracking Board `__. + +- Or follow our tutorial :doc:`Define New Models `, which introduces how to bring a new model architecture to MLC-LLM. + + +.. _model-library-tables: + +Level 2: Model Library Tables (Precompiled Binary Files) +-------------------------------------------------------- + +As mentioned earlier, each model architecture corresponds to a different model library file. That is, you cannot use the same model library file to run ``RedPajama`` and ``Llama-2``. However, you can use the same ``Llama`` model library file to run ``Llama-2``, ``WizardLM``, ``CodeLlama``, etc, but just with different weight files (from tables in Level 3). + +Each table below demonstrates the pre-compiled model library files for each model architecture. This is categorized by: + +- **Size**: each size of model has its own distinct model library file (e.g. 7B or 13B number of parameters) + +- **Platform**: the backend that the model library is intended to be run on (e.g. CUDA, ROCm, iphone, etc.) + +- **Quantization scheme**: the model library file also differs due to the quantization scheme used. For more on this, please see the :doc:`model compilation page ` (e.g. ``q3f16_1`` vs. ``q4f16_1``) + +Each entry links to the specific model library file found in `this github repo `__. + +.. _llama_library_table: + +Llama +^^^^^ +.. list-table:: Llama + :widths: 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M1/M2) + - Metal + + (Intel) + - iOS + - webgpu + - mali + * - 7B + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q3f16_1 `__ + - `q4f16_1 `__ + + `q4f32_1 `__ + - `q4f16_1 `__ + * - 13B + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - + - `q4f16_1 `__ + + `q4f32_1 `__ + - `q4f16_1 `__ + * - 34B + - `q4f16_1 `__ + - + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - + - + - + - + * - 70B + - + - + - + - + - `q3f16_1 `__ + + `q4f16_1 `__ + - + - + - `q4f16_1 `__ + - + +.. _gpt_neox_library_table: + +GPT-NeoX (RedPajama-INCITE) +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. list-table:: GPT-NeoX (RedPajama-INCITE) + :widths: 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M1/M2) + - Metal + + (Intel) + - iOS + - webgpu + - mali + * - 3B + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + + `q4f32_0 `__ + + `q4f32_1 `__ + - `q4f16_1 `__ + +.. _rwkv_library_table: + +RWKV +^^^^ +.. list-table:: RWKV + :widths: 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M1/M2) + - Metal + + (Intel) + - iOS + - webgpu + - mali + * - 1B5 + - + - + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - + - + - + * - 3B + - + - + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - + - + - + * - 7B + - + - + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - + - + - +.. _gpt_big_code_library_table: + +GPTBigCode +^^^^^^^^^^ +Note that these all links to model libraries for WizardCoder (the older version released in Jun. 2023). +However, any GPTBigCode model variants should be able to reuse these (e.g. StarCoder, SantaCoder). + +.. list-table:: GPTBigCode + :widths: 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M1/M2) + - Metal + + (Intel) + - iOS + - webgpu + - mali + * - 15B + - `q4f16_1 `__ + + `q4f32_1 `__ + - + - `q4f16_1 `__ + + `q4f32_1 `__ + - `q4f16_1 `__ + + `q4f32_1 `__ + - `q4f16_1 `__ + - + - + - `q4f16_1 `__ + + `q4f32_1 `__ + - + +.. _model-variant-tables: + +Level 3: Model Variant Tables (Precompiled Weights) +--------------------------------------------------- + +Finally, for each model variant, we provide the precompiled weights we uploaded to hugging face. + +Each precompiled weight is categorized by its model size (e.g. 7B vs. 13B) and the quantization scheme (e.g. ``q3f16_1`` vs. ``q4f16_1``). We note that the weights are **platform-agnostic**. + +Each model variant also loads its conversation configuration from a pre-defined :ref:`conversation template`. Note that multiple model variants can share a common conversation template. + +Some of these files are uploaded by our community contributors--thank you! + +.. _llama2_variant_table: + +`Llama-2 `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``llama-2`` + +.. list-table:: Llama-2 + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q3f16_1 `__ + * `q4f16_1 `__ + * `q4f32_1 `__ + + * - 13B + - * `q4f16_1 `__ + * `q4f32_1 `__ + + * - 70B + - * `q3f16_1 `__ + * `q4f16_1 `__ + +.. _code_llama_variant_table: + +`Code Llama `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``codellama_completion`` + +.. list-table:: Code Llama + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q4f16_1 (Base) `__ + * `q4f16_1 (Instruct) `__ + * `q4f16_1 (Python) `__ + + * - 13B + - * `q4f16_1 (Base) `__ + * `q4f16_1 (Instruct) `__ + * `q4f16_1 (Python) `__ + + * - 34B + - * `q4f16_1 (Base) `__ + * `q4f16_1 (Instruct) `__ + * `q4f16_1 (Python) `__ + + +.. _vicuna_variant_table: + +`Vicuna `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``vicuna_v1.1`` + +.. list-table:: Vicuna + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q3f16_0 `__ + * `q4f32_0 `__ + * `int3 (demo) `__ + * `int4 (demo) `__ + + +.. _WizardLM_variant_table: + +`WizardLM `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``vicuna_v1.1`` + +.. list-table:: WizardLM + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 13B + - * `q4f16_1 (V1.2) `__ + * `q4f32_1 (V1.2) `__ + + * - 70B + - * `q3f16_1 (V1.0) `__ + * `q4f16_1 (V1.0) `__ + + +.. _wizard_math_variant_table: + +`WizardMath `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``wizard_coder_or_math`` + +.. list-table:: WizardMath + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q4f16_1 `__ + * `q4f32_1 `__ + * - 13B + - `q4f16_1 `__ + * - 70B + - `q4f16_1 `__ + + +.. _open_orca_variant_table: + +`OpenOrca Platypus2 `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``llama-2`` + +.. list-table:: OpenOrca Platypus2 + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 13B + - `q4f16_1 `__ + + +.. _flag_alpha_llama2_variant_table: + +`FlagAlpha Llama-2 Chinese `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``llama-2`` + +.. list-table:: FlagAlpha Llama-2 Chinese + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q4f16_1 `__ + * `q4f32_1 `__ + + +.. _llama2_uncensored_variant_table: + +`Llama2 uncensored (georgesung) `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``llama-default`` + +.. list-table:: Llama2 uncensored + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q4f16_1 `__ + * `q4f32_1 `__ + +.. _red_pajama_variant_table: + +`RedPajama `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``LM`` + +.. list-table:: Red Pajama + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 3B + - * `q4f16_0 (Instruct) `__ + * `q4f16_0 (Chat) `__ + * `q4f16_1 (Chat) `__ + * `q4f32_0 (Chat) `__ + + +.. _rwkv_raven_variant_table: + +`RWKV-raven `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``rwkv`` + +.. list-table:: RWKV-raven + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 1B5 + - `q8f16_0 `__ + + * - 3B + - `q8f16_0 `__ + + * - 7B + - `q8f16_0 `__ + + +.. _wizard_coder_variant_table: + +`WizardCoder `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``wizard_coder_or_math`` + +.. list-table:: WizardCoder + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 15B + - `q4f16_1 `__ + +------------------ -For models structured in these model architectures, you can check the :doc:`model compilation page ` on how to compile models. -Please `create a new issue `_ if you want to request a new model architecture. -Our tutorial :doc:`Define New Models ` introduces how to bring a new model architecture to MLC-LLM. .. _contribute-models-to-mlc-llm: diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py new file mode 100644 index 0000000000..a290eb892c --- /dev/null +++ b/examples/python/run_llama_batched_vllm.py @@ -0,0 +1,448 @@ +import argparse +import math +import os +import json +from collections import defaultdict +from typing import List +from dataclasses import dataclass + +import numpy as np + +import tvm +from tvm import relax +from tvm.runtime import disco as di + +import torch +from transformers import AutoTokenizer + +from mlc_llm.relax_model.llama import LlamaConfig +from mlc_llm import utils + + +class KVCache: + def __init__(self, num_blocks, block_size, num_layers, num_heads, head_size, disco_session): + if disco_session: + init_cache_func = disco_session.get_global_func("tvm.contrib.vllm.allocate_kv_cache") + else: + init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache") + + self.cache = init_cache_func(head_size, num_layers, num_heads, block_size, num_blocks) + + self.block_tables = defaultdict(list) + self.slot_mappings = defaultdict(list) + self.block_size = block_size + + +class CacheManager: + block_size: int = 16 + + def __init__( + self, num_blocks, num_layers, num_heads, head_size, disco_session=None, sliding_window=None + ): + self.num_blocks = num_blocks + self.free_blocks = list(range(num_blocks)) + self.kv_cache = KVCache( + num_blocks, self.block_size, num_layers, num_heads, head_size, disco_session + ) + + if sliding_window: + assert sliding_window % self.kv_cache.block_size == 0 + self.block_sliding_window = sliding_window // self.kv_cache.block_size + else: + self.block_sliding_window = None + + def set_size(self, request_ids: List[int], target_sizes: List[int]): + for id, size in zip(request_ids, target_sizes): + num_needed_block = math.ceil(size / self.block_size) + + if self.block_sliding_window: + num_needed_block = min(num_needed_block, self.block_sliding_window) + + if id in self.kv_cache.block_tables and size == 0: + self.free_blocks.extend(self.kv_cache.block_tables[id]) + del self.kv_cache.block_tables[id] + del self.kv_cache.slot_mappings[id] + + elif id in self.kv_cache.block_tables: + # Decoding + if len(self.kv_cache.block_tables[id]) < num_needed_block: + # Need to allocate a new block for this request + assert len(self.kv_cache.block_tables[id]) + 1 == num_needed_block + self.kv_cache.block_tables[id].append(self.free_blocks.pop()) + + pos = size - 1 + block_number = self.kv_cache.block_tables[id][-1] + + if self.block_sliding_window: + block_number = self.kv_cache.block_tables[id][ + (pos // self.block_size) % self.block_sliding_window + ] + else: + block_number = self.kv_cache.block_tables[id][-1] + + block_offset = pos % self.block_size + slot = block_number * self.block_size + block_offset + self.kv_cache.slot_mappings[id].append(slot) + + elif id not in self.kv_cache.block_tables: + assert len(self.free_blocks) >= num_needed_block, "Not enough free blocks." + + for _ in range(num_needed_block): + self.kv_cache.block_tables[id].append(self.free_blocks.pop()) + + for i in range(size): + block_idx = i // self.block_size + + if self.block_sliding_window: + block_idx %= self.block_sliding_window + + block_number = self.kv_cache.block_tables[id][block_idx] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + self.kv_cache.slot_mappings[id].append(slot) + + def get(self): + return self.kv_cache + + +@dataclass +class SequenceGenerationRequest: + request_id: int + token_ids: List[int] + + +@dataclass +class SequenceGenerationResponse: + request_id: int + token_id: int + + +def sample(logits): + logits = torch.from_dlpack(logits) + return torch.argmax(logits, -1).cpu().numpy() + + +def load_params_disco(artifact_path, lib_path, num_shards): + sess = di.ProcessSession(num_workers=num_shards) + devices = range(num_shards) + sess.init_ccl("nccl", *devices) + module = sess.load_vm_module(lib_path) + + loader_create = sess.get_global_func("runtime.disco.ShardLoader") + metadata_path = os.path.join(artifact_path, "params", "ndarray-cache.json") + with open(metadata_path, "r", encoding="utf-8") as f: + ndarray_cache_metadata = f.read() + + loader = loader_create(metadata_path, ndarray_cache_metadata, "", module) + loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAll") + params = loader_load(loader) + + return module, params, sess + + +def copy_to_worker_0(sess: di.Session, host_array): + x_array = sess.empty(host_array.shape, host_array.dtype) + sess.copy_to_worker_0(host_array, x_array) + return x_array + + +def get_tvm_model(artifact_path, model, quantization, num_shards, dev): + lib_path = os.path.join(artifact_path, f"{model}-{quantization}-cuda.so") + + if num_shards == 1: + ex = tvm.runtime.load_module(lib_path) + vm = relax.VirtualMachine(ex, dev) + params = utils.load_params(artifact_path, dev) + return vm.module, params, None + + return load_params_disco(artifact_path, lib_path, num_shards) + + +def _prepare_inputs( + requests, + all_slot_mappings, + all_block_tables, + sliding_window, + dev, + is_prefill, +): + block_tables = [] + seq_lens = [] + input_ids = [] + slot_mapping = [] + positions = [] + max_num_blocks_per_seq = 0 + indices_within_window = [] + start_idx = 0 + + for request in requests: + request_id = request.request_id + token_ids = request.token_ids + + if is_prefill: + input_ids += token_ids + prompt_len = len(token_ids) + seq_lens.append(prompt_len) + positions += range(prompt_len) + slot_mapping += all_slot_mappings[request_id] + + if sliding_window: + indices_within_window += range( + start_idx + max(0, prompt_len - sliding_window), + start_idx + prompt_len, + ) + start_idx += prompt_len + + else: + input_ids.append(token_ids[-1]) + pos = len(token_ids) - 1 + positions.append(pos) + block_table = all_block_tables[request_id] + max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table)) + block_tables.append(block_table) + slot_mapping.append(all_slot_mappings[request_id][-1]) + + if sliding_window: + seq_lens.append(min(len(token_ids), sliding_window)) + else: + seq_lens.append(len(token_ids)) + + input_ids = tvm.nd.array(np.array(input_ids, dtype="int32"), dev) + positions = tvm.nd.array(np.array(positions, dtype="int32"), dev) + seq_lens = tvm.nd.array(np.array(seq_lens, dtype="int32"), dev) + slot_mapping = tvm.nd.array(np.array(slot_mapping, dtype="int32"), dev) + + if is_prefill and sliding_window: + indices_within_window = tvm.nd.array(np.array(indices_within_window, dtype="int32"), dev) + else: + indices_within_window = None + + if not is_prefill: + + def _pad_to_max(x: List[int], max_len: int) -> List[int]: + return x + [0] * (max_len - len(x)) + + padded_block_tables = [ + _pad_to_max(block_table, max_num_blocks_per_seq) for block_table in block_tables + ] + + block_tables_np = np.vstack(padded_block_tables).astype("int32") + block_tables = tvm.nd.array(np.array(block_tables_np, dtype="int32"), dev) + else: + block_tables = None + + return ( + input_ids, + positions, + seq_lens, + slot_mapping, + indices_within_window, + block_tables, + ) + + +class Model: + def __init__( + self, artifact_path, model_name, quant, vocab_size, num_shards, dev, sliding_window + ): + self.mod, self.params, self.disco_session = get_tvm_model( + artifact_path, model_name, quant, num_shards, dev + ) + self.dev = dev + self.vocab_size = vocab_size + self.sliding_window = sliding_window + + if sliding_window: + self.block_sliding_window = sliding_window // CacheManager.block_size + else: + self.block_sliding_window = None + + def generate( + self, requests: List[SequenceGenerationRequest], cache: KVCache, is_prefill: bool + ) -> List[SequenceGenerationResponse]: + ( + input_ids, + positions, + seq_lens, + slot_mapping, + indices_within_window, + block_tables, + ) = _prepare_inputs( + requests, + cache.slot_mappings, + cache.block_tables, + self.sliding_window, + self.dev, + is_prefill, + ) + + if self.disco_session: + input_ids = copy_to_worker_0(self.disco_session, input_ids) + positions = copy_to_worker_0(self.disco_session, positions) + seq_lens = copy_to_worker_0(self.disco_session, seq_lens) + slot_mapping = copy_to_worker_0(self.disco_session, slot_mapping) + + kv_cache = cache.cache + + if is_prefill: + if self.sliding_window: + if self.disco_session: + indices_within_window = copy_to_worker_0( + self.disco_session, indices_within_window + ) + + out = self.mod["prefill"]( + input_ids, + positions, + seq_lens, + kv_cache, + slot_mapping, + indices_within_window, + self.params, + ) + else: + out = self.mod["prefill"]( + input_ids, positions, seq_lens, kv_cache, slot_mapping, self.params + ) + + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[0] # Ignore returned KV cache since it is updated in-place anyway. + else: + if self.disco_session: + block_tables = copy_to_worker_0(self.disco_session, block_tables) + + out = self.mod["decode"]( + input_ids, + positions, + seq_lens, + kv_cache, + slot_mapping, + block_tables, + self.params, + ) + + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[0] + + next_tokens = sample(logits) + + return [ + SequenceGenerationResponse(request.request_id, new_token) + for request, new_token in zip(requests, next_tokens) + ] + + +def parse_args(): + # Example + # python build.py --model vicuna-v1-7b --quantization q4f16_ft --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention + # python examples/python/run_llama_batched_vllm.py --local-id vicuna-v1-7b-q4f16_ft + # + # For Disco: + # python build.py --model vicuna-v1-7b --quantization q0f16 --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention --build-model-only --num-shards 2 + # python build.py --model vicuna-v1-7b --quantization q0f16 --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention --convert-weight-only + # CUDA_VISIBLE_DEVICES=0,1 python examples/python/run_llama_batched_vllm.py --local-id vicuna-v1-7b-q0f16 --num-shards 2 + + args = argparse.ArgumentParser() + args.add_argument("--local-id", type=str, required=True) + args.add_argument("--artifact-path", type=str, default="dist") + args.add_argument("--num-shards", type=int, default=1) + args.add_argument("--num-decode-steps", type=int, default=20) + parsed = args.parse_args() + parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1) + utils.argparse_postproc_common(parsed) + parsed.artifact_path = os.path.join( + parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" + ) + return parsed + + +def run(args): + quantization = args.quantization.name + artifact_path = args.artifact_path + model_name = args.model + model_path = f"dist/models/{model_name}" + + dev = tvm.device("cuda", 0) + + with open(os.path.join(model_path, "config.json"), encoding="utf-8") as i_f: + config = LlamaConfig(**json.load(i_f)) + + model = Model( + artifact_path, + model_name, + quantization, + config.vocab_size, + args.num_shards, + dev, + config.sliding_window, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False) + + num_kv_heads = config.get_num_key_value_heads() // args.num_shards + head_size = config.hidden_size // config.num_attention_heads + num_blocks = 500 + + cache_manager = CacheManager( + num_blocks, + config.num_hidden_layers, + num_kv_heads, + head_size, + model.disco_session, + sliding_window=config.sliding_window, + ) + cache = cache_manager.get() + + model.block_sliding_window = cache_manager.block_sliding_window + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + batched_token_ids = [tokenizer.encode(p) for p in prompts] + prompts_len = [len(ids) for ids in batched_token_ids] + request_ids = list(range(len(prompts))) + target_sizes = [] + requests = [] + + for token_ids, request_id in zip(batched_token_ids, request_ids): + request_ids.append(request_id) + target_sizes.append(len(token_ids)) + requests.append(SequenceGenerationRequest(request_id, token_ids)) + + cache_manager.set_size(request_ids, target_sizes) + + out = model.generate(requests, cache, True) + + for _ in range(args.num_decode_steps): + for i, response in enumerate(out): + new_token_id = response.token_id + requests[i].token_ids.append(new_token_id) + target_sizes[i] += 1 + + cache_manager.set_size(request_ids, target_sizes) + + out = model.generate(requests, cache, False) + + output_tokens = [ + tokenizer.convert_ids_to_tokens( + requests[i].token_ids[prompts_len[i] :], skip_special_tokens=True + ) + for i in range(len(requests)) + ] + + generated = [tokenizer.convert_tokens_to_string(tokens) for tokens in output_tokens] + + for p, g in zip(prompts, generated): + print("Prompt = '{}', generated text = '{}'".format(p, g)) + + +if __name__ == "__main__": + run(parse_args()) diff --git a/examples/python/sample_chat_stream.py b/examples/python/sample_chat_stream.py new file mode 100644 index 0000000000..980e833d20 --- /dev/null +++ b/examples/python/sample_chat_stream.py @@ -0,0 +1,30 @@ +from mlc_chat import ChatModule +from mlc_chat.callback import StreamToStdout, StreamIterator + +# From the mlc-llm directory, run +# $ python examples/python/sample_chat_stream.py + +# Create a ChatModule instance +cm = ChatModule(model="Llama-2-7b-chat-hf-q4f16_1") + +# Stream to Stdout +output = cm.generate( + prompt="What is the meaning of life?", + progress_callback=StreamToStdout(callback_interval=2), +) + +# Stream to an Iterator +from threading import Thread + +stream = StreamIterator(callback_interval=2) +generation_thread = Thread( + target=cm.generate, + kwargs={"prompt": "What is the meaning of life?", "progress_callback": stream}, +) +generation_thread.start() + +output = "" +for delta_message in stream: + output += delta_message + +generation_thread.join() diff --git a/ios/MLCSwift/Sources/ObjC/LLMChat.mm b/ios/MLCSwift/Sources/ObjC/LLMChat.mm index da5edc177e..dcf57c5db2 100644 --- a/ios/MLCSwift/Sources/ObjC/LLMChat.mm +++ b/ios/MLCSwift/Sources/ObjC/LLMChat.mm @@ -23,7 +23,8 @@ // The input message is only the beginning part of a prompt, no role name and separator should be // appended after the message since there will be future messages appended after the message. kBegin, - // The input message is in the middle of a prompt, nothing should be appended before or after the message. + // The input message is in the middle of a prompt, nothing should be appended before or after the + // message. kMiddle, // The input message is the ending part of a prompt, no role name and separator should be appended // prior to it since the message is concatenated to some prior messages. @@ -118,7 +119,9 @@ - (void)unload { unload_func_(); } -- (void)reload:(NSString*)modelLib modelPath:(NSString*)modelPath appConfigJson:(NSString*)appConfigJson { +- (void)reload:(NSString*)modelLib + modelPath:(NSString*)modelPath + appConfigJson:(NSString*)appConfigJson { std::string lib_prefix = modelLib.UTF8String; std::string model_path = modelPath.UTF8String; std::string app_config_json = appConfigJson.UTF8String; @@ -194,7 +197,9 @@ - (void)resetImageModule { first_input_after_image = false; } -- (void)prefillImage:(UIImage*)image prevPlaceholder:(NSString*)prevPlaceholder postPlaceholder:(NSString*)postPlaceholder { +- (void)prefillImage:(UIImage*)image + prevPlaceholder:(NSString*)prevPlaceholder + postPlaceholder:(NSString*)postPlaceholder { // prefill the previous placeholder string std::string prev_placeholder = prevPlaceholder.UTF8String; prefill_func_(prev_placeholder, false, (int)PlaceInPrompt::kBegin); @@ -206,9 +211,9 @@ - (void)prefillImage:(UIImage*)image prevPlaceholder:(NSString*)prevPlaceholder NSUInteger bytesPerPixel = 4; NSUInteger bytesPerRow = bytesPerPixel * image_width; NSUInteger bitsPerComponent = 8; - CGContextRef context = CGBitmapContextCreate(image_data.data(), image_width, image_height, - bitsPerComponent, bytesPerRow, colorSpace, - kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); + CGContextRef context = CGBitmapContextCreate( + image_data.data(), image_width, image_height, bitsPerComponent, bytesPerRow, colorSpace, + kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); CGColorSpaceRelease(colorSpace); CGContextDrawImage(context, CGRectMake(0, 0, image_width, image_height), imageRef); CGContextRelease(context); diff --git a/ios/MLCSwift/Sources/ObjC/include/LLMChat.h b/ios/MLCSwift/Sources/ObjC/include/LLMChat.h index a996eaa55f..0aab17adb1 100644 --- a/ios/MLCSwift/Sources/ObjC/include/LLMChat.h +++ b/ios/MLCSwift/Sources/ObjC/include/LLMChat.h @@ -40,9 +40,12 @@ * * @param modelLib The name of the modelLib * @param modelPath The path to the model artifacts. - * @param appConfigJson The partial config that is used to partially override the model configuration. + * @param appConfigJson The partial config that is used to partially override the model + * configuration. */ -- (void)reload:(NSString*)modelLib modelPath:(NSString*)modelPath appConfigJson:(NSString*)appConfigJson; +- (void)reload:(NSString*)modelLib + modelPath:(NSString*)modelPath + appConfigJson:(NSString*)appConfigJson; /** * Reset the current chat session. @@ -118,5 +121,7 @@ * @param prevPlaceholder The previous placeholder in the prompt, i.e. . * @param postPlaceholder The post placeholder in the prompt, i.e. . */ -- (void)prefillImage:(UIImage*)image prevPlaceholder:(NSString*)prevPlaceholder postPlaceholder:(NSString*)postPlaceholder; +- (void)prefillImage:(UIImage*)image + prevPlaceholder:(NSString*)prevPlaceholder + postPlaceholder:(NSString*)postPlaceholder; @end diff --git a/mlc_llm/build.py b/mlc_llm/build.py index 703856c336..b7619aa963 100644 --- a/mlc_llm/build.py +++ b/mlc_llm/build.py @@ -1,13 +1,47 @@ """Script for building/compiling models.""" +import contextlib +import sys + from mlc_llm import core + +@contextlib.contextmanager +def debug_on_except(): + try: + yield + finally: + raised_exception = sys.exc_info()[1] + if not isinstance(raised_exception, Exception): + return + + import traceback + + try: + import ipdb as pdb + except ImportError: + import pdb + + traceback.print_exc() + pdb.post_mortem() + + def main(): """Main method for building model from command line.""" empty_args = core.convert_build_args_to_argparser() # Create new ArgumentParser parsed_args = empty_args.parse_args() # Parse through command line - # Post processing of arguments - parsed_args = core._parse_args(parsed_args) # pylint: disable=protected-access - core.build_model_from_args(parsed_args) + + with contextlib.ExitStack() as stack: + # Enter an exception-catching context before post-processing + # the arguments, in case the post-processing itself raises an + # exception. + if parsed_args.pdb: + stack.enter_context(debug_on_except()) + + # Post processing of arguments + parsed_args = core._parse_args(parsed_args) # pylint: disable=protected-access + + core.build_model_from_args(parsed_args) + if __name__ == "__main__": main() diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 6adb3f4fe1..6b993c07b5 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -26,6 +26,7 @@ minigpt, param_manager, rwkv, + stablelm_3b, ) from mlc_llm.relax_model.commons import create_shard_info_func from mlc_llm.transform import fuse_split_rotary_embedding, rewrite_attention @@ -79,6 +80,37 @@ class BuildArgs: Build with separated embedding layer, only applicable to LlaMa. This feature is in testing stage, and will be formally replaced after massive overhaul of embedding feature for all models and use cases. + cc_path: str + ``/path/to/cross_compiler_path``; currently only used for cross-compile + for nvidia/jetson device. + use_safetensors: bool + Specifies whether to use ``.safetensors`` instead of the default ``.bin`` + when loading in model weights. + enable_batching: bool + Build the model for batched inference. + This is a temporary flag used to control the model execution flow in single- + sequence and batching settings for now. We will eventually merge two flows + in the future and remove this flag then. + no_cutlass_attn: bool + Disable offloading attention operations to CUTLASS. + no_cutlass_norm: bool + Disable offloading layer and RMS norm operations to CUTLASS. + no_cublas: bool + Disable the step that offloads matmul to cuBLAS. Without this flag, + matmul will be offloaded to cuBLAS if quantization mode is ``q0f16`` or + ``q0f32``, target is CUDA and TVM has been built with cuBLAS enabled. + use_cuda_graph: bool + Specifies whether to enable CUDA Graph for the decoder. MLP and QKV + projection between two attention layers are put into a graph. + num_shards: int + Number of shards to split the model into in tensor parallelism multi-gpu + inference. Only useful when ``build_model_only`` is set. + use_flash_attn_mqa: bool + Offload multi-query attention workload to Flash Attention. + pdb: bool + If set, drop into a pdb debugger on error. + use_vllm_attention: bool + Use vLLM paged KV cache and attention kernel, only relevant when enable_batching=True. """ model: str = field( default="auto", @@ -180,21 +212,29 @@ class BuildArgs: "action": "store_true", }, ) - no_cutlass_attn: bool = field( + enable_batching: bool = field( default=False, metadata={ "help": ( - "Disable offloading attention operations to CUTLASS." + "Build the model for batched inference." + "This is a temporary flag used to control the model execution flow in single-" + "sequence and batching settings for now. We will eventually merge two flows" + "in the future and remove this flag then." ), "action": "store_true", }, ) + no_cutlass_attn: bool = field( + default=False, + metadata={ + "help": ("Disable offloading attention operations to CUTLASS."), + "action": "store_true", + }, + ) no_cutlass_norm: bool = field( default=False, metadata={ - "help": ( - "Disable offloading layer and RMS norm operations to CUTLASS." - ), + "help": ("Disable offloading layer and RMS norm operations to CUTLASS."), "action": "store_true", }, ) @@ -204,7 +244,7 @@ class BuildArgs: "help": ( "Disable the step that offloads matmul to cuBLAS. Without this flag, " "matmul will be offloaded to cuBLAS if quantization mode is q0f16 or q0f32, " - "target is CUDA and TVM has been built with cuBLAS enbaled." + "target is CUDA and TVM has been built with cuBLAS enabled." ), "action": "store_true", }, @@ -231,15 +271,23 @@ class BuildArgs: use_flash_attn_mqa: bool = field( default=False, metadata={ - "help": ( - "Offload multi-query attention workload to Flash Attention." - ), + "help": ("Offload multi-query attention workload to Flash Attention."), + "action": "store_true", }, ) - batched: bool = field( + pdb: bool = field( default=False, metadata={ - "help": ("Build the model with batched inference support."), + "help": ("If set, drop into a pdb debugger on error"), + "action": "store_true", + }, + ) + use_vllm_attention: bool = field( + default=False, + metadata={ + "help": ( + "Use vLLM paged KV cache and attention kernel, only relevant when enable_batching=True." + ), "action": "store_true", }, ) @@ -279,13 +327,15 @@ def _parse_args(parsed) -> argparse.Namespace: utils.parse_target(parsed) utils.argparse_postproc_common(parsed) + if parsed.use_vllm_attention: + assert parsed.enable_batching, "--enable_batching is required for using vLLM attention." + assert parsed.target_kind == "cuda", "vLLM attention is only supported for CUDA." + assert tvm.get_global_func("tvm.contrib.vllm.single_query_cached_kv_attention", True), "TVM needs to be built with -DUSE_VLLM=ON." + parsed.artifact_path = os.path.join( parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" ) - if parsed.batched: - parsed.artifact_path += "-batched" - return parsed @@ -378,29 +428,36 @@ def mod_transform_before_build( "decode", ] - if not args.batched: + if not args.use_vllm_attention: model_names += [ "create_kv_cache", "softmax_with_temperature", "get_metadata", ] - if args.batched: + else: + # This is equivalent to prefill but without KV cache. It is used for + # determining the number of paged cache blocks that can be allocated. model_names.append("evaluate") + if args.sep_embed: model_names = ["embed", "prefill_with_embed"] + model_names[1:] + if args.enable_batching: + model_names[2] = "decode_with_embed" if args.model.lower().startswith("rwkv-"): model_names += ["reset_kv_cache"] - mod = param_manager.transform_dequantize(mod) + mod = param_manager.transform_dequantize()(mod) + mod = relax.transform.BundleModelParams()(mod) use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"] mod = mlc_llm.transform.FuseDecodeTranspose(skip_gemm=not use_ft_quant)(mod) if ( - not args.batched + not args.enable_batching and hasattr(config, "num_attention_heads") and hasattr(config, "hidden_size") and hasattr(config, "position_embedding_base") + and getattr(config, "dtype", "float16") == "float16" ): max_seq_len = None if args.max_seq_len > 0: @@ -411,12 +468,11 @@ def mod_transform_before_build( if max_seq_len: num_key_value_heads = config.get_num_key_value_heads() mod = fuse_split_rotary_embedding( - mod, - config.num_attention_heads // args.num_shards, - num_key_value_heads // args.num_shards, - config.hidden_size // args.num_shards, - config.position_embedding_base, - ) + config.num_attention_heads // args.num_shards, + num_key_value_heads // args.num_shards, + config.hidden_size // args.num_shards, + config.position_embedding_base, + )(mod) if args.target_kind == "cuda": patterns = [] @@ -425,15 +481,8 @@ def mod_transform_before_build( if has_cutlass and not args.no_cutlass_attn: if args.use_flash_attn_mqa: - mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=True) - mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=True) - - mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=False) - mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=False) - - if args.batched: - mod["evaluate"] = rewrite_attention(mod["evaluate"], use_flash_mqa=False) - + mod = rewrite_attention(use_flash_mqa=True)(mod) + mod = rewrite_attention(use_flash_mqa=False)(mod) patterns += get_patterns_with_prefix("cutlass.attention") if has_cutlass and not args.no_cutlass_norm: @@ -470,7 +519,7 @@ def mod_transform_before_build( ), annotate_workspace, relax.transform.AllocateWorkspace(), - relax.transform.RunCodegen(options, entry_functions=model_names) + relax.transform.RunCodegen(options, entry_functions=model_names), ] )(mod) @@ -570,7 +619,9 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None: with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": use_cuda_graph}): # The num_input attribute is needed to capture transformed weights passed as input # into a cuda graph. - mod_deploy["decode"] = mod_deploy["decode"].with_attr({"num_input": 3}) + # NOTE: CUDA graph for batching is not enabled and is left as a TODO item. + if not args.enable_batching: + mod_deploy["decode"] = mod_deploy["decode"].with_attr({"num_input": 3}) ex = relax.build(mod_deploy, args.target, system_lib=args.system_lib) output_filename = f"{args.model}-{args.quantization.name}-{target_kind}.{args.lib_format}" @@ -593,6 +644,9 @@ def build_model_from_args(args: argparse.Namespace): "`num_shards` should be used together with " "`--build-model-only` and `--convert-weight-only`" ) + use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"] + if use_ft_quant: + raise ValueError("Multi-GPU deployments are not available for ft quantization.") os.makedirs(args.artifact_path, exist_ok=True) if args.debug_dump: os.makedirs(os.path.join(args.artifact_path, "debug"), exist_ok=True) @@ -601,36 +655,47 @@ def build_model_from_args(args: argparse.Namespace): use_cache = args.use_cache and os.path.isfile(cache_path) if args.sep_embed and args.model_category != "llama": raise ValueError(f"separate embedding not supported on {args.model}") - if args.model_category != "minigpt": + + if args.model_category == "minigpt": + # Special case for minigpt, which neither provides nor requires a configuration. + config = {} + else: with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f: config = json.load(i_f) + if not use_cache or args.convert_weight_only: - if args.model_category in ["llama", "mistral"] and args.batched: - mod, param_manager, params, model_config = llama_batched_vllm.get_model(args, config) - elif args.model_category == "llama": - mod, param_manager, params, model_config = llama.get_model(args, config) - elif args.model_category == "mistral": - mod, param_manager, params, model_config = llama.get_model(args, config) - elif args.model_category == "gpt_neox": - mod, param_manager, params, model_config = gpt_neox.get_model(args, config) - elif args.model_category == "gpt_bigcode": - mod, param_manager, params, model_config = gpt_bigcode.get_model(args, config) - elif args.model_category == "minigpt": - mod, param_manager, params, model_config = minigpt.get_model(args) - elif args.model_category == "gptj": - mod, param_manager, params, model_config = gptj.get_model(args, config) - elif args.model_category == "rwkv" or args.model_category == "rwkv_world": - mod, param_manager, params, model_config = rwkv.get_model(args, config) - elif args.model_category == "chatglm": - mod, param_manager, params, model_config = chatglm.get_model(args, config) - else: - raise ValueError(f"Model {args.model} not supported") + model_generators = { + "llama": llama, + "mistral": llama, + "stablelm_epoch": stablelm_3b, + "gpt_neox": gpt_neox, + "gpt_bigcode": gpt_bigcode, + "minigpt": minigpt, + "gptj": gptj, + "rwkv": rwkv, + "rwkv_world": rwkv, + "chatglm": chatglm, + } + + if args.use_vllm_attention: + model_generators["llama"] = llama_batched_vllm + model_generators["mistral"] = llama_batched_vllm + + assert args.model_category in model_generators, f"Model {args.model} not supported" + + mod, param_manager, params, model_config = model_generators[args.model_category].get_model( + args, config + ) for qspec_updater_class in param_manager.qspec_updater_classes: qspec_updater = qspec_updater_class(param_manager) qspec_updater.visit_module(mod) if not args.build_model_only: + # Run pre-quantization if provided. + args.model_path = param_manager.run_pre_quantize(args.model_path) + param_manager.init_torch_pname_to_bin_name(args.use_safetensors) + new_params = utils.convert_weights(param_manager, params, args) utils.save_params(new_params, args.artifact_path) if args.model_category != "minigpt": diff --git a/mlc_llm/models/__init__.py b/mlc_llm/models/__init__.py deleted file mode 100644 index 380ea83505..0000000000 --- a/mlc_llm/models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Model definition using PyTorch-like nn.Module API""" -from . import llama, llama_param_map -from .model_config_base import ModelConfig diff --git a/mlc_llm/param_loader/__init__.py b/mlc_llm/param_loader/__init__.py deleted file mode 100644 index dfc748d3f6..0000000000 --- a/mlc_llm/param_loader/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Utilities for loading parameters from specific formats, for example, HuggingFace PyTorch, -HuggingFace SafeTensor, GGML, AutoGPTQ. -""" -from .hf_torch_loader import HFTorchLoader -from .param_mapping import ParameterMapping diff --git a/mlc_llm/param_loader/hf_torch_loader.py b/mlc_llm/param_loader/hf_torch_loader.py deleted file mode 100644 index 6c12af9181..0000000000 --- a/mlc_llm/param_loader/hf_torch_loader.py +++ /dev/null @@ -1,191 +0,0 @@ -"""A weight loader for HuggingFace's PyTorch format""" -import gc -import json -import logging -import time -from collections import defaultdict -from pathlib import Path -from typing import Dict, List - -import numpy as np - -from .param_mapping import ParameterMapping - -logger = logging.getLogger(__name__) - - -class HFTorchLoader: - """A loader loading HuggingFace's PyTorch format and converts them to MLC's parameters. - - Attributes - ---------- - param_map : ParameterMapping - The parameter mapping from MLC to HuggingFace PyTorch. - - torch_to_path : Dict[str, Path] - A mapping from PyTorch parameter name to the path of the file containing it. - - cached_files : Dict[Path, Dict[str, np.ndarray]] - A cache of the loaded files. The key is the path of the file, and the value is a mapping - from parameter name to the parameter value. - - stats_load_time_sec : float - The time spent on loading the files in seconds. - - stats_load_data_gb : float - The amount of data loaded in GB. - """ - - param_map: ParameterMapping - torch_to_path: Dict[str, Path] - cached_files: Dict[Path, Dict[str, np.ndarray]] - stats_load_time_sec: float - stats_load_data_gb: float - - def __init__(self, config_path: Path, param_map: ParameterMapping) -> None: - """Create a parameter loader from HuggingFace PyTorch format. - - Parameters - ---------- - config_path : pathlib.Path - Path to the torch indexing file, usually `pytorch_model.bin.index.json` in the repo. - param_map : ParameterMapping - The parameter mapping from MLC to HuggingFace PyTorch. - """ - with config_path.open("r", encoding="utf-8") as in_file: - torch_weight_map = json.load(in_file)["weight_map"] - self.param_map = param_map - self.torch_to_path = {} - for torch_name, path_str in torch_weight_map.items(): - path = config_path.parent / path_str - self.torch_to_path[torch_name] = path - self.cached_files = {} - self.stats_load_time_sec = 0.0 - self.stats_load_data_gb = 0.0 - - used_torch_names = sum(param_map.name_map.values(), ()) - # Check 1. All PyTorch parameters in the weight files are used unless explicitly specified - unused_torch_names = set(torch_weight_map) - set(used_torch_names) - param_map.unused_params - if unused_torch_names: - logger.warning( - "Unused torch parameters: %s", - ", ".join(sorted(unused_torch_names)), - ) - # Check 2. All PyTorch parameters required are stored in the weight files - nonexistent_torch_names = set(used_torch_names) - set(torch_weight_map) - if nonexistent_torch_names: - raise ValueError( - "The following torch parameters do not exist in the weight files:\n " - + "\n ".join(sorted(nonexistent_torch_names)), - ) - - def suggest_loading_order(self) -> List[str]: - """Suggest a loading order for MLC parameters. - - Returns - ------- - order : List[str] - A list of MLC parameters in the order that ensures file locality. - """ - # Step 1. Build a map from path to torch parameters - path_to_torch: Dict[Path, List[str]] = defaultdict(list) - for torch_name, path in self.torch_to_path.items(): - path_to_torch[path].append(torch_name) - # Step 2. Build a map from torch parameters to MLC parameters - torch_to_mlc = defaultdict(list) - for mlc_name, torch_names in self.param_map.name_map.items(): - for torch_name in torch_names: - torch_to_mlc[torch_name].append(mlc_name) - # Step 3. Construct the ordering that ensures file locality - order = [] - for _, torch_names in path_to_torch.items(): - for torch_name in torch_names: - for mlc_name in torch_to_mlc[torch_name]: - order.append(mlc_name) - return order - - def load_param(self, name: str) -> np.ndarray: - """Load a MLC parameter according to its name. - - Parameters - ---------- - name : str - The name of the MLC parameter. - - Returns - ------- - param : np.ndarray - The parameter value as a numpy array. Note that if the parameter is stored in bfloat16, - it will be converted to float32. - """ - mlc_name = name - torch_names = self.param_map.name_map[mlc_name] - files_required = {self.torch_to_path[p] for p in torch_names} - files_existing = set(self.cached_files.keys()) - files_to_load = files_required - files_existing - files_to_unload = files_existing - files_required - - # Step 1. When there is some file to unloaded: - # - If no pending file load: unloading is deferred as there is no gain in peak memory usage; - # - Need to load files: unload immediately to save memory and make space for the new files. - if files_to_load: - for path in files_to_unload: - self._unload_file(path) - # Step 2. Load all the files needed - for path in files_to_load: - self._load_file(path) - # Step 3. Collect all torch parameters in order - torch_names = [self._retrieve_torch_param_from_cache(name) for name in torch_names] - # Step 4. Apply the mapping function - map_func = self.param_map.map_func[mlc_name] - return map_func(*torch_names) - - def __enter__(self) -> "HFTorchLoader": - self.stats_load_time_sec = 0.0 - self.stats_load_data_gb = 0.0 - return self - - def __exit__(self, exc_type, exc_value, traceback) -> None: - cached_files = list(self.cached_files.keys()) - for path in cached_files: - self._unload_file(path) - logger.info( - "Time used in PyTorch loading: %.3f sec. Total %.3f GB loaded", - self.stats_load_time_sec, - self.stats_load_data_gb, - ) - - def _load_file(self, path: Path) -> None: - import torch # pylint: disable=import-outside-toplevel - - logging.info("Loading PyTorch parameters from: %s", path) - - start_time = time.time() - result = {} - for name, param in torch.load(path, map_location=torch.device("cpu")).items(): - param = param.detach().cpu() - dtype = str(param.dtype) - if dtype == "torch.bfloat16": - param = param.float() - param = param.numpy() - self.stats_load_data_gb += param.nbytes / (1024**3) - result[name] = param - logging.debug(' Parameter: "%s", shape: %s, dtype: %s', name, param.shape, dtype) - self.cached_files[path] = result - self.stats_load_time_sec += time.time() - start_time - - def _unload_file(self, path: Path) -> None: - logging.debug("Unloading PyTorch weight file: %s", path) - - start_time = time.time() - del self.cached_files[path] - gc.collect() - self.stats_load_time_sec += time.time() - start_time - - def _retrieve_torch_param_from_cache(self, name: str) -> np.ndarray: - assert name in self.torch_to_path - path = self.torch_to_path[name] - assert path in self.cached_files - cache = self.cached_files[path] - assert name in cache - return cache[name] diff --git a/mlc_llm/param_loader/param_mapping.py b/mlc_llm/param_loader/param_mapping.py deleted file mode 100644 index c378b30268..0000000000 --- a/mlc_llm/param_loader/param_mapping.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Parameter mapping for converting different LLM implementations to MLC LLM.""" -import dataclasses -from typing import Callable, Dict, Set, Tuple - -import numpy as np - - -@dataclasses.dataclass -class ParameterMapping: - """Mapping from a parameter name in MLC LLM's model definition to its potential source, - for example, from MLC parameter "model.layers.2.post_attention_layernorm.weight" to PyTorch's - parameter correspondingly. - - Parameters - ---------- - name_map : Dict[str, Tuple[str, ...]] - A dictionary that maps the name of a parameter to its source. For example, - in Llama2, the source of MLC parameter "model.layers.0.self_attn.qkv_proj.weight" from - huggingface torch are: - - - "model.layers.0.self_attn.q_proj.weight" - - "model.layers.0.self_attn.k_proj.weight" - - "model.layers.0.self_attn.v_proj.weight" - - map_func: Dict[str, Callable[[np.ndarray, ...], np.ndarray]] - A dictionary that maps the name of a parameter to a function that combines the source - parameters into the MLC parameter. For example, for the above example, the function - would be: `lambda q, k, v: np.concatenate([q, k, v], axis=0)`. - - unused_params : Set[str] - Parameter names in the source weights that are not used in the MLC LLM model definition. - """ - - name_map: Dict[str, Tuple[str, ...]] - map_func: Dict[str, Callable[[np.ndarray, ...], np.ndarray]] - unused_params: Set[str] = dataclasses.field(default_factory=dict) diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index 49b9c24b43..2309bdd92e 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -26,10 +26,10 @@ ) -def apply_rotary_pos_emb(q, k, positions, position_embedding_base, offset: int = 0): - def f_rotary_embedding(tensor, pos_tensor, offset): +def apply_rotary_pos_emb(q, k, positions, position_embedding_base): + def f_rotary_embedding(tensor, pos_tensor): def rotary_compute(*idx): - pos = (offset + pos_tensor[idx[0]]).astype("float32") + pos = pos_tensor[idx[0]].astype("float32") return rotary_modulate_by_freq( tensor, idx, @@ -39,19 +39,15 @@ def rotary_compute(*idx): return tvm.te.compute(tensor.shape, rotary_compute, name="rotary") - q_embed = nn.emit_te( - f_rotary_embedding, q, positions, offset, primfunc_name_hint="rotary_embedding" - ) - k_embed = nn.emit_te( - f_rotary_embedding, k, positions, offset, primfunc_name_hint="rotary_embedding" - ) + q_embed = nn.emit_te(f_rotary_embedding, q, positions, primfunc_name_hint="rotary_embedding") + k_embed = nn.emit_te(f_rotary_embedding, k, positions, primfunc_name_hint="rotary_embedding") return q_embed, k_embed class LlamaAttentionBatched(LlamaAttentionBase): def __init__(self, config: LlamaConfig, head_mapping: relax.Constant): super().__init__(config) - self.head_mapping = head_mapping + self.head_mapping = head_mapping # (num_heads,), used by vLLM for multi-query attention self.sliding_window = None if config.sliding_window: @@ -59,15 +55,17 @@ def __init__(self, config: LlamaConfig, head_mapping: relax.Constant): def forward( self, - hidden_states: relax.Expr, - positions: relax.Expr, - seq_lens: relax.Expr, + hidden_states: relax.Expr, # (num_token, hidden_size) + positions: relax.Expr, # (num_token,), for batched RoPE + seq_lens: relax.Expr, # (num_seq,) kv_cache: Optional[Tuple[relax.Expr, relax.Expr]], - slot_mapping: Optional[relax.Expr], - max_seqlen: Optional[relax.Expr], # Must be on CPU - seqstart: Optional[relax.Expr], # For prefill - block_tables: Optional[relax.Expr], # For decode - indices_within_window: Optional[relax.Expr], # For prefill with sliding-window attention + slot_mapping: Optional[relax.Expr], # (num_token,) + max_seqlen: Optional[relax.Expr], # (), must be on CPU + seqstart: Optional[relax.Expr], # (num_seq + 1,), for prefill + block_tables: Optional[relax.Expr], # (num_seq, max_num_blocks_per_seq), for decode + indices_within_window: Optional[ + relax.Expr + ], # (num_cached_total,), for prefill with sliding-window attention ): num_tokens, _ = hidden_states.struct_info.shape @@ -77,9 +75,7 @@ def forward( (num_tokens, self.num_key_value_heads, self.head_dim), ) - queries, keys = apply_rotary_pos_emb( - queries, keys, positions, self.position_embedding_base, offset=0 - ) + queries, keys = apply_rotary_pos_emb(queries, keys, positions, self.position_embedding_base) if kv_cache: # Paged KV cache update @@ -113,6 +109,7 @@ def forward( k_cache = v_cache = None if seqstart: + # Prefill, batched attention over variable sequence lengths attn_output = nn.emit( attention_var_len( nn.emit(expand_dims(queries, axis=0)), @@ -125,6 +122,7 @@ def forward( ) ) else: + # Decode, using vLLM kernel attn_output = nn.emit( relax.op.call_dps_packed( "tvm.contrib.vllm.single_query_cached_kv_attention", @@ -294,14 +292,35 @@ def __init__( def forward( self, - input_ids: relax.Expr, - positions: relax.Expr, - seq_lens: relax.Expr, + input_ids: relax.Expr, # (num_token,) + positions: relax.Expr, # (num_token,), for batched RoPE + seq_lens: relax.Expr, # (num_seq,) kv_caches: Optional[relax.Expr], # For prefill and decode, not needed for evaluate - slot_mapping: Optional[relax.Expr], # For prefill and decode, not needed for evaluate - block_tables: Optional[relax.Expr], # For decode - indices_within_window: Optional[relax.Expr], # For prefill with sliding-window attention + slot_mapping: Optional[ + relax.Expr + ], # (num_token,), for prefill and decode, not needed for evaluate + block_tables: Optional[relax.Expr], # (num_seq, max_num_blocks_per_seq), for decode + indices_within_window: Optional[ + relax.Expr + ], # (num_cached_total,), for prefill with sliding-window attention ): + """ + In vLLM, the paged KV cache is simply a pair of tensors, one for keys and the other + for values. The tensor has shape (num_blocks, num_kv_heads, head_size, block_size). + (In practice, the key cache has a slightly different shape for an efficiency reason, + but that's not important.) + + The mapping between sequences / tokens to blocks is specified by two inputs. + - block_tables: A list of block IDs allocated for the sequence. + - slot_mapping: A linear index into the 2D grid (num_blocks, block_size), for each token. + + Support for sliding-window attention is realized by making a block table a circular buffer. + So the length of a block table for each sequence is at most ceil(window_size / block_size). + + With sliding window, not all past K / V values need to be cached during prefill. + The last input, indices_within_window, tells which tokens among (num_token,) need to have + their K / V values cached. + """ if self.num_shards > 1: input_ids = nn.emit(ccl.broadcast_from_worker0(input_ids)) positions = nn.emit(ccl.broadcast_from_worker0(positions)) @@ -313,9 +332,13 @@ def forward( if block_tables: block_tables = nn.emit(ccl.broadcast_from_worker0(block_tables)) + if indices_within_window: + indices_within_window = nn.emit(ccl.broadcast_from_worker0(indices_within_window)) + is_prompt = block_tables is None if is_prompt: # prefill and evaluate + # https://github.com/apache/tvm/issues/15851 for why we need to use Thrust cumsum = nn.emit( relax.op.call_dps_packed( "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info @@ -337,6 +360,7 @@ def forward( ) if is_prompt: + # Extract logits for the last token in each sequence def get_logits_last_tokens(x, seq_len_tensor, seqstart): return te.compute( @@ -424,10 +448,11 @@ def create_evaluate_func( bb: relax.BlockBuilder, param_manager: ParamManager, config: LlamaConfig, - cpu_dev, + cpu_dev: VDevice, quant_scheme: QuantizationScheme, sep_embed: bool = False, ) -> None: + """Evaluate logits for the last token in each sequence. Same as prefill but without KV cache.""" func_name = "evaluate" num_token = tvm.tir.Var("num_token", "int64") @@ -468,10 +493,15 @@ def create_encoding_func( bb: relax.BlockBuilder, param_manager: ParamManager, config: LlamaConfig, - cpu_dev, + cpu_dev: VDevice, quant_scheme: QuantizationScheme, sep_embed: bool = False, ) -> None: + """Batched prefill with vLLM paged KV cache. + + The batched attention op is intended to be offloaded to CUTLASS or Flash Attention + via BYOC. + """ func_name = "prefill_with_embed" if sep_embed else "prefill" num_token = tvm.tir.Var("num_token", "int64") @@ -533,9 +563,10 @@ def create_decoding_func( bb: relax.BlockBuilder, param_manager: ParamManager, config: LlamaConfig, - cpu_dev, + cpu_dev: VDevice, quant_scheme: QuantizationScheme, ) -> None: + """Batched decoding with vLLM paged KV cache.""" func_name = "decode" num_seq = tvm.tir.Var("num_seq", "int64") @@ -620,10 +651,10 @@ def get_model(args, hf_config): create_encoding_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) create_decoding_func(bb, param_manager, config, cpu_dev, args.quantization) - bb.get().update_global_info("vdevice", [cpu_dev]) - mod = bb.get() + mod.update_global_info("vdevice", [cpu_dev]) + if args.build_model_only: return mod, param_manager, None, config diff --git a/mlc_llm/relax_model/minigpt.py b/mlc_llm/relax_model/minigpt.py index 7bd30e70ed..96126bbf5b 100644 --- a/mlc_llm/relax_model/minigpt.py +++ b/mlc_llm/relax_model/minigpt.py @@ -502,7 +502,7 @@ def create_embed_func( bb.update_func(gv, mod[gv].with_attr("num_input", 1)) -def get_model(args): +def get_model(args, _config): model_name = args.model model_path = args.model_path diff --git a/mlc_llm/relax_model/param_manager.py b/mlc_llm/relax_model/param_manager.py index 138f04f769..7f0751b2a0 100644 --- a/mlc_llm/relax_model/param_manager.py +++ b/mlc_llm/relax_model/param_manager.py @@ -13,6 +13,7 @@ from .. import quantization from .modules import named_parameters +from ..transform import ReorderTransformFunc def f_default_compute_relax_param(relax_pname: str, torch_params: List[Any]) -> Any: @@ -268,12 +269,37 @@ def register_params( relax_param, getattr(quantization_scheme, quant_kind.name), func_name, - getattr(relax_param, "shard_dim", None), - getattr(relax_param, "shard_strategy", None), + relax_param.__dict__.get("shard_dim", None), + relax_param.__dict__.get("shard_strategy", None), ) self.params_in_func[func_name].append(param) + def run_pre_quantize(self, model_path: str): + if self.f_run_prequantize is not None: + model_path = self.f_run_prequantize(model_path) + + self.model_path = model_path + return model_path + + def init_torch_pname_to_bin_name(self, use_safetensors: bool): + assert hasattr(self, "model_path"), ( + "Must call either set_param_loading_func or run_pre_quantize " + "before init_torch_pname_to_bin_name" + ) + + if self.pidx2pname: + mapping = load_torch_pname2binname_map( + self.model_path, + use_safetensors, + set(self.pidx2pname.values()), + self.f_convert_pname_fwd, + ) + else: + mapping = {} + + self.torch_pname2binname = mapping + def set_param_loading_func( self, model_path: str, @@ -343,7 +369,7 @@ def set_param_loading_func( else: self.pidx2pname = dict() - def transform_dequantize(self, mod: tvm.IRModule) -> tvm.IRModule: + def transform_dequantize(self) -> tvm.ir.transform.Pass: """Apply dequantization to the input IRModule. Parameters @@ -360,38 +386,48 @@ def transform_dequantize(self, mod: tvm.IRModule) -> tvm.IRModule: The IRModule updated with the dequantization computation. """ - # For each Relax function in the input IRModule (e.g., "prefill"), - # we create its input relax.Var of all the quantized data, and - # store the mapping from function name to the var. - func2param_var: Dict[str, relax.Var] = {} - for gv, func in mod.functions.items(): - if not isinstance(func, relax.Function): - continue - if func.attrs is None or not "num_input" in func.attrs: - continue - func2param_var[gv.name_hint] = relax.Var( - "params", self.get_quantized_param_info(gv.name_hint) - ) + @tvm.ir.transform.module_pass(opt_level=0, name="ParamManager.transform_dequantize") + def transform_func(mod: tvm.IRModule, _context) -> tvm.IRModule: + # For each Relax function in the input IRModule (e.g., "prefill"), + # we create its input relax.Var of all the quantized data, and + # store the mapping from function name to the var. + func_name_to_quantized_params: Dict[str, List[relax.Var]] = {} - # Cache mapping to avoid duplicate dequantization. - dequantized_cache: Dict[relax.Var, relax.Var] = {} + for gv, func in mod.functions.items(): + if isinstance(func, relax.Function) and func.attrs and "num_input" in func.attrs: + quantized_param_info = self.get_quantized_param_info(gv.name_hint) + param_vars = [ + relax.Var(f"param_{i}", info) + for i, info in enumerate(quantized_param_info.fields) + ] + func_name_to_quantized_params[gv.name_hint] = param_vars - # Define a var replacement function for applying dequantization. - def f_replace(var: relax.Var, bb: relax.BlockBuilder, func_name: str) -> relax.Var: - if var in dequantized_cache: - return dequantized_cache[var] - assert var in self.func_raw_param_map - func_name, param = self.func_raw_param_map[var] - dequantized = self._dequantize(param, func2param_var[func_name], bb, func_name) - dequantized_cache[var] = dequantized - return dequantized + # Cache mapping to avoid duplicate dequantization. + dequantized_cache: Dict[relax.Var, relax.Var] = {} - # Create the function mutator for applying dequantization. - replacer = ParamReplacer(mod, func2param_var, f_replace) - # Update the input IRModule with dequantization. - mod = replacer.transform() + # Define a var replacement function for applying dequantization. + def f_replace(var: relax.Var, bb: relax.BlockBuilder) -> relax.Var: + if var in dequantized_cache: + return dequantized_cache[var] + assert var in self.func_raw_param_map - return mod + func_name, param = self.func_raw_param_map[var] + quantized_params = func_name_to_quantized_params[func_name] + relevant_quantized_params = [quantized_params[i] for i in self.param2qrange[param]] + + dequantized = self._dequantize(param, relevant_quantized_params, bb, func_name) + + dequantized_cache[var] = dequantized + return dequantized + + # Create the function mutator for applying dequantization. + replacer = ParamReplacer(mod, func_name_to_quantized_params, f_replace) + # Update the input IRModule with dequantization. + mod = replacer.transform() + + return mod + + return transform_func def get_quantized_param_info(self, func_name: str) -> List[relax.TensorStructInfo]: bb = relax.BlockBuilder() @@ -671,10 +707,9 @@ def _register_param( def _dequantize( self, param: Parameter, - quantized_tuple: relax.Var, + qparams: List[relax.Var], bb: relax.BlockBuilder, func_name: str, - qparams: List[relax.Var] = None, ) -> relax.Var: """Applying dequantization to the input parameter. This method is called by `transform_module` below, and is not @@ -685,30 +720,13 @@ def _dequantize( param : Parameter The parameter whose quantized tensors are to be dequantized. - quantized_tuple : relax.Var - The relax.Var of the quantized tensors of all parameters in the model. - - bb : relax.BlockBuilder - The Relax BlockBuilder used for inserting the dequantization computations. - - func_name : str - The name of the function which dequantization is applied to. - qparams : List[relax.Var] - The quantized parts of the parameter. - By default it is `None`, in which case we will get the quantized parts - from `quantized_tuple`. + The relax.Var of the quantized tensors of all parameters in the model. Returns ------- The dequantized parameter, in the form of a relax.Var. """ - if not qparams: - # Get the corresponding Relax vars of the quantized tensors of this parameter. - qparams: List[relax.Var] = [] - for qparam_idx in self.param2qrange[param]: - qparams.append(bb.emit(relax.TupleGetItem(quantized_tuple, qparam_idx))) - # Get the dequantization function of this parameter. f_dequantize = param.quant_spec.get_dequantize_func( param_info=param.param_info_dict[func_name], @@ -726,6 +744,33 @@ def _dequantize( # Apply the dequantization function. return bb.emit(f_dequantize(bb, qparams)) + def create_parameter_transformation(self, optimize_parameter_order: bool = True): + """Produce an IRModule that can transform the parameters + + Parameters + ---------- + optimize_parameter_order: bool + + If true, reorder the parameter transformations to + prioritize operations that use a currently-open file. If + false, transform the parameters in their default order. + + Returns + ------- + tvm.IRModule + The transformation module + + """ + mod = _create_quantize_func(self) + if optimize_parameter_order: + reorder_pass = ReorderTransformFunc( + self.pidx2pname, + self.torch_pname2binname, + self.f_convert_pname_fwd, + ) + mod = reorder_pass(mod) + return mod + @mutator class ParamReplacer(PyExprMutator): @@ -736,7 +781,7 @@ class ParamReplacer(PyExprMutator): mod : tvm.IRModule The IRModule of the model to be updated. - func2param_var : Dict[str, relax.Var] + func_name_to_quantized_params : Dict[str, List[relax.Var]] The mapping from each function name to its input var of quantized data tuple. f_replace : Callable[[relax.Var, relax.BlockBuilder], relax.Var] @@ -748,7 +793,7 @@ class ParamReplacer(PyExprMutator): """ mod: tvm.IRModule - func2param_var: Dict[str, relax.Var] + func_name_to_quantized_params: Dict[str, List[relax.Var]] f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var] param_set: Set[relax.Var] @@ -757,12 +802,12 @@ class ParamReplacer(PyExprMutator): def __init__( self, mod: tvm.IRModule, - func2param_var: Dict[str, relax.Var], + func_name_to_quantized_params: Dict[str, relax.Var], f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var], ): super().__init__(mod) self.mod = mod - self.func2param_var = func2param_var + self.func_name_to_quantized_params = func_name_to_quantized_params self.f_replace = f_replace self.cur_func_name = "" @@ -774,31 +819,31 @@ def transform(self) -> tvm.IRModule: continue assert ( - gv.name_hint in self.func2param_var - ), f"{gv.name_hint} not in {self.func2param_var}" - self.cur_func_name = gv.name_hint - updated_func = self.rewrite_func(func, self.func2param_var[gv.name_hint]) + gv.name_hint in self.func_name_to_quantized_params + ), f"{gv.name_hint} not in {self.func_name_to_quantized_params}" + updated_func = self.rewrite_func(func, self.func_name_to_quantized_params[gv.name_hint]) updated_func = remove_all_unused(updated_func) self.builder_.update_func(gv, updated_func) return self.builder_.get() - def rewrite_func(self, func: Function, param_var: relax.Var) -> relax.Function: + def rewrite_func(self, func: Function, quantized_params: List[relax.Var]) -> relax.Function: num_input = int(func.attrs["num_input"]) self.param_set = set(func.params[num_input:]) body = self.visit_expr(func.body) return relax.Function( - params=func.params[:num_input] + [param_var], + params=func.params[:num_input] + quantized_params, body=body, ret_struct_info=func.ret_struct_info, is_pure=func.is_pure, attrs=func.attrs, - ).without_attr("num_input") + ) def visit_var_(self, var: Var) -> Expr: - if var not in self.param_set: + if var in self.param_set: + return self.f_replace(var, self.builder_) + else: return super().visit_var_(var) - return self.f_replace(var, self.builder_, self.cur_func_name) ################################################################## @@ -868,7 +913,7 @@ def load_torch_pname2binname_map( return torch_pname2binname -def create_quantize_func(param_manager: ParamManager) -> tvm.IRModule: +def _create_quantize_func(param_manager: ParamManager) -> tvm.IRModule: """Construct the Relax function which computes quantization. This method is called by `transform_module` below, and is not directly invoked outside the class. diff --git a/mlc_llm/relax_model/stablelm_3b.py b/mlc_llm/relax_model/stablelm_3b.py new file mode 100644 index 0000000000..89c15a7955 --- /dev/null +++ b/mlc_llm/relax_model/stablelm_3b.py @@ -0,0 +1,899 @@ +import math +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple + +import numpy as np +import tvm +from tvm import relax, te +from tvm.relax.op import ccl +from tvm.relax.op.nn import layer_norm +from tvm.relax.testing import nn +from tvm.script import relax as R + +from ..quantization import ParamQuantKind, QuantizationScheme +from .commons import create_metadata_func +from .modules import ModuleList, RotaryEmbedding +from .param_manager import ParamManager +from .llama import Embedding, Linear + + +@dataclass +class StableLM3bConfig: + def __init__( + self, + dtype="float32", + max_sequence_length=4096, + vocab_size=50304, + hidden_size=2560, + intermediate_size=6912, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + initializer_range=0.02, + norm_eps=1e-5, + pad_token_id=-1, + bos_token_id=0, + eos_token_id=1, + tie_word_embeddings=False, + position_embedding_base=10000, + combine_matmul=True, + num_shards=1, + build_model_only=False, + convert_weight_only=False, + **kwargs, + ): + self.dtype = dtype + self.max_sequence_length = max_sequence_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.norm_eps = norm_eps + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.tie_word_embeddings = tie_word_embeddings + self.position_embedding_base = position_embedding_base + self.combine_matmul = combine_matmul + if build_model_only and num_shards > 1: + self.num_shards = num_shards + else: + self.num_shards = 1 + self.kwargs = kwargs + + def get_num_key_value_heads(self): + if self.num_key_value_heads is None: + return self.num_attention_heads + return self.num_key_value_heads + + +class LayerNorm(nn.Module): + def __init__( + self, + hidden_size, + dtype, + eps=1e-5, + ): + super().__init__() + self.eps = eps + self.weight = nn.Parameter((hidden_size,), dtype="float16", name="weight") + self.bias = nn.Parameter((hidden_size,), dtype="float16", name="bias") + + def forward(self, x: relax.Expr) -> relax.Var: + x = nn.emit( + layer_norm( + x, + gamma=self.weight, + beta=self.bias, + axes=-1, + epsilon=self.eps, + ) + ) + return x + + +class StableLM3bMLP(nn.Module): + def __init__(self, config: StableLM3bConfig): + self.combine_matmul = config.combine_matmul + self.num_shards = config.num_shards + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size // self.num_shards + dtype = config.dtype + if self.combine_matmul: + self.gate_up_proj = Linear(hidden_size, 2 * intermediate_size, dtype=dtype, bias=False) + self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) + self.gate_up_proj.weight.shard_dim = 0 + self.down_proj.weight.shard_dim = 1 + else: + self.gate_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) + self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) + self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) + self.gate_proj.weight.shard_dim = 0 + self.up_proj.weight.shard_dim = 0 + self.down_proj.weight.shard_dim = 1 + + def forward(self, x): + if self.combine_matmul: + gate_up_results = nn.emit( + relax.op.split( + self.gate_up_proj(x), + indices_or_sections=2, + axis=-1, + ) + ) + gate_result = relax.TupleGetItem(gate_up_results, 0) + up_result = relax.TupleGetItem(gate_up_results, 1) + else: + gate_result = self.gate_proj(x) + up_result = self.up_proj(x) + + result = self.down_proj(relax.op.nn.silu(gate_result) * up_result) + return result + + +class StableLM3bAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: StableLM3bConfig, rotary_embedding: RotaryEmbedding): + dtype = config.dtype + self.num_shards = config.num_shards + self.hidden_size = config.hidden_size + self.num_key_value_heads = ( + config.num_key_value_heads is None + and config.num_attention_heads + or config.num_key_value_heads + ) // config.num_shards + self.num_query_heads = config.num_attention_heads // self.num_shards + self.head_dim = self.hidden_size // config.num_attention_heads + self.position_embedding_base = config.position_embedding_base + self.rotary_embedding = rotary_embedding + + self.combine_matmul = config.combine_matmul + if self.combine_matmul: + self.query_key_value_proj = Linear( + self.hidden_size, + (self.num_query_heads + 2 * self.num_key_value_heads) * self.head_dim, + dtype=dtype, + bias=False, + ) + self.query_key_value_proj.weight.shard_dim = 0 + else: + self.q_proj = Linear( + self.hidden_size, + self.num_query_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.k_proj = Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.v_proj = Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.q_proj.weight.shard_dim = 0 + self.k_proj.weight.shard_dim = 0 + self.v_proj.weight.shard_dim = 0 + + self.o_proj = Linear( + self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=False + ) + self.o_proj.weight.shard_dim = 1 + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_value: Tuple[relax.Expr], + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Optional[relax.Expr], Optional[Tuple[relax.Expr]]]: + from tvm.relax.op import ( + astype, + matmul, + maximum, + permute_dims, + reshape, + split, + squeeze, + ) + from tvm.relax.op.nn import softmax + + bsz, q_len, _ = hidden_states.struct_info.shape + assert bsz == 1, "Only support batch size 1 at this moment." + + if self.combine_matmul: + qkv_states = nn.emit( + split( + self.query_key_value_proj(hidden_states), + indices_or_sections=[ + self.num_query_heads * self.head_dim, + (self.num_query_heads + self.num_key_value_heads) * self.head_dim, + ], + axis=-1, + ) + ) + query_states = relax.TupleGetItem(qkv_states, 0) + key_states = relax.TupleGetItem(qkv_states, 1) + value_states = relax.TupleGetItem(qkv_states, 2) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = nn.emit( + reshape( + query_states, + (bsz, q_len, self.num_query_heads, self.head_dim), + ), + ) + key_states = nn.emit( + reshape( + key_states, + (bsz, q_len, self.num_key_value_heads, self.head_dim), + ), + ) + value_states = nn.emit( + reshape( + value_states, + (bsz, q_len, self.num_key_value_heads, self.head_dim), + ), + ) + + kv_seq_len = all_seq_len_shape.struct_info.values[0] + offset = kv_seq_len - q_len + query_states, key_states = self.rotary_embedding(query_states, key_states, offset) + # [bsz, t, nh, hd] + + kv_states_shape = key_states.struct_info.shape + kv_states_dtype = key_states.struct_info.dtype + assert kv_states_shape[0] == 1 # bsz + kv_states_shape = R.shape( + [kv_states_shape[0], kv_seq_len, kv_states_shape[2], kv_states_shape[3]] + ) + kv_cache_shape = R.shape([kv_seq_len, kv_states_shape[2], kv_states_shape[3]]) + + squeezed_key = nn.emit(squeeze(key_states, axis=0)) + squeezed_value = nn.emit(squeeze(value_states, axis=0)) + k_cache, v_cache = past_key_value + f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") + k_cache = nn.emit( + relax.Call( + f_kv_cache_append, + args=[k_cache, squeezed_key], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + v_cache = nn.emit( + relax.Call( + f_kv_cache_append, + args=[v_cache, squeezed_value], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + past_key_value = (k_cache, v_cache) + f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") + k_cache = nn.emit( + relax.Call( + f_kv_cache_view, + args=[k_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)], + ) + ) + v_cache = nn.emit( + relax.Call( + f_kv_cache_view, + args=[v_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)], + ) + ) + key_states = nn.emit(reshape(k_cache, kv_states_shape)) + value_states = nn.emit(reshape(v_cache, kv_states_shape)) + if self.num_key_value_heads != self.num_query_heads: + n_rep = self.num_query_heads // self.num_key_value_heads + key_states = nn.emit(relax.op.repeat(key_states, n_rep, axis=2)) + value_states = nn.emit(relax.op.repeat(value_states, n_rep, axis=2)) + + query_states = nn.emit(permute_dims(query_states, [0, 2, 1, 3])) + key_states = nn.emit(permute_dims(key_states, [0, 2, 1, 3])) + value_states = nn.emit(permute_dims(value_states, [0, 2, 1, 3])) + + attn_weights = nn.emit( + matmul(query_states, permute_dims(key_states, [0, 1, 3, 2])) + / relax.const(math.sqrt(self.head_dim), query_states.struct_info.dtype) + ) + + tvm.ir.assert_structural_equal( + attention_mask.struct_info.shape.values, + (bsz, tvm.tir.IntImm("int64", 1), q_len, kv_seq_len), + ) + + attn_weights = nn.emit( + maximum( + attn_weights, + relax.const( + tvm.tir.min_value(attn_weights.struct_info.dtype).value, + attn_weights.struct_info.dtype, + ), + ) + ) + attn_weights = nn.emit(relax.op.minimum(attn_weights, attention_mask)) + + # upcast attention to fp32 + if attn_weights.struct_info.dtype != "float32": + attn_weights = astype(attn_weights, "float32") + attn_weights = nn.emit(softmax(attn_weights, axis=-1)) + if attn_weights.struct_info.dtype != query_states.struct_info.dtype: + attn_weights = astype(attn_weights, query_states.struct_info.dtype) + attn_output = nn.emit(matmul(attn_weights, value_states)) + + attn_output = nn.emit(permute_dims(attn_output, [0, 2, 1, 3])) + attn_output = nn.emit( + reshape(attn_output, (bsz, q_len, self.head_dim * self.num_query_heads)) + ) + + attn_output = self.o_proj(attn_output) + return attn_output, ((None, None) if past_key_value is None else past_key_value) + + +class StableLM3bDecoderLayer(nn.Module): + def __init__(self, config: StableLM3bConfig, rotary_embedding: RotaryEmbedding): + self.hidden_size = config.hidden_size + self.self_attn = StableLM3bAttention(config, rotary_embedding) + self.mlp = StableLM3bMLP(config) + self.input_layernorm = LayerNorm( + config.hidden_size, dtype=config.dtype, eps=config.norm_eps + ) + self.post_attention_layernorm = LayerNorm( + config.hidden_size, dtype=config.dtype, eps=config.norm_eps + ) + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_value: Tuple[relax.Expr], + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + all_seq_len_shape=all_seq_len_shape, + ) + if self.self_attn.num_shards > 1: + residual = nn.emit(residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype)) + hidden_states = nn.emit(residual + hidden_states) + if self.self_attn.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if self.mlp.num_shards > 1: + residual = nn.emit(residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype)) + hidden_states = nn.emit(residual + hidden_states) + if self.mlp.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + return hidden_states, present_key_value + + +def _make_causal_mask(input_ids_shape, dtype, src_len): + from tvm.relax.op import broadcast_to + + bsz, tgt_len = input_ids_shape + + def min_max_triu_te(): + return te.compute( + (tgt_len, tgt_len), + lambda i, j: tvm.tir.Select(j > i, tvm.tir.min_value(dtype), tvm.tir.max_value(dtype)), + name="make_diag_mask_te", + ) + + mask = nn.emit_te(min_max_triu_te) + diag_mask = nn.emit(broadcast_to(mask, (bsz, 1, tgt_len, tgt_len))) + if src_len == tgt_len: + return diag_mask + + def extend_te(x, tgt_len, src_len): + return te.compute( + (bsz, 1, tgt_len, src_len), + lambda b, _, i, j: te.if_then_else( + j < src_len - tgt_len, + tvm.tir.max_value(dtype), + x[b, _, i, j - (src_len - tgt_len)], + ), + name="concat_te", + ) + + return nn.emit_te(extend_te, diag_mask, tgt_len, src_len) + + +class StableLM3bEmbedTokens(nn.Module): + def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var): + self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) + + def forward(self, input_ids: relax.Expr): + inputs_embeds = self.embed_tokens(input_ids) + return inputs_embeds + + +class StableLM3bEmbedTokensWrapper(nn.Module): + def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var): + # build a wrapper to ensure that the naming of the embed_tokens parameter is consistent + self.model = StableLM3bEmbedTokens(config, vocab_size_var) + + def forward(self, input_ids: relax.Expr): + inputs_embeds = self.model(input_ids) + return inputs_embeds + + +class StableLM3bModell(nn.Module): + def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False): + rotary_embedding = RotaryEmbedding( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + position_embedding_base=config.position_embedding_base, + max_sequence_length=config.max_sequence_length, + rotary_pct=0.25, + dtype=config.dtype, + ) + self.num_shards = config.num_shards + self.padding_idx = config.pad_token_id + self.embed_tokens = None + + if not sep_embed: + self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) + + self.layers = ModuleList( + [StableLM3bDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)] + ) + self.norm = LayerNorm(config.hidden_size, dtype=config.dtype, eps=config.norm_eps) + + def _prepare_decoder_attention_mask(self, input_shape, src_len, dtype): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if isinstance(input_shape[-1], tvm.tir.Var) or input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask(input_shape, dtype, src_len) + else: + # Get src_len from input parameters + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + bsz, tgt_len = input_shape + combined_attention_mask = nn.emit( + relax.op.full( + (bsz, 1, tgt_len, src_len), + relax.const(tvm.tir.max_value(dtype).value, dtype), + dtype, + ) + ) + return combined_attention_mask + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_values: relax.Expr, + ): + if self.num_shards > 1: + inputs = nn.emit(ccl.broadcast_from_worker0(inputs)) + if self.embed_tokens: + inputs_embeds = self.embed_tokens(inputs) + else: + inputs_embeds = inputs + # retrieve input_ids + batch_size, seq_length, _ = inputs_embeds.struct_info.shape + seq_length_with_past = all_seq_len_shape.struct_info.values[0] + # embed positions + attention_mask = self._prepare_decoder_attention_mask( + (batch_size, seq_length), + seq_length_with_past, + inputs_embeds.struct_info.dtype, + ) + + hidden_states = inputs_embeds + + # decoder layers + next_decoder_cache = () + + for idx, decoder_layer in enumerate(self.layers): + assert past_key_values is not None + past_key_value = (past_key_values[idx * 2], past_key_values[idx * 2 + 1]) + + hidden_states, key_value_cache = decoder_layer( + hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + all_seq_len_shape=all_seq_len_shape, + ) + next_decoder_cache += key_value_cache + + hidden_states = self.norm(hidden_states) + + assert len(next_decoder_cache) == len(self.layers) * 2 + return hidden_states, next_decoder_cache + + +class StableLM3bForCausalLM(nn.Module): + def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False): + self.model = StableLM3bModell(config, vocab_size_var, sep_embed) + self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) + + assert config.hidden_size % config.num_attention_heads == 0 + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_values: relax.Expr, + ): + hidden_states, key_value_cache = self.model( + inputs=inputs, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + ) + + def te_slicing(x: te.Tensor): + return te.compute( + shape=(1, 1, x.shape[-1]), + fcompute=lambda i, j, k: x[i, x.shape[1] - 1, k], + name="slice", + ) + + logits = self.lm_head(nn.emit_te(te_slicing, hidden_states, primfunc_name_hint="slice")) + if logits.struct_info.dtype != "float32": + logits = nn.emit(relax.op.astype(logits, "float32")) + + return logits, key_value_cache + + +def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: + if "embed_tokens" in name: + return ParamQuantKind.embedding_table + elif "lm_head.weight" in name: + return ParamQuantKind.final_fc_weight + elif param_info.ndim == 2 and name.endswith(".weight"): + return ParamQuantKind.linear_weight + else: + return ParamQuantKind.others + + +def create_embed_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: StableLM3bConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "embed" + + bsz = 1 + seq_len = tvm.tir.Var("n", "int64") + with bb.function(func_name): + model = StableLM3bEmbedTokensWrapper(config, tvm.tir.Var("vocab_size", "int64")) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") + with bb.dataflow(): + inputs_embeds = model(input_ids) + params = [input_ids] + model.parameters() + gv = bb.emit_output(inputs_embeds) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 1)) + + +def create_encoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: StableLM3bConfig, + quant_scheme: QuantizationScheme, + sep_embed: bool = False, +) -> None: + func_name = "prefill_with_embed" if sep_embed else "prefill" + + bsz = 1 + seq_len = tvm.tir.Var("n", "int64") + all_seq_len = tvm.tir.Var("m", "int64") + hidden_size = config.hidden_size + with bb.function(func_name): + model = StableLM3bForCausalLM(config, tvm.tir.Var("vocab_size", "int64"), sep_embed) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs = ( + nn.Placeholder((bsz, seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds") + if sep_embed + else nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") + ) + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + inputs, all_seq_len_shape, past_key_values=past_key_values + ) + params = [ + inputs, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_decoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: StableLM3bConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "decode" + + bsz = 1 + all_seq_len = tvm.tir.Var("n", "int64") + + with bb.function(func_name): + model = StableLM3bForCausalLM(config, tvm.tir.Var("vocab_size", "int64")) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + input_ids, all_seq_len_shape, past_key_values=past_key_values + ) + params = [ + input_ids, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_kv_cache_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> None: + num_key_value_heads = ( + config.num_attention_heads + if config.num_key_value_heads is None + else config.num_key_value_heads + ) // config.num_shards + init_shape = relax.ShapeExpr( + ( + config.max_sequence_length, + num_key_value_heads, + config.hidden_size // config.num_attention_heads, # head_dim + ) + ) + with bb.function("create_kv_cache", []): + with bb.dataflow(): + zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) + caches = [] + f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") + for _ in range(config.num_hidden_layers * 2): + caches.append( + bb.emit( + relax.Call( + f_kv_cache_create, + args=[zeros, init_shape, relax.PrimValue(0)], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + ) + gv = bb.emit_output(caches) + bb.emit_func_output(gv) + + +def create_softmax_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> None: + with bb.function("softmax_with_temperature"): + logits = nn.Placeholder( + (1, 1, tvm.tir.Var("vocab_size", "int64")), dtype="float32", name="logits" + ) + temperature = nn.Placeholder((), dtype="float32", name="temperature") + with bb.dataflow(): + div = bb.emit(relax.op.divide(logits, temperature)) + softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) + gv = bb.emit_output(softmax) + bb.emit_func_output(gv, [logits, temperature]) + + +def emit_shard3d(bb: relax.BlockBuilder) -> None: + from tvm.script import tir as T + + def _emit(dtype: str, global_symbol: str): + @T.prim_func + def shard_3d(a: T.handle, num_shards: T.int64, b: T.handle): + T.func_attr( + { + "tir.noalias": T.bool(True), + "global_symbol": global_symbol, + } + ) + s_0, s_1, s_2 = T.int64(), T.int64(), T.int64() + # pylint: disable=invalid-name + A = T.match_buffer(a, (s_0, s_1, s_2), dtype) + B = T.match_buffer(b, (num_shards, s_0, s_1 // num_shards, s_2), dtype) + # pylint: enable=invalid-name + for j_o, i, j_i, k in T.grid(num_shards, s_0, s_1 // num_shards, s_2): + with T.block("B"): + v_j_o = T.axis.spatial(num_shards, j_o) + v_i = T.axis.spatial(s_0, i) + v_j_i = T.axis.spatial(s_1 // num_shards, j_i) + v_k = T.axis.spatial(s_2, k) + B[v_j_o, v_i, v_j_i, v_k] = A[v_i, v_j_o * (s_1 // num_shards) + v_j_i, v_k] + + bb.add_func(shard_3d, global_symbol) + + _emit("float32", "shard3d_fp32") + _emit("float16", "shard3d_fp16") + _emit("uint32", "shard3d_uint32") + + +def get_model(args, hf_config): + model_name = args.model + dtype = args.quantization.model_dtype + max_seq_len = args.max_seq_len + sep_embed = args.sep_embed + + position_embedding_base = 10000 + if "rope_theta" in hf_config: + position_embedding_base = hf_config["rope_theta"] + + config = StableLM3bConfig( + **hf_config, + dtype=dtype, + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + convert_weight_only=args.convert_weight_only, + ) + if max_seq_len != -1: + config.max_sequence_length = max_seq_len + + param_manager = ParamManager() + bb = relax.BlockBuilder() + emit_shard3d(bb) + + if sep_embed: + create_embed_func(bb, param_manager, config, args.quantization) + create_encoding_func(bb, param_manager, config, args.quantization, sep_embed) + create_decoding_func(bb, param_manager, config, args.quantization) + create_kv_cache_func(bb, config) + create_softmax_func(bb, config) + create_metadata_func( + bb, + model_name=model_name, + max_window_size=config.max_sequence_length, + stop_tokens=[2], + add_prefix_space=False, + ) + + mod = bb.get() + for gv in mod.functions: + func = mod[gv] + if isinstance(func, relax.Function): + mod[gv] = func.with_attr( + "tir_var_upper_bound", + { + "n": config.max_sequence_length, + "m": config.max_sequence_length, + }, + ) + + if args.build_model_only: + return mod, param_manager, None, config + + def f_convert_pname_fwd(pname: str) -> List[str]: + if not config.combine_matmul: + return [pname] + + qkv_str = "query_key_value_proj" + gate_up_str = "gate_up_proj" + if qkv_str in pname: + return [ + pname.replace(qkv_str, "q_proj"), + pname.replace(qkv_str, "k_proj"), + pname.replace(qkv_str, "v_proj"), + ] + elif gate_up_str in pname: + return [ + pname.replace(gate_up_str, "gate_proj"), + pname.replace(gate_up_str, "up_proj"), + ] + else: + return [pname] + + def f_convert_param_bkwd(torch_pname: str, torch_param): + if not config.combine_matmul: + return [(torch_pname, torch_param.astype(dtype))] + + combined_layers = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"] + if any([name in torch_pname for name in combined_layers]): + return None + return [(torch_pname, torch_param.astype(dtype))] + + def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): + # Expected to enter this function only for the combined linear matmul weights. + # Other weights are supposed to be loaded in `f_convert_param_bkwd` since + # each other relax param has a unique corresponding torch param. + if not config.combine_matmul: + # When matmul combination is not turned on, each relax param has a unique + # corresponding torch param, and this function is not expected to be entered. + raise NotImplementedError( + "Matmul combination is not turned on, and the function " + "is not expected to be entered" + ) + num_shards = args.num_shards + hidden_size = config.hidden_size + head_dim = config.hidden_size // config.num_attention_heads + + if "query_key_value_proj" in relax_pname: + q_heads = config.num_attention_heads + kv_heads = config.num_key_value_heads + if kv_heads is None: + kv_heads = q_heads + q, k, v = torch_params + assert q.shape == (q_heads * head_dim, hidden_size) + assert k.shape == (kv_heads * head_dim, hidden_size) + assert v.shape == (kv_heads * head_dim, hidden_size) + q = q.reshape((num_shards, q_heads // num_shards, head_dim, hidden_size)) + k = k.reshape((num_shards, kv_heads // num_shards, head_dim, hidden_size)) + v = v.reshape((num_shards, kv_heads // num_shards, head_dim, hidden_size)) + qkv = np.concatenate([q, k, v], axis=1) + qkv = qkv.reshape((-1, hidden_size)).astype(dtype) + return qkv + if "gate_up_proj" in relax_pname: + intermediate_size = config.intermediate_size + gate, up = torch_params + gate = gate.reshape((num_shards, intermediate_size // num_shards, hidden_size)) + up = up.reshape((num_shards, intermediate_size // num_shards, hidden_size)) + gate_up = np.concatenate([gate, up], axis=1) + gate_up = gate_up.reshape((-1, hidden_size)).astype(dtype) + return gate_up + raise ValueError("Unexpected param loading") + + param_manager.set_param_loading_func( + args.model_path, + args.use_safetensors, + f_convert_pname_fwd, + f_convert_param_bkwd, + f_compute_relax_param, + ) + + param_list = [None] * param_manager.nparam_to_load + + return mod, param_manager, param_list, config diff --git a/mlc_llm/transform/decode_take.py b/mlc_llm/transform/decode_take.py index ece1c7ab23..cd09771126 100644 --- a/mlc_llm/transform/decode_take.py +++ b/mlc_llm/transform/decode_take.py @@ -17,7 +17,7 @@ def pattern_check(ctx: relax.transform.PatternCheckContext) -> bool: return "take" in take.args[0].name_hint and "decode" in decode.args[0].name_hint -def decode_take_pattern(n_aux_tensor: int): +def decode_take_pattern(n_aux_tensor: int, match_tir_vars: bool): aux_tensors = [wildcard(), wildcard(), wildcard()] decode = is_op("relax.call_tir")( GlobalVarPattern(), @@ -26,9 +26,10 @@ def decode_take_pattern(n_aux_tensor: int): ) indices = ~is_const() take_args = [decode, indices] - take = is_op("relax.call_tir")( - GlobalVarPattern(), TuplePattern(take_args), add_constraint=False - ) + call_tir_args_take = [GlobalVarPattern(), TuplePattern(take_args)] + if match_tir_vars: + call_tir_args_take.append(wildcard()) + take = is_op("relax.call_tir")(*call_tir_args_take, add_constraint=False) annotations = { "take": take, @@ -41,18 +42,17 @@ def decode_take_pattern(n_aux_tensor: int): @tvm.transform.module_pass(opt_level=0, name="FuseDecodeTake") class FuseDecodeTake: - def transform_module( - self, mod: IRModule, ctx: tvm.transform.PassContext - ) -> IRModule: + def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule: for n_aux_tensor in [2, 3]: - mod = relax.transform.FuseOpsByPattern( - [ - ( - "decode_take", - *decode_take_pattern(n_aux_tensor), - ) - ] - )(mod) + for match_tir_vars in [False, True]: + mod = relax.transform.FuseOpsByPattern( + [ + ( + "decode_take", + *decode_take_pattern(n_aux_tensor, match_tir_vars), + ) + ] + )(mod) mod = relax.transform.FuseTIR()(mod) for gv, func in mod.functions.items(): @@ -61,9 +61,9 @@ def transform_module( if "fused_decode" not in gv.name_hint or "take" not in gv.name_hint: continue - downcasted_mod = tir.transform.ForceNarrowIndexToInt32()( - tvm.IRModule({"main": func}) - )["main"] + downcasted_mod = tir.transform.ForceNarrowIndexToInt32()(tvm.IRModule({"main": func}))[ + "main" + ] sch = tir.Schedule(downcasted_mod) sch.compute_inline("decode") mod[gv] = sch.mod["main"] diff --git a/mlc_llm/transform/fuse_split_rotary_embedding.py b/mlc_llm/transform/fuse_split_rotary_embedding.py index 4ecc843f4a..ed19a7095c 100644 --- a/mlc_llm/transform/fuse_split_rotary_embedding.py +++ b/mlc_llm/transform/fuse_split_rotary_embedding.py @@ -1,5 +1,5 @@ +import tvm from tvm import relax -from tvm.script import tir as T from tvm.relax.dpl import ( PatternContext, is_op, @@ -10,234 +10,275 @@ TuplePattern, is_shape, ) -from tvm.script import relax as R +from tvm.script import relax as R, tir as T -def get_split_rotary(num_attention_heads, head_dim, position_embedding_base): - hidden_size = num_attention_heads * head_dim +def get_dynamic_split_rotary(): + """Implementation of R.split(rotary_embedding(fused_qkv)) - @T.prim_func + Implementation is generic over the number of query heads, + key/value heads, sequence length, head dimension, and position + embedding base. These parameters can be replaced with static + values using `PrimFunc.specialize`. + """ + + @T.prim_func(private=True) def split_rotary( - qkv: T.handle, - split_0: T.handle, - split_1: T.handle, - split_2: T.handle, - n: T.int64, + fused_qkv_handle: T.handle, + embedded_query_handle: T.handle, + embedded_key_handle: T.handle, + value_handle: T.handle, + rotary_offset: T.int64, + batch_size: T.int64, + seq_len: T.int64, + num_query_heads: T.int64, + num_kv_heads: T.int64, + head_dim: T.int64, + position_embedding_base: T.float32, ): - A = T.match_buffer(qkv, [1, 1, hidden_size * 3], dtype="float16") - T_split = T.match_buffer(split_0, [1, 1, hidden_size], dtype="float16") - T_split_1 = T.match_buffer(split_1, [1, 1, hidden_size], dtype="float16") - T_split_2 = T.match_buffer(split_2, [1, 1, hidden_size], dtype="float16") + Fused_QKV = T.match_buffer( + fused_qkv_handle, + [batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim], + dtype="float16", + ) + EmbeddedQuery = T.match_buffer( + embedded_query_handle, + [batch_size, seq_len, num_query_heads, head_dim], + dtype="float16", + ) + EmbeddedKey = T.match_buffer( + embedded_key_handle, + [batch_size, seq_len, num_kv_heads, head_dim], + dtype="float16", + ) + Value = T.match_buffer( + value_handle, + [batch_size, seq_len, num_kv_heads, head_dim], + dtype="float16", + ) T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(hidden_size)): - with T.block("T_split"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - A[v_ax0, v_ax1, v_ax2], - A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size)], - A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size * 2)], - ) - T.writes( - T_split[v_ax0, v_ax1, v_ax2], - T_split_1[v_ax0, v_ax1, v_ax2], - T_split_2[v_ax0, v_ax1, v_ax2], - ) - pos: T.float32 = T.Cast("float32", n - T.int64(1)) + + for iters in T.grid(batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim): + with T.block("FusedRotaryEmbeddingAndSplitQKV"): + batch_i, seq_i, head_num, head_i = T.axis.remap("SSSS", iters) + pos: T.float32 = T.Cast("float32", rotary_offset + seq_i - seq_len) + inv_freq: T.float32 = T.float32(1) / T.pow( - T.float32(position_embedding_base), - T.Cast("float32", (v_ax2 * 2) % head_dim) / T.float32(head_dim), + position_embedding_base, + T.Cast("float32", (head_i * 2) % head_dim) / T.float32(head_dim), ) freq: T.float32 = pos * inv_freq cos_value: T.float16 = T.Cast("float16", T.cos(freq)) sin_value: T.float16 = T.Cast("float16", T.sin(freq)) - T_split[v_ax0, v_ax1, v_ax2] = cos_value * A[ - v_ax0, v_ax1, v_ax2 - ] + sin_value * T.Select( - T.int64(head_dim // 2) <= v_ax2 % head_dim, - A[v_ax0, v_ax1, v_ax2 - T.int64(head_dim // 2)], - A[v_ax0, v_ax1, v_ax2 + T.int64(head_dim // 2)] * T.float16(-1), - ) - T_split_1[v_ax0, v_ax1, v_ax2] = cos_value * A[ - v_ax0, v_ax1, v_ax2 + T.int64(hidden_size) - ] + sin_value * T.Select( - T.int64(head_dim // 2) <= v_ax2 % head_dim, - A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size) - T.int64(head_dim // 2)], - A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size) + T.int64(head_dim // 2)] + + input_value = Fused_QKV[batch_i, seq_i, head_num, head_i] + embedded_value = cos_value * input_value + sin_value * T.Select( + head_i < T.int64(head_dim // 2), + Fused_QKV[batch_i, seq_i, head_num, head_i + T.int64(head_dim // 2)] * T.float16(-1), + Fused_QKV[batch_i, seq_i, head_num, head_i - T.int64(head_dim // 2)], ) - T_split_2[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size * 2)] + if head_num < num_query_heads: + EmbeddedQuery[batch_i, seq_i, head_num, head_i] = embedded_value + elif head_num < num_query_heads + num_kv_heads: + EmbeddedKey[batch_i, seq_i, head_num - num_query_heads, head_i] = embedded_value + else: + Value[ + batch_i, seq_i, head_num - num_query_heads - num_kv_heads, head_i + ] = input_value + + param_sinfo = [] + for param in split_rotary.params: + if param in split_rotary.buffer_map: + buf = split_rotary.buffer_map[param] + sinfo = relax.TensorStructInfo(shape=buf.shape, dtype=buf.dtype) + else: + sinfo = relax.PrimStructInfo(param.dtype) + param_sinfo.append(sinfo) + + relax.expr._update_struct_info( + split_rotary, + tvm.relax.FuncStructInfo( + params=param_sinfo, + ret=relax.TupleStructInfo([]), + purity=False, + ), + ) return split_rotary -def get_split_rotary_group_query_attention( - num_query_heads, num_kv_heads, head_dim, position_embedding_base +def fuse_split_rotary_embedding( + num_query_heads, num_kv_heads, hidden_size, position_embedding_base ): - query_hidden_size = num_query_heads * head_dim - kv_hidden_size = num_kv_heads * head_dim - total_size = query_hidden_size + kv_hidden_size * 2 + @tvm.ir.transform.module_pass(opt_level=0, name="fuse_split_rotary_embedding") + def ir_module_pass(mod: tvm.IRModule, _pass_context) -> tvm.IRModule: + head_dim = hidden_size // num_query_heads + split_rotary = get_dynamic_split_rotary() - @T.prim_func - def split_rotary( - qkv: T.handle, - split_0: T.handle, - split_1: T.handle, - split_2: T.handle, - n: T.int64, - ): - A = T.match_buffer(qkv, [1, 1, total_size], dtype="float16") - T_split = T.match_buffer(split_0, [1, 1, query_hidden_size], dtype="float16") - T_split_1 = T.match_buffer(split_1, [1, 1, kv_hidden_size], dtype="float16") - T_split_2 = T.match_buffer(split_2, [1, 1, kv_hidden_size], dtype="float16") + ( + dyn_batch_size, + dyn_seq_len, + dyn_num_query_heads, + dyn_num_kv_heads, + dyn_head_dim, + dyn_position_embedding_base, + ) = split_rotary.params[-6:] - T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(query_hidden_size)): - with T.block("T_split"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - A[v_ax0, v_ax1, v_ax2], - ) - T.writes(T_split[v_ax0, v_ax1, v_ax2]) - pos: T.float32 = T.Cast("float32", n - T.int64(1)) - inv_freq: T.float32 = T.float32(1) / T.pow( - T.float32(position_embedding_base), - T.Cast("float32", (v_ax2 * 2) % head_dim) / T.float32(head_dim), - ) - freq: T.float32 = pos * inv_freq - cos_value: T.float16 = T.Cast("float16", T.cos(freq)) - sin_value: T.float16 = T.Cast("float16", T.sin(freq)) - T_split[v_ax0, v_ax1, v_ax2] = cos_value * A[ - v_ax0, v_ax1, v_ax2 - ] + sin_value * T.Select( - T.int64(head_dim // 2) <= v_ax2 % head_dim, - A[v_ax0, v_ax1, v_ax2 - T.int64(head_dim // 2)], - A[v_ax0, v_ax1, v_ax2 + T.int64(head_dim // 2)] * T.float16(-1), - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(kv_hidden_size)): - with T.block("T_split"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size)], - A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size + kv_hidden_size)], - ) - T.writes( - T_split_1[v_ax0, v_ax1, v_ax2], - T_split_2[v_ax0, v_ax1, v_ax2], - ) - pos: T.float32 = T.Cast("float32", n - T.int64(1)) - inv_freq: T.float32 = T.float32(1) / T.pow( - T.float32(position_embedding_base), - T.Cast("float32", (v_ax2 * 2) % head_dim) / T.float32(head_dim), - ) - freq: T.float32 = pos * inv_freq - cos_value: T.float16 = T.Cast("float16", T.cos(freq)) - sin_value: T.float16 = T.Cast("float16", T.sin(freq)) - T_split_1[v_ax0, v_ax1, v_ax2] = cos_value * A[ - v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size) - ] + sin_value * T.Select( - T.int64(head_dim // 2) <= v_ax2 % head_dim, - A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size) - T.int64(head_dim // 2)], - A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size) + T.int64(head_dim // 2)] - * T.float16(-1), - ) - T_split_2[v_ax0, v_ax1, v_ax2] = A[ - v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size + kv_hidden_size) - ] + split_rotary = split_rotary.specialize( + { + # Static model parameters + dyn_batch_size: T.int64(1), + dyn_num_query_heads: T.int64(num_query_heads), + dyn_num_kv_heads: T.int64(num_kv_heads), + dyn_head_dim: T.int64(head_dim), + dyn_position_embedding_base: T.float32(position_embedding_base), + # Dynamic parameters, to be inferred from TIR Buffer shapes + dyn_seq_len: tvm.tir.Var("query_sequence_length", "int64"), + } + ) - return split_rotary + mod["split_rotary"] = split_rotary + split_rotary_gvar = mod.get_global_var("split_rotary") + relax.expr._update_struct_info(split_rotary_gvar, mod["split_rotary"].struct_info) -def fuse_split_rotary_embedding( - mod, num_query_heads, num_kv_heads, hidden_size, position_embedding_base -): - head_dim = hidden_size // num_query_heads - mod["split_rotary"] = ( - get_split_rotary(num_query_heads, head_dim, position_embedding_base) - if num_query_heads == num_kv_heads - else get_split_rotary_group_query_attention( - num_query_heads, num_kv_heads, head_dim, position_embedding_base - ) - ) + with PatternContext() as ctx: + # flat_qkv_tuple: R.Tuple( + # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), + # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), + # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), + # ) = R.split(flat_fused_qkv, indices_or_sections=[4096, 8192], axis=2) + # + # flat_query: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[0] + # query: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( + # flat_query, R.shape([batch_size, seq_len, 32, 128]) + # ) + # flat_key: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[1] + # key: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( + # flat_key, R.shape([batch_size, seq_len, 32, 128]) + # ) + # flat_value: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[2] + # value: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( + # flat_value, R.shape([batch_size, seq_len, 32, 128]) + # ) + # embedded_query = R.call_tir( + # cls.rotary_embedding1, + # [query], + # out_sinfo=R.Tensor((batch_size, seq_len, 32, 128), dtype="float16"), + # tir_vars=R.shape([n]), + # ) + # embedded_key = R.call_tir( + # cls.rotary_embedding1, + # [key], + # out_sinfo=R.Tensor((batch_size, seq_len, 32, 128), dtype="float16"), + # tir_vars=R.shape([n]), + # ) - gvar = mod.get_global_var("split_rotary") - relax.expr._update_struct_info(gvar, mod.get_global_var("rotary_embedding1").struct_info) + pat_rotary_embedding_gvar = GlobalVarPattern() - with PatternContext() as ctx: - # lv3: R.Tuple(R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")) = R.split(lv2, indices_or_sections=[4096, 8192], axis=2) + pat_flat_fused_qkv = wildcard() + pat_offset = wildcard() - # lv1521: R.Tensor((1, 1, 4096), dtype="float16") = lv3[0] - # lv1522: R.Tensor((1, 1, 32, 128), dtype="float16") = R.reshape(lv1521, R.shape([1, 1, 32, 128])) - # lv1524: R.Tensor((1, 1, 4096), dtype="float16") = lv3[1] - # lv1525: R.Tensor((1, 1, 32, 128), dtype="float16") = R.reshape(lv1524, R.shape([1, 1, 32, 128])) - # lv1527: R.Tensor((1, 1, 4096), dtype="float16") = lv3[2] - # lv1528: R.Tensor((1, 1, 32, 128), dtype="float16") = R.reshape(lv1527, R.shape([1, 1, 32, 128])) - # lv1530 = R.call_tir(cls.rotary_embedding1, (lv1525, cos_cached1, sin_cached1), out_sinfo=R.Tensor((1, 1, 32, 128), dtype="float16"), tir_vars=R.shape([n])) - # lv_1 = R.call_tir(cls.rotary_embedding1, (lv1522, cos_cached1, sin_cached1), out_sinfo=R.Tensor((1, 1, 32, 128), dtype="float16"), tir_vars=R.shape( + # query_shape = is_shape([1, seq_len, num_query_heads, head_dim]) + pat_query_shape = wildcard() + # value_shape = is_shape([1, seq_len, num_kv_heads, head_dim]) + pat_key_shape = wildcard() + # value_shape = is_shape([1, seq_len, num_kv_heads, head_dim]) + pat_value_shape = wildcard() - inp_pat = wildcard() - offset = wildcard() + pat_flat_qkv_tuple = is_op("relax.split")(pat_flat_fused_qkv) + pat_flat_query = is_tuple_get_item(pat_flat_qkv_tuple, 0) + pat_query = is_op("relax.reshape")( + pat_flat_query, pat_query_shape, add_constraint=False + ) + pat_flat_query.used_by(pat_query) + pat_flat_key = is_tuple_get_item(pat_flat_qkv_tuple, 1) + pat_key = is_op("relax.reshape")(pat_flat_key, pat_key_shape, add_constraint=False) + pat_flat_key.used_by(pat_key) + pat_flat_value = is_tuple_get_item(pat_flat_qkv_tuple, 2) + pat_value = is_op("relax.reshape")( + pat_flat_value, pat_value_shape, add_constraint=False + ) + pat_flat_value.used_by(pat_value) - lv3 = is_op("relax.split")(inp_pat) - lv1521 = is_tuple_get_item(lv3, 0) - lv1522 = is_op("relax.reshape")( - lv1521, is_shape([1, 1, num_query_heads, head_dim]), add_constraint=False - ) - lv1521.used_by(lv1522) - lv1524 = is_tuple_get_item(lv3, 1) - lv1525 = is_op("relax.reshape")( - lv1524, is_shape([1, 1, num_kv_heads, head_dim]), add_constraint=False - ) - lv1524.used_by(lv1525) - lv1527 = is_tuple_get_item(lv3, 2) - V = is_op("relax.reshape")( - lv1527, is_shape([1, 1, num_kv_heads, head_dim]), add_constraint=False - ) - lv1527.used_by(V) + pat_embedded_query = is_op("relax.call_tir")( + pat_rotary_embedding_gvar, + TuplePattern([pat_query]), + pat_offset, + add_constraint=False, + ) + pat_embedded_key = is_op("relax.call_tir")( + pat_rotary_embedding_gvar, + TuplePattern([pat_key]), + pat_offset, + add_constraint=False, + ) - Q = is_op("relax.call_tir")( - GlobalVarPattern(), TuplePattern([lv1522]), offset, add_constraint=False - ) - K = is_op("relax.call_tir")( - GlobalVarPattern(), TuplePattern([lv1525]), offset, add_constraint=False - ) + pat_flat_qkv_tuple.used_by(pat_flat_query) + pat_flat_qkv_tuple.used_by(pat_flat_key) + pat_flat_qkv_tuple.used_by(pat_flat_value) + pat_query.used_by(pat_embedded_query) + pat_key.used_by(pat_embedded_key) - lv3.used_by(lv1521) - lv3.used_by(lv1524) - lv3.used_by(lv1527) - lv1522.used_by(Q) - lv1525.used_by(K) - - def rewriter(matchings, bindings): - inp = matchings[inp_pat] - call_tir = matchings[Q] - n = bindings[call_tir].args[-1] - out_sinfo = [ - R.Tensor((1, 1, num_query_heads * head_dim), dtype="float16"), - R.Tensor((1, 1, num_kv_heads * head_dim), dtype="float16"), - R.Tensor((1, 1, num_kv_heads * head_dim), dtype="float16"), - ] - lv3_new = R.call_tir( - mod.get_global_var("split_rotary"), (inp,), out_sinfo=out_sinfo, tir_vars=n - ) - lv1521_new = lv3_new[0] - lv1522_new = R.reshape(lv1521_new, R.shape([1, 1, num_query_heads, head_dim])) - lv1524_new = lv3_new[1] - lv1525_new = R.reshape(lv1524_new, R.shape([1, 1, num_kv_heads, head_dim])) - lv1527_new = lv3_new[2] - lv1528_new = R.reshape(lv1527_new, R.shape([1, 1, num_kv_heads, head_dim])) - - return { - matchings[lv3]: lv3_new, - matchings[lv1521]: lv1521_new, - matchings[lv1522]: lv1522_new, - matchings[lv1524]: lv1524_new, - matchings[lv1525]: lv1525_new, - matchings[lv1527]: lv1527_new, - matchings[V]: lv1528_new, - matchings[Q]: lv1522_new, - matchings[K]: lv1525_new, - } - - mod["decode"] = rewrite_bindings(ctx, rewriter, mod["decode"]) - return mod + def rewriter(matchings, bindings): + # Extracting all the relax and TIR variables that we'll need + flat_fused_qkv = matchings[pat_flat_fused_qkv] + flat_qkv_tuple = matchings[pat_flat_qkv_tuple] + + flat_query = matchings[pat_flat_query] + flat_key = matchings[pat_flat_key] + flat_value = matchings[pat_flat_value] + + query = matchings[pat_query] + key = matchings[pat_key] + value = matchings[pat_value] + + embedded_query = matchings[pat_embedded_query] + embedded_key = matchings[pat_embedded_key] + + # rotary_embedding_offset = bindings[query].args[-1][1] + rotary_embedding_offset = bindings[embedded_query].args[-1][0] + + batch_size, seq_len, num_query_heads, head_dim = query.struct_info.shape + _batch_size, _seq_len, num_kv_heads, _head_dim = key.struct_info.shape + + # Rewriting along the new path + + fused_qkv = relax.op.reshape( + flat_fused_qkv, [batch_size, seq_len, num_query_heads + 2 * num_kv_heads, head_dim] + ) + + split_rotary_sinfo = [ + R.Tensor((batch_size, seq_len, num_query_heads, head_dim), dtype="float16"), + R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), + R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), + ] + qkv_tuple_new = R.call_tir( + split_rotary_gvar, + (fused_qkv,), + out_sinfo=split_rotary_sinfo, + tir_vars=[rotary_embedding_offset], + ) + + embedded_query_new = qkv_tuple_new[0] + embedded_key_new = qkv_tuple_new[1] + value_new = qkv_tuple_new[2] + + return { + value: value_new, + embedded_query: embedded_query_new, + embedded_key: embedded_key_new, + } + + new_mod = {} + for gvar, func in mod.functions.items(): + if isinstance(func, relax.Function): + func = rewrite_bindings(ctx, rewriter, func) + new_mod[gvar] = func + + new_mod = tvm.IRModule(new_mod, mod.type_definitions, mod.attrs, mod.global_infos) + return new_mod + + return ir_module_pass diff --git a/mlc_llm/transform/rewrite_attention.py b/mlc_llm/transform/rewrite_attention.py index b6d2a493ab..d6d5693762 100644 --- a/mlc_llm/transform/rewrite_attention.py +++ b/mlc_llm/transform/rewrite_attention.py @@ -1,35 +1,46 @@ +import tvm from tvm.relax.dpl import PatternContext, is_const, is_op, rewrite_call, wildcard from tvm.script import relax as R -def rewrite_attention(f, use_flash_mqa=False): - Q = wildcard() - K = wildcard() - V = wildcard() +def rewrite_attention(use_flash_mqa=False): + @tvm.ir.transform.module_pass(opt_level=0, name="mlc_llm.transform.rewrite_attention") + def ir_module_transform(mod: tvm.IRModule, context) -> tvm.IRModule: + Q = wildcard() + K = wildcard() + V = wildcard() - Q_BNSH = is_op("relax.permute_dims")(Q) + Q_BNSH = is_op("relax.permute_dims")(Q) - if use_flash_mqa: - K_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(K)) - V_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(V)) - else: - K_BNSH = is_op("relax.permute_dims")(K) - V_BNSH = is_op("relax.permute_dims")(V) + if use_flash_mqa: + K_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(K)) + V_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(V)) + else: + K_BNSH = is_op("relax.permute_dims")(K) + V_BNSH = is_op("relax.permute_dims")(V) - K_BNSH_T = is_op("relax.permute_dims")(K_BNSH) + K_BNSH_T = is_op("relax.permute_dims")(K_BNSH) - matmul1 = is_op("relax.matmul")(Q_BNSH, K_BNSH_T) - divide = is_op("relax.divide")(matmul1, is_const()) - max = is_op("relax.maximum")(divide, is_const()) - min = is_op("relax.minimum")(max, wildcard()) - softmax = is_op("relax.nn.softmax")(is_op("relax.astype")(min)) - matmul2 = is_op("relax.matmul")(is_op("relax.astype")(softmax), V_BNSH) + matmul1 = is_op("relax.matmul")(Q_BNSH, K_BNSH_T) + divide = is_op("relax.divide")(matmul1, is_const()) + max = is_op("relax.maximum")(divide, is_const()) + min = is_op("relax.minimum")(max, wildcard()) + softmax = is_op("relax.nn.softmax")(is_op("relax.astype")(min)) + matmul2 = is_op("relax.matmul")(is_op("relax.astype")(softmax), V_BNSH) - pattern = is_op("relax.permute_dims")(matmul2) + pattern = is_op("relax.permute_dims")(matmul2) - def callback(_, matchings): - return R.nn.attention( - matchings[Q], matchings[K], matchings[V], causal_mask="BottomRight" - ) + def callback(_, matchings): + return R.nn.attention( + matchings[Q], matchings[K], matchings[V], causal_mask="BottomRight" + ) - return rewrite_call(pattern, callback, f) + new_module = {} + for gvar, func in mod.functions.items(): + if isinstance(func, tvm.relax.Function): + func = rewrite_call(pattern, callback, func) + new_module[gvar] = func + + return tvm.IRModule(new_module, mod.type_definitions, mod.attrs, mod.global_infos) + + return ir_module_transform diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 17329c19d4..1bcf1e8816 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -1,22 +1,44 @@ # pylint: disable=missing-docstring,invalid-name import argparse +import functools import json +import math import os import shutil from typing import Any, Dict, List, Optional, Set +import numpy as np + import tvm from tvm import relax from .quantization import quantization_schemes from .relax_model import param_manager -from .transform import ReorderTransformFunc + supported_model_types = set( - ["llama", "gpt_neox", "gpt_bigcode", "minigpt", "moss", "rwkv", "gptj", "chatglm", "mistral"] + ["llama", "gpt_neox", "gpt_bigcode", "minigpt", "moss", "rwkv", "gptj", "chatglm", "mistral", "stablelm_epoch"] ) +def wrap_tqdm_counter(func, **tqdm_kwargs): + # tqdm isn't a hard requirement, so return the original function + # if it isn't available. + try: + from tqdm import tqdm + except ImportError: + return func + + pbar = tqdm(**tqdm_kwargs) + + @functools.wraps(func) + def inner(*args, **kwargs): + pbar.update(1) + return func(*args, **kwargs) + + return inner + + def argparse_postproc_common(args: argparse.Namespace) -> None: if hasattr(args, "device_name"): if args.device_name == "auto": @@ -64,6 +86,7 @@ def argparse_postproc_common(args: argparse.Namespace) -> None: "codellama": "codellama_completion", "vicuna-": "vicuna_v1.1", "dolly-": "dolly", + "stablelm-3b-": "stablelm-3b", "stablelm-": "stablelm", "redpajama-": "redpajama_chat", "minigpt": "minigpt", @@ -191,31 +214,18 @@ def convert_weights( model_params: List[Optional[tvm.nd.NDArray]], args: argparse.Namespace, ): - # Run pre-quantization if provided. - if param_mgr.f_run_prequantize is not None: - args.model_path = param_mgr.f_run_prequantize(args.model_path) - param_mgr.model_path = args.model_path - param_mgr.torch_pname2binname = ( - param_manager.load_torch_pname2binname_map( - args.model_path, - args.use_safetensors, - set(param_mgr.pidx2pname.values()), - param_mgr.f_convert_pname_fwd, - ) - if len(param_mgr.pidx2pname) != 0 - else dict() - ) - # Create the quantization function. # We first create an initial one, then reorder it according to each # weight's location in the binary files, in the purpose of reducing # memory usage when loading torch weights as well as acceleration. - mod_transform = param_manager.create_quantize_func(param_mgr) - mod_transform = ReorderTransformFunc( - param_mgr.pidx2pname, - param_mgr.torch_pname2binname, - param_mgr.f_convert_pname_fwd, - )(mod_transform) + mod_transform = param_mgr.create_parameter_transformation() + + # Save the number of parameters before we lower mod_transform, so + # we can use them in the progress bar. + transform_func = mod_transform["transform_params"] + num_original_params = len(transform_func.params[0].struct_info.fields) + num_transformed_params = len(transform_func.struct_info.ret.fields) + # Remove the dataflow block inside the param transform function, # so that the LazyTransformParams pass can be applied. mod_transform = relax.transform.ToNonDataflow()(mod_transform) @@ -245,6 +255,14 @@ def convert_weights( device, device_cpu, ) + + get_item = wrap_tqdm_counter( + get_item, desc="Get old param", position=0, unit="tensors", total=num_original_params + ) + set_item = wrap_tqdm_counter( + set_item, desc="Set new param", position=1, unit="tensors", total=num_transformed_params + ) + tvm.register_func(func_name="get_item", f=get_item, override=True) tvm.register_func(func_name="set_item", f=set_item, override=True) @@ -268,11 +286,12 @@ def save_params(params: List[tvm.nd.NDArray], artifact_path: str) -> None: meta_data["ParamSize"] = len(params) total_size = 0.0 for i, nd in enumerate(params): + assert nd is not None, f"Missing parameter at index {i}" param_dict[f"param_{i}"] = nd - np_nd = nd.numpy() - total_size += np_nd.size * np_nd.dtype.itemsize - total_size = total_size / 1024.0 / 1024.0 / 1024.0 - print(f"Total param size: {total_size} GB") + + total_size_bytes = sum(math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params) + total_size_gb = total_size_bytes / (1024 ** 3) + print(f"Total param size: {total_size_gb} GB") tvmjs.dump_ndarray_cache( param_dict, f"{artifact_path}/params", meta_data=meta_data, encode_format="raw" ) diff --git a/pyproject.toml b/pyproject.toml index 2310e9aa60..ccf754554f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,4 +19,18 @@ profile = "black" [tool.black] line-length = 100 -target-version = ['py310'] + +[tool.mypy] +ignore_missing_imports = true +show_column_numbers = true +show_error_context = true +follow_imports = "skip" +ignore_errors = false +strict_optional = false +install_types = true + +[tool.pylint.messages_control] +max-line-length = 100 +disable = """ +duplicate-code, +""" diff --git a/python/mlc_chat/__init__.py b/python/mlc_chat/__init__.py index 5d55de875f..756c785bf6 100644 --- a/python/mlc_chat/__init__.py +++ b/python/mlc_chat/__init__.py @@ -2,7 +2,5 @@ MLC Chat is the app runtime of MLC LLM. """ +from .chat_module import ChatConfig, ChatModule, ConvConfig, GenerationConfig from .libinfo import __version__ -from .chat_module import ChatModule -from .chat_module import ConvConfig -from .chat_module import ChatConfig diff --git a/python/mlc_chat/base.py b/python/mlc_chat/base.py index e8393eecf7..8980330977 100644 --- a/python/mlc_chat/base.py +++ b/python/mlc_chat/base.py @@ -1,5 +1,4 @@ """Load MLC LLM library and _ffi_api functions.""" - import ctypes import os import sys @@ -15,7 +14,9 @@ def _load_mlc_llm_lib(): if sys.platform.startswith("win32") and sys.version_info >= (3, 8): for path in libinfo.get_dll_directories(): os.add_dll_directory(path) + # pylint: disable=protected-access lib_name = "mlc_llm" if tvm._ffi.base._RUNTIME_ONLY else "mlc_llm_module" + # pylint: enable=protected-access lib_path = libinfo.find_lib_path(lib_name, optional=False) return ctypes.CDLL(lib_path[0]), lib_path[0] @@ -46,6 +47,7 @@ def get_delta_message(curr_message: str, new_message: str) -> str: def set_global_random_seed(seed): + """Set global random seed for python, numpy, torch and tvm.""" if "numpy" in sys.modules: sys.modules["numpy"].random.seed(seed) if "torch" in sys.modules: diff --git a/python/mlc_chat/callback.py b/python/mlc_chat/callback.py index faf2dbd953..0ef3fe580b 100644 --- a/python/mlc_chat/callback.py +++ b/python/mlc_chat/callback.py @@ -1,5 +1,7 @@ """Namespace of callback functions in Python API.""" -#! pylint: disable=unused-import, invalid-name, unnecessary-pass +# pylint: disable=unused-import, invalid-name, unnecessary-pass +from queue import Queue +from typing import Optional from .base import get_delta_message @@ -74,3 +76,46 @@ def delta_callback(self, delta_message: str): def stopped_callback(self): r"""Stream an additional '\n' when generation ends.""" print() + + +class StreamIterator(DeltaCallback): + """Stream the output using an iterator. + A queue stores the delta messages""" + + def __init__(self, callback_interval: int = 2, timeout: Optional[float] = None): + r"""Initialize the callback class with callback interval and queue timeout. + + Parameters + ---------- + callback_interval : int + The refresh rate of the streaming process. + timeout : Optional[float] + Timeout for put and get from the delta messages queue + """ + super().__init__() + self.delta_messages: Queue[str] = Queue() + self.callback_interval = callback_interval + self.timeout = timeout + + def delta_callback(self, delta_message: str): + r"""Stream the delta message to iterator (adding). + + Parameters + ---------- + delta_message : str + The delta message (the part that has not been added to queue yet). + """ + self.delta_messages.put(delta_message, timeout=self.timeout) + + def stopped_callback(self): + """Using None as the stop signal for the iterator""" + self.delta_messages.put(None, timeout=self.timeout) + + def __iter__(self): + return self + + def __next__(self): + value = self.delta_messages.get(timeout=self.timeout) + if value: + return value + raise StopIteration() diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index 9e35224801..058557c182 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -1,18 +1,19 @@ """The Python API for MLC chat.""" -#! pylint: disable=unused-import, invalid-name +#! pylint: disable=too-many-lines import inspect import json import logging import os import sys +import warnings from dataclasses import asdict, dataclass, fields from enum import Enum -from typing import List, Optional +from typing import List, Optional, Tuple, Union import tvm -from tvm.runtime import disco +from tvm.runtime import disco # pylint: disable=unused-import -from . import callback +from .interface.openai_api import ChatMessage # pylint: disable=line-too-long _PYTHON_GET_STARTED_TUTORIAL_URL = "https://github.com/mlc-ai/notebooks/blob/main/mlc-llm/tutorial_chat_module_getting_started.ipynb" @@ -20,7 +21,7 @@ @dataclass -class ConvConfig: +class ConvConfig: # pylint: disable=too-many-instance-attributes r"""A dataclass that represents user-defined partial configuration for conversation template. This is an attribute of :class:`mlc_chat.ChatConfig`, which can then be passed in to the @@ -82,7 +83,7 @@ def __post_init__(self): @dataclass -class ChatConfig: +class ChatConfig: # pylint: disable=too-many-instance-attributes r"""A dataclass that represents user-defined partial configuration for the chat config file. @@ -90,7 +91,7 @@ class ChatConfig: :class:`mlc_chat.ChatModule` instance to override the default setting in ``mlc-chat-config.json`` under the model folder. - Since the configuraiton is partial, everything will be ``Optional``. + Since the configuration is partial, everything will be ``Optional``. Note that we will exploit this class to also represent ``mlc-chat-config.json`` during intermediate processing. @@ -130,14 +131,19 @@ class ChatConfig: For additional information on top-p sampling, please refer to this blog post: https://huggingface.co/blog/how-to-generate#top-p-nucleus-sampling. mean_gen_len : Optional[int] + The approximated average number of generated tokens in each round. Used + to determine whether the maximum window size would be exceeded. max_gen_len : Optional[int] + The maximum number of tokens to be generated in each round. Would simply + stop generating after this number is exceeded. shift_fill_factor : Optional[float] + The fraction of maximum window size to shift when it is exceeded. tokenizer_files : Optional[List[str]] List of tokenizer files of the model. conv_config : Optional[ConvConfig] The partial overriding configuration for conversation template. Will first load the predefined template with the name specified in ``conv_template`` - and then override some of the configuraitons specified in ``conv_config``. + and then override some of the configurations specified in ``conv_config``. model_category : Optional[str] The category of the model's architecture (e.g. ``llama``, ``gpt_neox``, ``rwkv``). model_name : Optional[str] @@ -165,12 +171,94 @@ class ChatConfig: max_window_size: Optional[int] = None @classmethod - def _from_json(chat_config_cls, json_obj: dict): - return chat_config_cls( + def _from_json(cls, json_obj: dict): + return cls(**{k: v for k, v in json_obj.items() if k in inspect.signature(cls).parameters}) + + +@dataclass +class GenerationConfig: # pylint: disable=too-many-instance-attributes + r"""A dataclass that represents user-defined generation configuration. + + An instance of ``GenerationConfig`` can be passed in to the generate function + of a :class:`mlc_chat.ChatModule` instance to override the default generation + setting in ``mlc-chat-config.json`` and ``ChatConfig`` under the model folder. + + Once the generation ends, ``GenerationConfig`` is discarded, since the values + will only override the ``ChatConfig`` generation settings during one generation, + unless it is recurrently passed to generate function. This allows changing generation + settings over time, without overriding ``ChatConfig`` permanently. + + Since the configuraiton is partial, everything will be ``Optional``. + + Parameters + ---------- + temperature : Optional[float] + The temperature applied to logits before sampling. The default value is + ``0.7``. A higher temperature encourages more diverse outputs, while a + lower temperature produces more deterministic outputs. + presence_penalty : Optional[float] + Number between -2.0 and 2.0. Positive values penalize new tokens based on + whether they appear in the text so far, increasing the model's likelihood + to talk about new topics. Negative values can increase the likelihood of + repetition. + frequency_penalty : Optional[float] + Number between -2.0 and 2.0. Positive values penalize new tokens based on their + existing frequency in the text so far, decreasing the model's likelihood to + repeat the same line verbatim. Negative values can increase the likelihood of + repetition. + repetition_penalty : Optional[float] + The repetition penalty controls the likelihood of the model generating + repeated texts. The default value is set to ``1.0``, indicating that no + repetition penalty is applied. Increasing the value reduces the + likelihood of repeat text generation. However, setting a high + ``repetition_penalty`` may result in the model generating meaningless + texts. The ideal choice of repetition penalty may vary among models. Only + Active when presence_penalty and frequency_penalty are both 0.0. + + For more details on how repetition penalty controls text generation, please + check out the CTRL paper (https://arxiv.org/pdf/1909.05858.pdf). + top_p : Optional[float] + This parameter determines the set of tokens from which we sample during + decoding. The default value is set to ``0.95``. At each step, we select + tokens from the minimal set that has a cumulative probability exceeding + the ``top_p`` parameter. + + For additional information on top-p sampling, please refer to this blog + post: https://huggingface.co/blog/how-to-generate#top-p-nucleus-sampling. + mean_gen_len : Optional[int] + The approximated average number of generated tokens in each round. Used + to determine whether the maximum window size would be exceeded. + max_gen_len : Optional[int] + This parameter determines the maximum length of the generated text. If it is + not set, the model will generate text until it encounters a stop token. + n : Optional[int] + This parameter determines the number of text samples to generate. The default + value is ``1``. Note that this parameter is only used when ``stream`` is set to + ``False``. + stop : Optional[Union[str, List[str]]] + When ``stop`` is encountered, the model will stop generating output. + It can be a string or a list of strings. If it is a list of strings, the model + will stop generating output when any of the strings in the list is encountered. + Note that this parameter does not override the default stop string of the model. + """ + + temperature: Optional[float] = None + repetition_penalty: Optional[float] = None + top_p: Optional[float] = None + mean_gen_len: Optional[int] = None + max_gen_len: Optional[int] = None + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + n: Optional[int] = None # pylint: disable=invalid-name + stop: Optional[Union[str, List[str]]] = None + + @classmethod + def _from_chat_config(cls, chat_config_obj: ChatConfig): + return cls( **{ - k: v - for k, v in json_obj.items() - if k in inspect.signature(chat_config_cls).parameters + f.name: getattr(chat_config_obj, f.name) + for f in fields(chat_config_obj) + if f.name in inspect.signature(cls).parameters } ) @@ -178,20 +266,21 @@ def _from_json(chat_config_cls, json_obj: dict): class PlaceInPrompt(Enum): """The place of an input message in a prompt.""" - # The input message should have role names and corresponding seperators appended both prior to it and after it, - # making it a complete prompt. - All = 0 - # The input message is only the beginning part of a prompt, no role name and separator should be appended after - # the message since there will be future messages appended after the message. - Begin = 1 - # The input message is in the middle of a prompt, nothing should be appended before or after the message. - Middle = 2 - # The input message is the ending part of a prompt, no role name and separator should be appended prior to it - # since the message is concatenated to some prior messages. - End = 3 - - -def _get_model_path(model: str) -> (str, str): + # The input message should have role names and corresponding seperators appended both prior to + # it and after it, making it a complete prompt. + All = 0 # pylint: disable=invalid-name + # The input message is only the beginning part of a prompt, no role name and separator should + # be appended after the message since there will be future messages appended after the message. + Begin = 1 # pylint: disable=invalid-name + # The input message is in the middle of a prompt, nothing should be appended before or after + # the message. + Middle = 2 # pylint: disable=invalid-name + # The input message is the ending part of a prompt, no role name and separator should be + # appended prior to it since the message is concatenated to some prior messages. + End = 3 # pylint: disable=invalid-name + + +def _get_model_path(model: str) -> Tuple[str, str]: """Use user-provided argument ``model`` to search for a valid model path. We define "valid" as having an ``mlc-chat-config.json`` right under the folder. @@ -225,8 +314,8 @@ def _get_model_path(model: str) -> (str, str): for candidate in candidate_paths: chat_file = os.path.join(candidate, "mlc-chat-config.json") if os.path.isfile(chat_file): - logging.info(f"Using model folder: {os.path.abspath(candidate)}") - logging.info(f"Using mlc chat config: {os.path.abspath(chat_file)}") + logging.info("Using model folder: %s", os.path.abspath(candidate)) + logging.info("Using mlc chat config: %s", os.path.abspath(chat_file)) return candidate, chat_file # Failed to find a valid model_path, analyzing error for user @@ -241,7 +330,7 @@ def _get_model_path(model: str) -> (str, str): if found_folder: # Error 1: there is a folder, but not an mlc-llm model folder (E1) - err_msg = ( + raise FileNotFoundError( "The model folder provided does not seem to refer to a valid mlc-llm model folder.\n" "Specifically, we cannot find `mlc-chat-config.json`, a required file. You should " "provide a path that contains the file.\n" @@ -251,21 +340,16 @@ def _get_model_path(model: str) -> (str, str): f"Please checkout {_PYTHON_GET_STARTED_TUTORIAL_URL} for an example on " "how to load a model." ) - raise FileNotFoundError(err_msg) - else: - # Error 2: cannot find a folder (E0) - all_paths_str = "" - for path in candidate_paths: - all_paths_str += f"- {path}\n" - err_msg = ( - "Cannot find the model folder. We searched over the following possible paths:\n" - f"{all_paths_str}" - "You can try to pass in `model=/path/to/your-model-path`, and confirm " - "that it contains `mlc-chat-config.json`, among other essential files.\n" - f"Please checkout {_PYTHON_GET_STARTED_TUTORIAL_URL} for an " - "example on how to load a model." - ) - raise FileNotFoundError(err_msg) + # Error 2: cannot find a folder (E0) + all_paths_str = "".join(f"- {path}\n" for path in candidate_paths) + raise FileNotFoundError( + "Cannot find the model folder. We searched over the following possible paths:\n" + f"{all_paths_str}" + "You can try to pass in `model=/path/to/your-model-path`, and confirm " + "that it contains `mlc-chat-config.json`, among other essential files.\n" + f"Please checkout {_PYTHON_GET_STARTED_TUTORIAL_URL} for an " + "example on how to load a model." + ) def _get_chat_config(config_file_path: str, user_chat_config: Optional[ChatConfig]) -> ChatConfig: @@ -284,24 +368,63 @@ def _get_chat_config(config_file_path: str, user_chat_config: Optional[ChatConfi ``ChatConfig`` corresponding to ``config_file_path``, overriden by ``user_chat_config``. """ final_chat_config = None - with open(config_file_path, mode="rt", encoding="utf-8") as f: - json_object = json.load(f) - final_chat_config = ChatConfig._from_json(json_object) + with open(config_file_path, mode="rt", encoding="utf-8") as file: + json_object = json.load(file) + final_chat_config = ChatConfig._from_json(json_object) # pylint: disable=protected-access if user_chat_config is not None: # We override using user's chat config for field in fields(user_chat_config): field_name = field.name field_value = getattr(user_chat_config, field_name) if field_value is not None: - setattr(final_chat_config, field_name, field_value) + if field_name == "model_lib": + warn_msg = ( + 'WARNING: Do not override "model_lib" in ChatConfig. ' + "This override will be ignored. Please use ChatModule.model_lib_path to " + "override the full model library path instead." + ) + warnings.warn(warn_msg) + else: + setattr(final_chat_config, field_name, field_value) return final_chat_config -def _get_lib_module_path( +def _get_generation_config( + user_chat_config: ChatConfig, user_generation_config: Optional[GenerationConfig] +) -> GenerationConfig: + """Read in the config file in model path, then potentially override with user input. + + Parameters + ---------- + user_chat_config : ChatConfig + ``ChatConfig`` that contain the generation settings to be overriden. + user_generation_config : Optional[GenerationConfig] + User's input, a partial ``GenerationConfig`` to override the ``ChatConfig``. + + Returns + ------ + final_generation_config : GenerationConfig + ``GenerationConfig`` corresponding to ``user_chat_config``, overriden by + ``user_generation_config``. + """ + # pylint: disable=protected-access + final_generation_config = GenerationConfig._from_chat_config(user_chat_config) + # pylint: enable=protected-access + if user_generation_config is not None: + # We override using user's chat config + for field in fields(user_generation_config): + field_name = field.name + field_value = getattr(user_generation_config, field_name) + if field_value is not None: + setattr(final_generation_config, field_name, field_value) + return final_generation_config + + +def _get_lib_module_path( # pylint: disable=too-many-arguments model: str, model_path: str, chat_config: ChatConfig, - lib_path: Optional[str], + model_lib_path: Optional[str], device_name: str, config_file_path: str, ) -> str: @@ -315,7 +438,7 @@ def _get_lib_module_path( Model path found by `_get_model_path`. chat_config : ChatConfig Chat config after potential overrides. Returned by ``_get_chat_config``. - lib_path : Optional[str] + model_lib_path : Optional[str] User's input. Supposedly a full path to model library. Prioritized to use. device_name : str User's input. Used to construct the library model file name. @@ -324,24 +447,22 @@ def _get_lib_module_path( Returns ------ - lib_path : str + model_lib_path : str The path pointing to the model library we find. Raises ------ FileNotFoundError: if we cannot find a valid model library file. """ - # 1. Use user's lib_path if provided - if lib_path is not None: - if os.path.isfile(lib_path): - logging.info(f"Using library model: {lib_path}") - return lib_path - else: - err_msg = ( - f"The `lib_path` you passed in is not a file: {lib_path}.\nPlease checkout " - f"{_PYTHON_GET_STARTED_TUTORIAL_URL} for an example on how to load a model." - ) - raise FileNotFoundError(err_msg) + # 1. Use user's model_lib_path if provided + if model_lib_path is not None: + if os.path.isfile(model_lib_path): + logging.info("Using library model: %s", model_lib_path) + return model_lib_path + raise FileNotFoundError( + f"The `model_lib_path` you passed in is not a file: {model_lib_path}.\n" + f"Please refer to {_PYTHON_GET_STARTED_TUTORIAL_URL} as tutorial on model loading." + ) # 2. Generate all possible file names according to OS candidate_lib_names = [] @@ -380,7 +501,7 @@ def _get_lib_module_path( # 4. Search for model library for candidate in candidate_paths: if os.path.isfile(candidate): - logging.info(f"Using library model: {os.path.abspath(candidate)}\n") + logging.info("Using library model: %s", os.path.abspath(candidate)) return candidate # 5. Error @@ -394,21 +515,23 @@ def _get_lib_module_path( err_msg += f"- {candidate}\n" err_msg += ( "If you would like to directly specify the model library path, you may " - "consider passing in the `lib_path` parameter.\n" + "consider passing in the `ChatModule.model_lib_path` parameter.\n" f"Please checkout {_PYTHON_GET_STARTED_TUTORIAL_URL} for an example " "on how to load a model." ) raise FileNotFoundError(err_msg) -def _convert_chat_config_to_json_str(chat_config: Optional[ChatConfig], conv_template: str) -> str: +def _convert_chat_config_to_json_str( + chat_config: Optional[ChatConfig], conv_template: Optional[str] +) -> str: """Convert user's input ChatConfig to a json string, omitting ``None`` fields. Parameters ---------- chat_config : Optional[ChatConfig] User's input. A partial ChatConfig for overriding ``mlc-chat-config.json``. - conv_template : str + conv_template : Optional[str] The ``conv_template`` that will be used after considering potential override. Returns @@ -425,23 +548,91 @@ def _convert_chat_config_to_json_str(chat_config: Optional[ChatConfig], conv_tem # Only want to keep entries that are not None; otherwise, we would override things to None assert hasattr(ChatConfig, "conv_config") # in case dataclass attribute name changes chat_dict = {} - for k, v in asdict(chat_config).items(): - if k == "conv_config" and v is not None: + for key, value in asdict(chat_config).items(): + if key == "conv_config" and value is not None: # conv template is another dict, do the same thing conv_dict = {} - for conv_k, conv_v in v.items(): + for conv_k, conv_v in value.items(): if conv_v is not None: conv_dict[conv_k] = conv_v - chat_dict[k] = conv_dict + chat_dict[key] = conv_dict continue - - if v is not None: - chat_dict[k] = v + if value is not None: + chat_dict[key] = value return json.dumps(chat_dict) -def _detect_local_device(device_id: int = 0): +def _convert_generation_config_to_json_str(generation_config: Optional[GenerationConfig]) -> str: + """Convert user's input GenerationConfig to a json string. + + Parameters + ---------- + generation_config : Optional[GenerationConfig] + User's input. A partial GenerationConfig for overriding ChatConfig generation settings. + + Returns + ------ + json_str : str + A JSON string that corresponds to user's ``generation_config`` input. + Returns "" if ``generation_config`` unspecified. + """ + if generation_config is None: + return "" + return json.dumps(asdict(generation_config)) + + +def _parse_device_str(device: str) -> Tuple[tvm.runtime.Device, str]: + """Parse the input device identifier into device name and id. + + Parameters + ---------- + device : str + The device identifier to parse. + It can be "device_name" (e.g., "cuda") or + "device_name:device_id" (e.g., "cuda:1"). + + Returns + ------- + dev : tvm.runtime.Device + The device. + + device_name : str + The name of the device. + """ + device_err_msg = ( + f"Invalid device name: {device}. Please enter the device in the form " + "'device_name:device_id' or 'device_name', where 'device_name' needs to be " + "one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'." + ) + device_args = device.split(":") + if len(device_args) == 1: + device_name, device_id = device_args[0], 0 + elif len(device_args) == 2: + device_name, device_id = device_args[0], int(device_args[1]) + elif len(device_args) > 2: + raise ValueError(device_err_msg) + + if device_name == "cuda": + device = tvm.cuda(device_id) + elif device_name == "metal": + device = tvm.metal(device_id) + elif device_name == "vulkan": + device = tvm.vulkan(device_id) + elif device_name == "rocm": + device = tvm.rocm(device_id) + elif device_name == "opencl": + device = tvm.opencl(device_id) + elif device_name == "auto": + device, device_name = _detect_local_device(device_id) + logging.info("System automatically detected device: %s", device_name) + else: + raise ValueError(device_err_msg) + + return device, device_name + + +def _detect_local_device(device_id: int = 0) -> Tuple[tvm.runtime.Device, str]: """Automatically detect the local device if user does not specify. Parameters @@ -451,8 +642,11 @@ def _detect_local_device(device_id: int = 0): Returns ------ - dev : Device + dev : tvm.runtime.Device The local device. + + device_name : str + The name of the device. """ if tvm.metal().exist: return tvm.metal(device_id), "metal" @@ -464,14 +658,14 @@ def _detect_local_device(device_id: int = 0): return tvm.vulkan(device_id), "vulkan" if tvm.opencl().exist: return tvm.opencl(device_id), "opencl" - logging.info( - "None of the following device is detected: metal, rocm, cuda, vulkan, opencl. Switch to llvm instead." + "None of the following device is detected: metal, rocm, cuda, vulkan, opencl. " + "Switch to llvm instead." ) return tvm.cpu(device_id), "llvm" -class ChatModule: +class ChatModule: # pylint: disable=too-many-instance-attributes r"""The ChatModule for MLC LLM. Examples @@ -520,7 +714,7 @@ class ChatModule: A ``ChatConfig`` instance partially filled. Will be used to override the ``mlc-chat-config.json``. - lib_path : Optional[str] + model_lib_path : Optional[str] The full path to the model library file to use (e.g. a ``.so`` file). If unspecified, we will use the provided ``model`` to search over possible paths. @@ -531,42 +725,15 @@ def __init__( model: str, device: str = "auto", chat_config: Optional[ChatConfig] = None, - lib_path: Optional[str] = None, + model_lib_path: Optional[str] = None, ): - device_err_msg = ( - f"Invalid device name: {device}. Please enter the device in the form " - "'device_name:device_id' or 'device_name', where 'device_name' needs to be " - "one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'." - ) - - # 0. Retrieve device_name and device_id (if any, default 0) from device arg - device_args = device.split(":") - if len(device_args) == 1: - device_name, device_id = device_args[0], 0 - elif len(device_args) == 2: - device_name, device_id = device_args[0], int(device_args[1]) - elif len(device_args) > 2: - raise ValueError(device_err_msg) - - # 1. Get self.device - if device_name == "cuda": - self.device = tvm.cuda(device_id) - elif device_name == "metal": - self.device = tvm.metal(device_id) - elif device_name == "vulkan": - self.device = tvm.vulkan(device_id) - elif device_name == "rocm": - self.device = tvm.rocm(device_id) - elif device_name == "opencl": - self.device = tvm.opencl(device_id) - elif device_name == "auto": - self.device, device_name = _detect_local_device(device_id) - logging.info(f"System automatically detected device: {device_name}") - else: - raise ValueError(device_err_msg) + # 0. Get device: + # Retrieve device_name and device_id (if any, default 0) from device arg + self.device, device_name = _parse_device_str(device) device_type = self.device.device_type + device_id = self.device.device_id - # 2. Populate chat module and their functions + # 1. Populate chat module and their functions fcreate_chat_mod = tvm.get_global_func("mlc.llm_chat_create") assert fcreate_chat_mod is not None chat_mod = fcreate_chat_mod(device_type, device_id) @@ -592,42 +759,55 @@ def __init__( self._get_role0_func = chat_mod["get_role0"] self._get_role1_func = chat_mod["get_role1"] - # 3. Look up model_path + # 2. Look up model_path self.model_path, self.config_file_path = _get_model_path(model) - # 4. Instantiate chat_config + # 3. Instantiate chat_config self.chat_config = _get_chat_config(self.config_file_path, chat_config) - # 5. Look up model library - self.lib_path = _get_lib_module_path( - model, self.model_path, self.chat_config, lib_path, device_name, self.config_file_path + # 4. Look up model library + self.model_lib_path = _get_lib_module_path( + model, + self.model_path, + self.chat_config, + model_lib_path, + device_name, + self.config_file_path, ) - # 6. Call reload + # 5. Call reload user_chat_config_json_str = _convert_chat_config_to_json_str( self.chat_config, self.chat_config.conv_template ) - self._reload(self.lib_path, self.model_path, user_chat_config_json_str) - - # 7. Save default config values. - self.default_chat_config = asdict(self.chat_config) - if "conv_config" in self.default_chat_config: - self.default_chat_config.pop("conv_config") - self.default_conv_config = json.loads(self._get_config_json())["conv_config"] - - def generate(self, prompt: str, progress_callback=None) -> str: - r"""A high-level method that returns the full response from the chat module given a user prompt. - User can optionally specify which callback method to use upon receiving the response. By default, - no callback will be applied. + self._reload(self.model_lib_path, self.model_path, user_chat_config_json_str) + + def generate( + self, + prompt: Union[str, List[ChatMessage]], + generation_config: Optional[GenerationConfig] = None, + progress_callback=None, + ) -> Union[str, List[str]]: + r"""A high-level method that returns the full response from the chat module given a user + prompt. User can optionally specify which callback method to use upon receiving the + response. By default, no callback will be applied. Parameters ---------- - prompt : str + prompt : Union[str, List[ChatMessage]] The user input prompt, i.e. a question to ask the chat module. + It can also be the whole conversation history (list of messages with role and content) + eg: ```[ + ChatMessage(role="user", content="Hello, how are you?"), + ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"), + ChatMessage(role="user", content="I'm good too."), + ]``` + generation_config: Optional[GenerationConfig] + The generation config object to override the ChatConfig generation settings. progress_callback: object - The optional callback method used upon receiving a newly generated message from the chat module. - See `mlc_chat/callback.py` for a full list of available callback classes. Currently, only - streaming to stdout callback method is supported, see `Examples` for more detailed usage. + The optional callback method used upon receiving a newly generated message from the + chat module. See `mlc_chat/callback.py` for a full list of available callback classes. + Currently, only streaming to stdout callback method is supported, see `Examples` for + more detailed usage. Returns ------- @@ -643,31 +823,44 @@ def generate(self, prompt: str, progress_callback=None) -> str: # the chat module streaming to stdout piece by piece, and in the end we receive the # full response as a single string `output`. - from mlc_chat import ChatModule, callback + from mlc_chat import ChatModule, GenerationConfig, callback cm = ChatModule(xxx) prompt = "what's the color of banana?" - output = cm.generate(prompt, callback.StreamToStdout(callback_interval=2)) + output = cm.generate( + prompt, GenerationConfig(temperature=0.8), callback.StreamToStdout(callback_interval=2) + ) print(output) """ - self._prefill(prompt) + new_msgs = [] + num_return_sequences = 1 + return_str = True + if (generation_config is not None) and (generation_config.n is not None): + num_return_sequences = generation_config.n + return_str = False + else: + num_return_sequences = 1 - if not progress_callback: - while not self._stopped(): - self._decode() - new_msg = self._get_message() - return new_msg - - # apply callback with a rate of callback_interval - i, new_msg = 0, "" - while not self._stopped(): - self._decode() - if i % progress_callback.callback_interval == 0 or self._stopped(): - new_msg = self._get_message() - progress_callback(new_msg) - i += 1 - progress_callback(stopped=True) + for _ in range(num_return_sequences): + self.reset_chat() + self._prefill(prompt, generation_config=generation_config) - return new_msg + if not progress_callback: + while not self._stopped(): + self._decode(generation_config=generation_config) + new_msg = self._get_message() + new_msgs.append(new_msg) + else: + # apply callback with a rate of callback_interval + i, new_msg = 0, "" + while not self._stopped(): + self._decode(generation_config=generation_config) + if i % progress_callback.callback_interval == 0 or self._stopped(): + new_msg = self._get_message() + progress_callback(new_msg) + i += 1 + progress_callback(stopped=True) + new_msgs.append(new_msg) + return new_msgs[0] if return_str else new_msgs def reset_chat(self, chat_config: Optional[ChatConfig] = None): r"""Reset the chat session, clear all chat history, and potentially @@ -696,46 +889,7 @@ def reset_chat(self, chat_config: Optional[ChatConfig] = None): # Second argument is `partial_update = True` self._load_json_override_func(user_chat_config_json_str, True) - def update_chat_config(self, new_chat_config: ChatConfig): - r"""Update the chat config, or use the currently used default values if - values are None. - - Parameters - ---------- - chat_config : ChatConfig - A ``ChatConfig`` instance partially filled. The chat module will - override the default values with it. - - Note - ---- - This is inteneded for use in the completions api to allow users to specify - config values and use defaults if they are not passed to the request. - """ - - new_chat_config_dict = asdict(new_chat_config) - - # Override chat config values if they are present. Use default values if not. - config_updates_dict = {} - for k, default_value in self.default_chat_config.items(): - new_value = new_chat_config_dict.get(k) - config_updates_dict[k] = new_value if new_value else default_value - - # Add conv_config values if there are ones. - new_conv_config_dict = new_chat_config_dict.get("conv_config") - if new_conv_config_dict: - conv_config_updates_dict = {} - for k, default_value in self.default_conv_config.items(): - new_value = new_conv_config_dict.get(k) - conv_config_updates_dict[k] = new_value if new_value else default_value - config_updates_dict["conv_config"] = conv_config_updates_dict - - # Current logic does not allow partial ChatConfig without specifying the - # conv_template. Hence we use the conv_template after considering potential overrides. - user_chat_config_json_str = json.dumps(config_updates_dict) - # Second argument is `partial_update = True` - self._load_json_override_func(user_chat_config_json_str, True) - - def embed_text(self, input: str): + def embed_text(self, input: str): # pylint: disable=redefined-builtin r"""Given a text input, returns its embedding in the LLM. Parameters @@ -759,7 +913,7 @@ def embed_text(self, input: str): return self._embed_func(input, PlaceInPrompt.Middle.value) def stats(self, verbose=False) -> str: - r"""Get the runtime stats of the encoding step, decoding step, (and embedding step if exists) + r"""Get the runtime stats of the encoding step, decoding step (and embedding step if exists) of the chat module in text form. Returns @@ -769,8 +923,7 @@ def stats(self, verbose=False) -> str: """ if verbose: return self._verbose_runtime_stats_text_func() - else: - return self._runtime_stats_text_func() + return self._runtime_stats_text_func() def benchmark_generate(self, prompt: str, generate_length: int) -> str: r"""Controlled generation with input prompt and fixed number of @@ -844,28 +997,77 @@ def _unload(self): def _prefill( self, - input: str, + input: Union[str, List[ChatMessage]], # pylint: disable=redefined-builtin decode_next_token: bool = True, place_in_prompt: PlaceInPrompt = PlaceInPrompt.All, + generation_config: Optional[GenerationConfig] = None, ): r"""Run prefill stage for a given input and optionally decode the first output token. User can decide where to place the input in the prompt. Parameters ---------- - input : str - The user input string. + input : Union[str, List[ChatMessage]] + The user input prompt, i.e. a question to ask the chat module. + It can also be the whole conversation history (list of messages with role and content) + eg: ```[ + ChatMessage(role="user", content="Hello, how are you?"), + ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"), + ChatMessage(role="user", content="I'm good too."), + ]``` decode_next_token : bool Whether to decode the next token after prefilling. place_in_prompt: PlaceInPrompt The place of the input message in the prompt. See `class PlaceInPrompt` for details. + generation_config: Optional[GenerationConfig] + The generation config to override the ChatConfig generation settings. """ - self._prefill_func(input, decode_next_token, place_in_prompt.value) + generation_config = _get_generation_config(self.chat_config, generation_config) + generation_config_str = _convert_generation_config_to_json_str(generation_config) + + if isinstance(input, list): + # Populate conversation.messages using load_json_override + if len(input) > 1: + conv_config = json.loads(self._get_config_json())["conv_config"] + messages = [] + role0 = self._get_role_0() + role1 = self._get_role_1() + for _, msg in enumerate(input[:-1]): + role = msg.role + content = msg.content + if role == "user": + messages.append([role0, content]) + elif role == "assistant": + messages.append([role1, content]) + else: + raise ValueError("Only user and assistant roles are supported.") + if not input[-1].role == "user": + raise ValueError("Last message should be from user.") + conv_config["messages"] = messages + conv_config["offset"] = 0 + # Otherwise, the offset will be set to the length of the conversation, + # which means history will be retained even after calling reset_chat + self._load_json_override( + json.dumps({"conv_config": conv_config}), + partial_update=True, + ) + input_str = input[-1].content + else: + input_str = input + + self._prefill_func( + input_str, decode_next_token, place_in_prompt.value, generation_config_str + ) - def _embed(self, input: str, place_in_prompt: PlaceInPrompt = PlaceInPrompt.All): - r"""A more fine-grained embedding API. Given a text input, get the embedding of the tokenized prompt. - User can decide where to place the input in the prompt. This functionality usually aids the subsequent - call to :func:`_prefill_with_embed`. + def _embed( + self, + input: str, # pylint: disable=redefined-builtin + place_in_prompt: PlaceInPrompt = PlaceInPrompt.All, + generation_config: Optional[GenerationConfig] = None, + ): + r"""A more fine-grained embedding API. Given a text input, get the embedding of the + tokenized prompt. User can decide where to place the input in the prompt. This functionality + usually aids the subsequent call to :func:`_prefill_with_embed`. Parameters ---------- @@ -873,15 +1075,25 @@ def _embed(self, input: str, place_in_prompt: PlaceInPrompt = PlaceInPrompt.All) The user input string. place_in_prompt: PlaceInPrompt The place of the input message in the prompt. See `class PlaceInPrompt` for details. + generation_config: Optional[GenerationConfig] + The generation config to override the ChatConfig generation settings. Returns ------- embedding : tvm.runtime.NDArray The embedding of the text. """ - return self._embed_func(input, place_in_prompt.value) + generation_config = _get_generation_config(self.chat_config, generation_config) + generation_config_str = _convert_generation_config_to_json_str(generation_config) - def _prefill_with_embed(self, embedding: tvm.runtime.NDArray, decode_next_token: bool = True): + return self._embed_func(input, place_in_prompt.value, generation_config_str) + + def _prefill_with_embed( + self, + embedding: tvm.runtime.NDArray, + decode_next_token: bool = True, + generation_config: Optional[GenerationConfig] = None, + ): r"""Given an embedding, run the prefill stage and optionally decode the first output token. Parameters @@ -890,14 +1102,26 @@ def _prefill_with_embed(self, embedding: tvm.runtime.NDArray, decode_next_token: The embedding of user input. decode_next_token : bool Whether to decode the next token after prefilling. + generation_config: Optional[GenerationConfig] + The generation config to override the ChatConfig generation settings. """ - self._prefill_with_embed_func(embedding, decode_next_token) + generation_config = _get_generation_config(self.chat_config, generation_config) + generation_config_str = _convert_generation_config_to_json_str(generation_config) - def _decode(self): + self._prefill_with_embed_func(embedding, decode_next_token, generation_config_str) + + def _decode(self, generation_config: Optional[GenerationConfig] = None): r"""Decode the next token, the decoding result is stored in a buffer and can be retrieved by :func:`get_message`. + + Parameters + ---------- + generation_config: Optional[GenerationConfig] + The generation config to override the ChatConfig generation settings. """ - self._decode_func() + generation_config = _get_generation_config(self.chat_config, generation_config) + generation_config_str = _convert_generation_config_to_json_str(generation_config) + self._decode_func(generation_config_str) def _stopped(self) -> bool: r"""Check if the stop condition is met for the current round. @@ -940,8 +1164,8 @@ def _load_json_override(self, config_str: str, partial_update: bool = False): config_str : str A json config string that partially specifies some of the options. partial_update : bool - Whether it's a partial update or full update, if set to true, we perform a partial update - on some of the provided options; if set to false, all options must be provided. + Whether it's a partial update or full update. If set to true, we perform a partial + update on some of the provided options; if set to false, all options must be provided. """ self._load_json_override_func(config_str, partial_update) diff --git a/python/mlc_chat/cli/benchmark.py b/python/mlc_chat/cli/benchmark.py index bcbb4eca53..0a4d5d97f3 100644 --- a/python/mlc_chat/cli/benchmark.py +++ b/python/mlc_chat/cli/benchmark.py @@ -1,7 +1,7 @@ """A command line tool for benchmarking a chat model.""" import argparse -from mlc_chat import ChatModule +from mlc_chat import ChatConfig, ChatModule parser = argparse.ArgumentParser(description="Benchmark an MLC LLM ChatModule.") parser.add_argument( @@ -13,6 +13,21 @@ the model folder over possible paths.""", required=True, ) +parser.add_argument( + "--model-lib", + type=str, + help="""The compiled model library. In MLC LLM, an LLM is compiled to a shared or static + library (.so or .a), which contains GPU computation to efficiently run the LLM. MLC Chat, + as the runtime of MLC LLM, depends on the compiled model library to generate tokens. + """, + required=False, +) +parser.add_argument( + "--num-shards", + type=int, + help="Number of GPUs to be used.", + required=False, +) parser.add_argument( "--device", type=str, @@ -40,7 +55,14 @@ def main(): """The main function that runs the benchmarking.""" args = parser.parse_args() - chat_module = ChatModule(model=args.model, device=args.device) + chat_module = ChatModule( + model=args.model, + device=args.device, + chat_config=ChatConfig( + num_shards=args.num_shards, + ), + model_lib_path=args.model_lib, + ) output = chat_module.benchmark_generate(args.prompt, generate_length=args.generate_length) print(f"Generated text:\n{output}\n") print(f"Statistics: {chat_module.stats(verbose=True)}") diff --git a/python/mlc_chat/cli/compile.py b/python/mlc_chat/cli/compile.py new file mode 100644 index 0000000000..31b639a68f --- /dev/null +++ b/python/mlc_chat/cli/compile.py @@ -0,0 +1,133 @@ +"""Command line entrypoint of compilation.""" +import argparse +import logging +from pathlib import Path +from typing import Union + +from mlc_chat.compiler import ( # pylint: disable=redefined-builtin + MODELS, + QUANT, + OptimizationFlags, + compile, +) + +from ..support.auto_config import detect_config, detect_model_type +from ..support.auto_target import detect_target_and_host + +logging.basicConfig( + level=logging.DEBUG, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="[{asctime}] {levelname} {filename}:{lineno}: {message}", +) + + +def main(): + """Parse command line argumennts and call `mlc_llm.compiler.compile`.""" + + def _parse_config(path: Union[str, Path]) -> Path: + try: + return detect_config(path) + except ValueError as err: + raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}") + + def _parse_output(path: Union[str, Path]) -> Path: + path = Path(path) + parent = path.parent + if not parent.is_dir(): + raise argparse.ArgumentTypeError(f"Directory does not exist: {parent}") + return path + + parser = argparse.ArgumentParser("MLC LLM Compiler") + parser.add_argument( + "--config", + type=_parse_config, + required=True, + help="Path to config.json file or to the directory that contains config.json, which is " + "a HuggingFace standard that defines model architecture, for example, " + "https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json", + ) + parser.add_argument( + "--quantization", + type=str, + required=True, + choices=list(QUANT.keys()), + help="Quantization format.", + ) + parser.add_argument( + "--model-type", + type=str, + default="auto", + choices=["auto"] + list(MODELS.keys()), + help="Model architecture, for example, llama. If not set, it is inferred " + "from the config.json file.", + ) + parser.add_argument( + "--device", + type=str, + default="auto", + help="The GPU device to compile the model to. If not set, it is inferred from locally " + "available GPUs.", + ) + parser.add_argument( + "--host", + type=str, + default="auto", + choices=[ + "auto", + "arm", + "arm64", + "aarch64", + "x86-64", + ], + help="The host CPU ISA to compile the model to. If not set, it is inferred from the " + "local CPU.", + ) + parser.add_argument( + "--opt", + type=OptimizationFlags.from_str, + default="", + help="Optimization flags. MLC LLM maintains a predefined set of optimization flags, " + "denoted as O0, O1, O2, O3, where O0 means no optimization, O2 means majority of them, " + "and O3 represents extreme optimization that could potentially break the system. " + "Meanwhile, optimization flags could be explicitly specified via details knobs, e.g. " + '--opt="cutlass_attn=1;cutlass_norm=0;cublas_gemm=0;cudagraph=0"', + ) + parser.add_argument( + "--prefix-symbols", + type=str, + default="", + help='Adding a prefix to all symbols exported. Similar to "objcopy --prefix-symbols". ' + "This is useful when compiling multiple models into a single library to avoid symbol " + "conflicts. Differet from objcopy, this takes no effect for shared library.", + ) + parser.add_argument( + "--output", + "-o", + type=_parse_output, + required=True, + help="The name of the output file. The suffix determines if the output file is a " + "shared library or a static library. Available suffixes: " + "1) Linux: .so (shared), .tar (static); " + "2) macOS: .dylib (shared), .tar (static); " + "3) Windows: .dll (shared), .tar (static); " + "4) Android, iOS: .tar (static); " + "5) Web: .wasm (web assembly)", + ) + parsed = parser.parse_args() + target, build_func = detect_target_and_host(parsed.device, parsed.host) + parsed.model_type = detect_model_type(parsed.model_type, parsed.config) + compile( + config=parsed.config, + quantization=parsed.quantization, + model_type=parsed.model_type, + target=target, + opt=parsed.opt, + build_func=build_func, + prefix_symbols=parsed.prefix_symbols, + output=parsed.output, + ) + + +if __name__ == "__main__": + main() diff --git a/python/mlc_chat/compiler/__init__.py b/python/mlc_chat/compiler/__init__.py new file mode 100644 index 0000000000..4905e8ac91 --- /dev/null +++ b/python/mlc_chat/compiler/__init__.py @@ -0,0 +1,10 @@ +""" +A compiler for MLC Chat. By default, it is not imported to MLC Chat to avoid unnecessary dependency, +but users could optionally import it if they want to use the compiler. +""" +from . import compiler_pass +from .compile import CompileArgs, compile # pylint: disable=redefined-builtin +from .flags_optimization import OptimizationFlags +from .model import MODEL_PRESETS, MODELS, Model +from .parameter import ExternMapping, HuggingFaceLoader, QuantizeMapping +from .quantization import QUANT diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/compiler/compile.py new file mode 100644 index 0000000000..5b77a94f81 --- /dev/null +++ b/python/mlc_chat/compiler/compile.py @@ -0,0 +1,69 @@ +"""Python entrypoint of compilation.""" +import dataclasses +from io import StringIO +from pathlib import Path +from typing import Callable + +from tvm import IRModule, relax +from tvm.target import Target + +from ..compiler.model import Model +from ..support.style import bold +from .flags_optimization import OptimizationFlags + + +@dataclasses.dataclass +class CompileArgs: # pylint: disable=too-many-instance-attributes + """Arguments to MLC LLM's compiler.""" + + config: Path + quantization: str + model: Model + target: Target + opt: OptimizationFlags + build_func: Callable[[IRModule, "CompileArgs"], None] + prefix_symbols: str + output: Path + + +def _echo_args(args: CompileArgs) -> None: + out = StringIO() + print(f"{bold('Compiling with arguments:')}", file=out) + print(f" {bold('--config'):<25} {args.config}", file=out) + print(f" {bold('--quantization'):<25} {args.quantization}", file=out) + print(f" {bold('--model-type'):<25} {args.model.name}", file=out) + print(f" {bold('--target'):<25} {args.target.export()}", file=out) + print(f" {bold('--opt'):<25} {args.opt}", file=out) + print(f" {bold('--output'):<25} {args.output}", file=out) + print(out.getvalue().rstrip()) + + +def _compile(args: CompileArgs): + model_config = args.model.config.from_file(args.config) + model = args.model.model(model_config) + mod, named_params = model.export_tvm( + spec=model.get_default_spec(), # type: ignore + ) + with args.target: + mod = relax.get_pipeline("mlc_llm")(mod) + mod.show(black_format=False) + for name, param in named_params: + print(f"{name}: {param.shape} {param.dtype}") + + +def compile( # pylint: disable=too-many-arguments,redefined-builtin + config: Path, + quantization, + model_type: Model, + target: Target, + opt: OptimizationFlags, + build_func: Callable[[IRModule, CompileArgs], None], + prefix_symbols: str, + output: Path, +): + """Compile a model given its configuration and quantization format to a specific target.""" + args = CompileArgs( + config, quantization, model_type, target, opt, build_func, prefix_symbols, output + ) + _echo_args(args) + _compile(args) diff --git a/python/mlc_chat/compiler/compiler_pass/__init__.py b/python/mlc_chat/compiler/compiler_pass/__init__.py new file mode 100644 index 0000000000..762ba8c1e0 --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/__init__.py @@ -0,0 +1,2 @@ +"""Compiler passes used in MLC LLM.""" +from . import pipeline as _pipeline diff --git a/python/mlc_chat/compiler/compiler_pass/clean_up_tir_attrs.py b/python/mlc_chat/compiler/compiler_pass/clean_up_tir_attrs.py new file mode 100644 index 0000000000..71848ba546 --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/clean_up_tir_attrs.py @@ -0,0 +1,31 @@ +"""A compiler pass that cleans up undesired TIR attrs.""" +from typing import List + +import tvm +from tvm.ir.module import IRModule + + +@tvm.transform.module_pass(opt_level=0, name="CleanUpTIRAttrs") +class CleanUpTIRAttrs: # pylint: disable=too-few-public-methods + """A compiler pass that cleans up undesired TIR attrs.""" + + def __init__(self, attrs: List[str]): + self.attrs = attrs + + def transform_module( + self, + mod: IRModule, + _ctx: tvm.transform.PassContext, + ) -> IRModule: + """IRModule-level transformation""" + for g_var in list(mod.functions): + func = mod[g_var] + changed = False + for attr in self.attrs: + if func.attrs is not None and attr in func.attrs: + func = func.without_attr(attr) + changed = True + break + if changed: + mod[g_var] = func + return mod diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_decode_matmul_ewise.py b/python/mlc_chat/compiler/compiler_pass/fuse_decode_matmul_ewise.py new file mode 100644 index 0000000000..0e02f2ae5a --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/fuse_decode_matmul_ewise.py @@ -0,0 +1,81 @@ +"""A compiler pass that fuses decode + matmul + elementwise.""" +import tvm +from tvm import IRModule, relax +from tvm.relax.dpl.pattern import GlobalVarPattern, TuplePattern, is_op, wildcard + + +@tvm.transform.module_pass(opt_level=0, name="FuseDecodeMatmulEwise") +class FuseDecodeMatmulEwise: # pylint: disable=too-few-public-methods + """A compiler pass that fuses decode + matmul + elementwise.""" + + def transform_module( + self, + mod: IRModule, + _ctx: tvm.transform.PassContext, + ) -> IRModule: + """IRModule-level transformation""" + for n_aux_tensor in [1, 2, 3, 4]: + for match_ewise in [0, 1, 2, 6]: + if match_ewise == 6 and n_aux_tensor != 4: + continue + mod = relax.transform.FuseOpsByPattern( + [ + ( + "decode_matmul", + *_pattern(match_ewise, n_aux_tensor), + ) + ] + )(mod) + mod = relax.transform.FuseTIR()(mod) + return mod + + +def _pattern(match_ewise: int, n_aux_tensor: int): + # pylint: disable=invalid-name + w_scaled = wildcard() + x = wildcard() + w = is_op("relax.call_tir")( + GlobalVarPattern(), + TuplePattern([w_scaled] + [wildcard() for _ in range(n_aux_tensor)]), + add_constraint=False, + ) + matmul = is_op("relax.call_tir")( + GlobalVarPattern(), + TuplePattern([x, w] + [wildcard() for _ in range(match_ewise)]), + add_constraint=False, + ) + # pylint: enable=invalid-name + annotations = { + "w_scaled": w_scaled, + "x": x, + "w": w, + "matmul": matmul, + } + + def _check_decoding(ctx: relax.transform.PatternCheckContext) -> bool: + call = ctx.annotated_expr["w"] + if not isinstance(call, relax.Call): + return False + g_var = call.args[0] + if not isinstance(g_var, relax.GlobalVar): + return False + return g_var.name_hint.startswith("decode") or g_var.name_hint.startswith("fused_decode") + + def _check_matmul(ctx: relax.transform.PatternCheckContext) -> bool: + call = ctx.annotated_expr["matmul"] + if not isinstance(call, relax.Call): + return False + g_var = call.args[0] + if not isinstance(g_var, relax.GlobalVar): + return False + return ( + g_var.name_hint.startswith("matmul") + or g_var.name_hint.startswith("fused_matmul") + or g_var.name_hint.startswith("NT_matmul") + or g_var.name_hint.startswith("fused_NT_matmul") + ) + + def _check(ctx: relax.transform.PatternCheckContext) -> bool: + return _check_decoding(ctx) and _check_matmul(ctx) + + return matmul, annotations, _check diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py b/python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py new file mode 100644 index 0000000000..96678fa951 --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py @@ -0,0 +1,83 @@ +"""A compiler pass that fuses decode + take.""" +import tvm +from tvm import IRModule, relax, tir +from tvm.relax.dpl.pattern import ( + GlobalVarPattern, + TuplePattern, + is_const, + is_op, + wildcard, +) + + +@tvm.transform.module_pass(opt_level=0, name="FuseDecodeTake") +class FuseDecodeTake: # pylint: disable=too-few-public-methods + """A compiler pass that fuses decode + take.""" + + def transform_module( + self, + mod: IRModule, + _ctx: tvm.transform.PassContext, + ) -> IRModule: + """IRModule-level transformation""" + for n_aux_tensor in [2, 3]: + for match_tir_vars in [False, True]: + mod = relax.transform.FuseOpsByPattern( + [ + ( + "decode_take", + *_pattern(n_aux_tensor, match_tir_vars), + ) + ] + )(mod) + mod = relax.transform.FuseTIR()(mod) + for g_var, func in mod.functions.items(): + name = g_var.name_hint + if isinstance(func, tir.PrimFunc) and (("fused_decode" in name) and ("take" in name)): + mod = tvm.IRModule({"main": func}) + sch = tir.Schedule(mod) + sch.compute_inline("decode") + mod[g_var] = sch.mod["main"] + return mod + + +def _pattern(n_aux_tensor: int, match_tir_vars: bool): + decode = is_op("relax.call_tir")( + GlobalVarPattern(), + TuplePattern([wildcard() for _ in range(n_aux_tensor)]), + add_constraint=False, + ) + indices = ~is_const() + if match_tir_vars: + call_tir_args_take = [ + GlobalVarPattern(), + TuplePattern([decode, indices]), + wildcard(), + ] + else: + call_tir_args_take = [ + GlobalVarPattern(), + TuplePattern([decode, indices]), + ] + take = is_op("relax.call_tir")( + *call_tir_args_take, + add_constraint=False, + ) + annotations = { + "take": take, + "decode": decode, + "indices": indices, + } + + def _check(ctx: relax.transform.PatternCheckContext) -> bool: + take = ctx.annotated_expr["take"] + decode = ctx.annotated_expr["decode"] + if not isinstance(decode, relax.expr.Call): + return False + if not isinstance(take.args[0], relax.GlobalVar) or not isinstance( + decode.args[0], relax.GlobalVar + ): + return False + return "take" in take.args[0].name_hint and "decode" in decode.args[0].name_hint + + return take, annotations, _check diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_decode_transpose.py b/python/mlc_chat/compiler/compiler_pass/fuse_decode_transpose.py new file mode 100644 index 0000000000..99bcb1b602 --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/fuse_decode_transpose.py @@ -0,0 +1,109 @@ +"""A compiler pass that fuses transpose + dequantize.""" +import tvm +from tvm import relax, tir +from tvm.ir.module import IRModule +from tvm.relax.analysis import remove_all_unused +from tvm.relax.expr_functor import PyExprMutator, mutator + + +@tvm.transform.module_pass(opt_level=0, name="FuseDecodeTranspose") +class FuseDecodeTranspose: # pylint: disable=too-few-public-methods + """A compiler pass that fuses transpose + dequantize.""" + + def __init__(self, skip_gemm: bool) -> None: + self.skip_gemm = skip_gemm + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """IRModule-level transformation""" + return _DecodeTransposeFuser(mod, skip_gemm=self.skip_gemm).transform() + + +@mutator +class _DecodeTransposeFuser(PyExprMutator): # pylint: disable=abstract-method + def __init__( + self, + mod: IRModule, + skip_gemm: bool, + ): + super().__init__(mod) + self.mod = mod + self.skip_gemm = skip_gemm + + def transform(self) -> IRModule: + """Entry point""" + for g_var, func in self.mod.functions.items(): + if isinstance(func, relax.Function): + updated_func = self.visit_expr(func) + updated_func = remove_all_unused(updated_func) + self.builder_.update_func(g_var, updated_func) + return self.builder_.get() + + def visit_call_( # pylint: disable=arguments-renamed + self, + call: relax.Call, + ) -> relax.Expr: + call = self.visit_expr_post_order(call) + if call.op != tvm.ir.Op.get("relax.matmul"): + return call + # Do not fuse decode-transpose for GeMM + if self.skip_gemm and ( + call.args[0].struct_info.ndim < 2 + or not isinstance(call.args[0].struct_info.shape[-2], tir.IntImm) + or call.args[0].struct_info.shape[-2].value != 1 + ): + return call + + matmul_rhs = self.lookup_binding(call.args[1]) + if ( + not isinstance(matmul_rhs, relax.Call) + or matmul_rhs.op != tvm.ir.Op.get("relax.permute_dims") + or matmul_rhs.args[0].struct_info.ndim != 2 + or matmul_rhs.attrs.axes is not None + ): + return call + + transpose_input = self.lookup_binding(matmul_rhs.args[0]) + if ( + not isinstance(transpose_input, relax.Call) + or transpose_input.op != tvm.ir.Op.get("relax.call_tir") + or not transpose_input.args[0].name_hint.startswith("decode") + or not isinstance(transpose_input.struct_info, relax.TensorStructInfo) + ): + return call + + decode_tir_func = self.mod[transpose_input.args[0]] + assert isinstance(decode_tir_func, tir.PrimFunc) + if ( # pylint: disable=too-many-boolean-expressions + len(decode_tir_func.body.block.alloc_buffers) != 1 + or not isinstance(decode_tir_func.body.block.body, tir.SeqStmt) + or len(decode_tir_func.body.block.body) != 2 + or not isinstance(decode_tir_func.body.block.body[1], tir.For) + or not isinstance(decode_tir_func.body.block.body[1].body.body, tir.BlockRealize) + or decode_tir_func.body.block.body[1].body.body.block.name_hint != "T_transpose" + ): + return call + + new_func_buffers = [decode_tir_func.buffer_map[var] for var in decode_tir_func.params] + new_func_buffers[-1] = decode_tir_func.body.block.alloc_buffers[0] + new_func = tir.PrimFunc( + params=new_func_buffers, + body=tir.BlockRealize( + iter_values=[], + predicate=True, + block=tir.Block( + iter_vars=[], + reads=[], + writes=[], + name_hint="root", + body=decode_tir_func.body.block.body[0], + ), + ), + ) + # Call `renew_defs` for deep-copy to avoid IR node duplication in + # different PrimFuncs of an IRModule. + new_func = tir.stmt_functor.renew_defs(new_func) + g_var = self.builder_.add_func(new_func, func_name="decode") + decoded_matmul_rhs = self.builder_.emit( + relax.call_tir(g_var, transpose_input.args[1], out_sinfo=matmul_rhs.struct_info) + ) + return relax.op.matmul(call.args[0], decoded_matmul_rhs, out_dtype=call.attrs.out_dtype) diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_transpose_matmul.py b/python/mlc_chat/compiler/compiler_pass/fuse_transpose_matmul.py new file mode 100644 index 0000000000..ac1de41377 --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/fuse_transpose_matmul.py @@ -0,0 +1,153 @@ +"""A compiler pass that fuses transpose + matmul.""" +import tvm +from tvm import IRModule, relax, te, tir +from tvm.relax.dpl.pattern import is_op, wildcard +from tvm.relax.expr_functor import PyExprMutator, mutator + + +@tvm.transform.module_pass(opt_level=0, name="FuseTransposeMatmul") +class FuseTransposeMatmul: # pylint: disable=too-few-public-methods + """A compiler pass that fuses transpose + matmul.""" + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """IRModule-level transformation""" + mod = relax.transform.FuseOpsByPattern( + [ + ( + "transpose_matmul_fuse", + *_pattern(), + ), + ] + )(mod) + + transpose_matmul_codegen = _TransposeMatmulFuser(mod) + for g_var in mod.functions: + func = mod[g_var] + if isinstance(func, relax.Function): + func = transpose_matmul_codegen.visit_expr(func) + transpose_matmul_codegen.builder_.update_func(g_var, func) + return transpose_matmul_codegen.builder_.get() + + +def _pattern(): + """Pattern for transpose + matmul.""" + # pylint: disable=invalid-name + w = wildcard() + x = wildcard() + wT = is_op("relax.permute_dims")(w) + o = is_op("relax.matmul")(x, wT) + # pylint: enable=invalid-name + annotations = {"o": o, "w": w, "x": x, "wT": wT} + + def _check(context: relax.transform.PatternCheckContext) -> bool: + transpose_call = context.annotated_expr["wT"] + ndim = transpose_call.args[0].struct_info.ndim + if ndim == -1: + return False + if ndim == 2 and transpose_call.attrs.axes is None: + return True + axes = list(range(ndim)) + axes[-1], axes[-2] = axes[-2], axes[-1] + return list(transpose_call.attrs.axes) == axes + + return o, annotations, _check + + +# pylint: disable=missing-docstring,invalid-name + + +@mutator +class _TransposeMatmulFuser(PyExprMutator): # pylint: disable=abstract-method + def __init__(self, mod): + super().__init__(mod) + + def visit_call_( # pylint: disable=arguments-renamed + self, + call: relax.Call, + ) -> relax.Expr: + out_dtype = None + + def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: + nonlocal out_dtype + a_shape = list(a.shape) + b_shape = list(b.shape) + a_prepended = False + b_appended = False + if len(a_shape) == 1: + a_prepended = True + a_shape.insert(0, 1) + if len(b_shape) == 1: + b_appended = True + b_shape.append(1) + + is_a_larger = len(a_shape) > len(b_shape) + offset = len(a_shape) - len(b_shape) if is_a_larger else len(b_shape) - len(a_shape) + + a_relax = relax.Var("a", relax.TensorStructInfo(a.shape)) + bT_shape = list(b.shape) + bT_shape[-1], bT_shape[-2] = bT_shape[-2], bT_shape[-1] + bT_relax = relax.Var("b", relax.TensorStructInfo(bT_shape)) + output_shape = self.builder_.normalize( + relax.op.matmul(a_relax, bT_relax) + ).struct_info.shape + + def matmul_compute(*idx_spatial): + k = te.reduce_axis((0, a_shape[-1]), name="k") + + def multiply_compute(idx_reduce): + a_indices = [] + b_indices = [] + + for i in range(offset): + if is_a_larger: + a_indices.append(idx_spatial[i]) + else: + b_indices.append(idx_spatial[i]) + for i in range(offset, len(output_shape) - (2 - a_prepended - b_appended)): + a_dim = a_shape[i if is_a_larger else i - offset] + b_dim = b_shape[i if not is_a_larger else i - offset] + dim_equal = a_dim == b_dim + if not isinstance(dim_equal, tir.IntImm) or dim_equal == 0: + a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1 + b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1 + a_indices.append(0 if a_dim_is_one else idx_spatial[i]) + b_indices.append(0 if b_dim_is_one else idx_spatial[i]) + else: + a_indices.append(idx_spatial[i]) + b_indices.append(idx_spatial[i]) + + if not a_prepended: + a_indices.append(idx_spatial[-2 + b_appended]) + a_indices.append(idx_reduce) + if not b_appended: + b_indices.append(idx_spatial[-1]) + b_indices.append(idx_reduce) + + dtype = out_dtype + if dtype != "": + return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype) + return a(*a_indices) * b(*b_indices) + + return te.sum(multiply_compute(k), axis=k) + + return te.compute( + output_shape, + lambda *idx: matmul_compute(*idx), # pylint: disable=unnecessary-lambda + name="NT_matmul", + ) + + if isinstance(call.op, relax.GlobalVar): + function = self.builder_.get()[call.op] + if ( + "Composite" in function.attrs + and function.attrs["Composite"] == "transpose_matmul_fuse" + ): + out_dtype = function.ret_struct_info.dtype + return self.builder_.call_te( + te_transposed_matmul, + call.args[1], + call.args[0], + primfunc_name_hint="NT_matmul", + ) + + return super().visit_call_(call) diff --git a/python/mlc_chat/compiler/compiler_pass/lift_global_buffer_alloc.py b/python/mlc_chat/compiler/compiler_pass/lift_global_buffer_alloc.py new file mode 100644 index 0000000000..dc8eaa5bdc --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/lift_global_buffer_alloc.py @@ -0,0 +1,196 @@ +"""A compiler pass that lifts TIR-level global allocation to Relax.""" +from typing import Dict, List, Tuple + +import tvm +from tvm import relax, tir +from tvm.ir.module import IRModule +from tvm.relax.analysis import remove_all_unused +from tvm.relax.expr_functor import PyExprMutator, mutator + + +@tvm.transform.module_pass(opt_level=0, name="LiftTIRGlobalBufferAlloc") +class LiftTIRGlobalBufferAlloc: # pylint: disable=too-few-public-methods + """A compiler pass that lifts TIR-level global allocation to Relax.""" + + def transform_module( + self, + mod: IRModule, + _ctx: tvm.transform.PassContext, + ) -> IRModule: + """IRModule-level transformation""" + return _TIRGlobalAllocRewriter(mod).transform() + + +@mutator +class _TIRGlobalAllocRewriter(PyExprMutator): # pylint: disable=abstract-method + def __init__(self, mod: IRModule): + super().__init__(mod) + self.mod = mod + self.gv2new_tensor_sinfo: Dict[ + tvm.ir.GlobalVar, Tuple[List[relax.TensorStructInfo], tir.PrimFunc] + ] = {} + + def transform(self) -> IRModule: + """Entry point of the transformation""" + for g_var, func in self.mod.functions.items(): + if isinstance(func, tir.PrimFunc): + updated_func, tensor_sinfo_list = remove_global_buf_alloc(func) + if len(tensor_sinfo_list) > 0: + self.gv2new_tensor_sinfo[g_var] = (tensor_sinfo_list, func) + self.builder_.update_func(g_var, updated_func) + + self.mod = self.builder_.get() + for g_var, func in self.mod.functions.items(): + if not isinstance(func, relax.Function): + continue + updated_func = self.visit_expr(func) + updated_func = remove_all_unused(updated_func) + self.builder_.update_func(g_var, updated_func) + return self.builder_.get() + + def visit_call_(self, call: relax.Call): # pylint: disable=arguments-renamed + call = self.visit_expr_post_order(call) + if ( + call.op != tvm.ir.Op.get("relax.call_tir") + or call.args[0] not in self.gv2new_tensor_sinfo + ): + return call + + g_var = call.args[0] + tensor_sinfo, func_before_update = self.gv2new_tensor_sinfo[g_var] + + assert len(call.sinfo_args) == 1 + if any(_has_symbolic_var(sinfo) for sinfo in tensor_sinfo): + tensor_sinfo, success = _resolve_tir_var_mapping(func_before_update, call, tensor_sinfo) + if not success: + # Cannot resolve TIR var mapping. Fall back to no lifting. + self.builder_.update_func(g_var, func_before_update) + self.gv2new_tensor_sinfo.pop(g_var) + return call + + if isinstance(call.sinfo_args[0], relax.TensorStructInfo): + new_call = relax.Call( + call.op, + args=call.args, + sinfo_args=[relax.TupleStructInfo(list(call.sinfo_args) + tensor_sinfo)], + attrs=call.attrs, + ) + emitted_tuple = self.builder_.emit(new_call) + return relax.TupleGetItem(emitted_tuple, 0) + assert isinstance(call.sinfo_args[0], relax.TupleStructInfo) + return relax.Call( + call.op, + args=call.args, + sinfo_args=[relax.TupleStructInfo(list(call.sinfo_args[0].fields) + tensor_sinfo)], + attrs=call.attrs, + ) + + +def remove_global_buf_alloc( + func: tir.PrimFunc, +) -> Tuple[tir.PrimFunc, List[relax.TensorStructInfo]]: + """Remove the global buffer allocation for a given TIR PrimFunc.""" + assert isinstance(func.body, tir.BlockRealize) + params = list(func.params) + buffer_map = dict(func.buffer_map) + tensor_sinfo = [] + alloc_buffers = [] + + insertion_point = len(params) + while params[insertion_point - 1].dtype != "handle": + insertion_point -= 1 + assert insertion_point >= 1 + + prev_root_block = func.body.block + for buf_alloc in func.body.block.alloc_buffers: + if buf_alloc.scope() == "global": + param = tir.Var("var_" + buf_alloc.name, "handle") + params.insert(insertion_point, param) + insertion_point += 1 + buffer_map[param] = buf_alloc + tensor_sinfo.append(relax.TensorStructInfo(buf_alloc.shape, buf_alloc.dtype)) + else: + alloc_buffers.append(buf_alloc) + + if len(tensor_sinfo) == 0: + return func, [] + + assert len(prev_root_block.iter_vars) == 0 + assert len(prev_root_block.reads) == 0 + assert len(prev_root_block.writes) == 0 + assert len(prev_root_block.match_buffers) == 0 + assert prev_root_block.name_hint == "root" + assert prev_root_block.init is None + root_block = tir.Block( + iter_vars=[], + reads=[], + writes=[], + name_hint="root", + body=prev_root_block.body, + alloc_buffers=alloc_buffers, + annotations=prev_root_block.annotations, + ) + + updated_func = tir.PrimFunc( + params=params, + body=tir.BlockRealize(iter_values=[], predicate=True, block=root_block), + ret_type=func.ret_type, + buffer_map=buffer_map, + attrs=func.attrs, + ) + return updated_func, tensor_sinfo + + +def _has_symbolic_var(tensor_sinfo: relax.TensorStructInfo) -> bool: + assert isinstance(tensor_sinfo.shape, relax.ShapeExpr) + for dim in tensor_sinfo.shape.values: + if not isinstance(dim, tir.IntImm): + return True + return False + + +def _resolve_tir_var_mapping( # pylint: disable=too-many-locals + func: tir.PrimFunc, + call: relax.Call, + tensor_sinfo: List[relax.TensorStructInfo], +) -> Tuple[List[relax.TensorStructInfo], bool]: + """Resolve the TIR symbolic var relationship across sides of PrimFunc and Relax Function""" + var_map: Dict[tir.Var, tir.PrimExpr] = {} + + n_arg = len(call.args[1].fields) + for i in range(n_arg): + buffer_shape = func.buffer_map[func.params[i]].shape + arg_shape = call.args[1][i].struct_info.shape.values + assert len(buffer_shape) == len(arg_shape) + for v_l, v_r in zip(buffer_shape, arg_shape): + if isinstance(v_l, tir.Var): + var_map[v_l] = v_r + elif not isinstance(v_l, tir.IntImm): + return [], False + + ret_tensors = call.sinfo_args[0] + ret_tensors = ( + [ret_tensors] + if isinstance(ret_tensors, relax.TensorStructInfo) + else list(ret_tensors.fields) + ) + for i, ret_tensor in enumerate(ret_tensors): + buffer_shape = func.buffer_map[func.params[n_arg + i]].shape + ret_tensor_shape = ret_tensor.shape.values + assert len(buffer_shape) == len(ret_tensor_shape) + for v_l, v_r in zip(buffer_shape, ret_tensor_shape): + if isinstance(v_l, tir.Var): + var_map[v_l] = v_r + elif not isinstance(v_l, tir.IntImm): + return [], False + + updated_tensor_sinfo = [] + for sinfo in tensor_sinfo: + if not _has_symbolic_var(sinfo): + updated_tensor_sinfo.append(sinfo) + continue + new_shape = [] + for dim in sinfo.shape.values: + new_shape.append(tir.stmt_functor.substitute(dim, var_map)) + updated_tensor_sinfo.append(relax.TensorStructInfo(new_shape, sinfo.dtype)) + return updated_tensor_sinfo, True diff --git a/python/mlc_chat/compiler/compiler_pass/pipeline.py b/python/mlc_chat/compiler/compiler_pass/pipeline.py new file mode 100644 index 0000000000..349a5af0f0 --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/pipeline.py @@ -0,0 +1,49 @@ +"""The compilation pipeline for LLM applications.""" +import tvm +from tvm import dlight as dl +from tvm.relax import register_pipeline # pylint: disable=no-name-in-module + +from .clean_up_tir_attrs import CleanUpTIRAttrs +from .fuse_decode_matmul_ewise import FuseDecodeMatmulEwise +from .fuse_decode_take import FuseDecodeTake +from .fuse_decode_transpose import FuseDecodeTranspose +from .fuse_transpose_matmul import FuseTransposeMatmul +from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc + + +@register_pipeline("mlc_llm") +def _mlc_llm_pipeline(): + @tvm.transform.module_pass(opt_level=0) + def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: + seq = tvm.transform.Sequential( + [ + # Phase 1. Passes on high-level operator graph + FuseDecodeTranspose(skip_gemm=False), + FuseTransposeMatmul(), + # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline + tvm.relax.transform.LegalizeOps(), + tvm.relax.transform.AnnotateTIROpPattern(), + tvm.relax.transform.FoldConstant(), + tvm.relax.transform.FuseOps(), + tvm.relax.transform.FuseTIR(), + # Phase 3. Passes on TIR + FuseDecodeMatmulEwise(), + FuseDecodeTake(), + tvm.relax.transform.DeadCodeElimination(), + CleanUpTIRAttrs(["op_pattern"]), + # Phase 4. Low-level Optimizations + dl.ApplyDefaultSchedule( + dl.gpu.Matmul(), + dl.gpu.GEMV(), + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), + ), + LiftTIRGlobalBufferAlloc(), + tvm.tir.transform.ForceNarrowIndexToInt32(), + ] + ) + mod = seq(mod._move()) # pylint: disable=protected-access + return mod + + return _pipeline diff --git a/python/mlc_chat/compiler/flags_optimization.py b/python/mlc_chat/compiler/flags_optimization.py new file mode 100644 index 0000000000..704903b419 --- /dev/null +++ b/python/mlc_chat/compiler/flags_optimization.py @@ -0,0 +1,77 @@ +"""Optimization flags""" +import argparse +import dataclasses +from io import StringIO + + +@dataclasses.dataclass +class OptimizationFlags: + """Optiization flags""" + + cutlass_attn: bool = True + cutlass_norm: bool = True + cublas_gemm: bool = False + cudagraph: bool = False + + def __repr__(self) -> str: + out = StringIO() + print(f"cutlass_attn={int(self.cutlass_attn)}", file=out, end="") + print(f";cutlass_norm={int(self.cutlass_norm)}", file=out, end="") + print(f";cublas_gemm={int(self.cublas_gemm)}", file=out, end="") + print(f";cudagraph={int(self.cudagraph)}", file=out, end="") + return out.getvalue().rstrip() + + @staticmethod + def from_str(source: str) -> "OptimizationFlags": + """Parse optimization flags from a string.""" + + if source in OPT_FLAG_PRESET: + return OPT_FLAG_PRESET[source] + + def boolean(value: str) -> bool: + if value == "0": + return False + if value == "1": + return True + raise ValueError(f"Invalid boolean value: {value}") + + parser = argparse.ArgumentParser(description="optimization flags") + parser.add_argument("--cutlass_attn", type=boolean, default=True) + parser.add_argument("--cutlass_norm", type=boolean, default=True) + parser.add_argument("--cublas_gemm", type=boolean, default=False) + parser.add_argument("--cudagraph", type=boolean, default=False) + results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) + return OptimizationFlags( + cutlass_attn=results.cutlass_attn, + cutlass_norm=results.cutlass_norm, + cublas_gemm=results.cublas_gemm, + cudagraph=results.cudagraph, + ) + + +OPT_FLAG_PRESET = { + "O0": OptimizationFlags( + cutlass_attn=False, + cutlass_norm=False, + cublas_gemm=False, + cudagraph=False, + ), + "O1": OptimizationFlags( + cutlass_attn=False, + cutlass_norm=True, + cublas_gemm=False, + cudagraph=False, + ), + "O2": OptimizationFlags( + cutlass_attn=True, + cutlass_norm=True, + cublas_gemm=False, + cudagraph=False, + ), + "O3": OptimizationFlags( + cutlass_attn=True, + cutlass_norm=True, + cublas_gemm=False, + cudagraph=True, + ), +} diff --git a/python/mlc_chat/compiler/model/__init__.py b/python/mlc_chat/compiler/model/__init__.py new file mode 100644 index 0000000000..a42dda9a09 --- /dev/null +++ b/python/mlc_chat/compiler/model/__init__.py @@ -0,0 +1,2 @@ +"""Model definition for the compiler.""" +from .model import MODEL_PRESETS, MODELS, Model diff --git a/python/mlc_chat/compiler/model/llama_config.py b/python/mlc_chat/compiler/model/llama_config.py new file mode 100644 index 0000000000..113acd456f --- /dev/null +++ b/python/mlc_chat/compiler/model/llama_config.py @@ -0,0 +1,108 @@ +"""Common configuration for Llama models.""" +import dataclasses +from typing import Any, Dict + +from ...support.config import ConfigBase + + +@dataclasses.dataclass +class LlamaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Llama model.""" + + hidden_act: str + hidden_size: int + intermediate_size: int + num_attention_heads: int + num_hidden_layers: int + rms_norm_eps: float + vocab_size: int + max_sequence_length: int = 2048 + position_embedding_base: int = 10000 + num_key_value_heads: int = 0 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + head_dim: int = 0 + + def __post_init__(self): + if self.num_key_value_heads == 0: + self.num_key_value_heads = self.num_attention_heads + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.num_attention_heads % self.num_key_value_heads == 0 + assert self.head_dim * self.num_attention_heads == self.hidden_size + + @staticmethod + def from_predefined(name: str) -> "LlamaConfig": + """Create a LlamaConfig from a predefined configuration.""" + return LlamaConfig.from_dict(CONFIG[name]) + + +CONFIG = { + "llama2_7b": { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "pad_token_id": 0, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + }, + "llama2_13b": { + "_name_or_path": "meta-llama/Llama-2-13b-hf", + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 13824, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 40, + "num_hidden_layers": 40, + "num_key_value_heads": 40, + "pad_token_id": 0, + "pretraining_tp": 2, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + }, + "llama2_70b": { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + }, +} diff --git a/mlc_llm/models/llama.py b/python/mlc_chat/compiler/model/llama_model.py similarity index 84% rename from mlc_llm/models/llama.py rename to python/mlc_chat/compiler/model/llama_model.py index 40df48180e..1106b38c56 100644 --- a/mlc_llm/models/llama.py +++ b/python/mlc_chat/compiler/model/llama_model.py @@ -1,41 +1,19 @@ -"""Implementation for Llama2 architecture""" -import dataclasses +""" +Implementation for Llama2 architecture. +TODO: add docstring +""" import math -from typing import Any, Dict, Optional +from typing import Optional from tvm import te, tir from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from .model_config_base import ModelConfig +from .llama_config import LlamaConfig # pylint: disable=invalid-name,missing-docstring -@dataclasses.dataclass -class LlamaConfig(ModelConfig): # pylint: disable=too-many-instance-attributes - hidden_act: str - hidden_size: int - intermediate_size: int - num_attention_heads: int - num_hidden_layers: int - rms_norm_eps: float - vocab_size: int - max_sequence_length: int = 2048 - position_embedding_base: int = 10000 - num_key_value_heads: int = 0 - kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) - head_dim: int = 0 - - def __post_init__(self): - if self.num_key_value_heads == 0: - self.num_key_value_heads = self.num_attention_heads - if self.head_dim == 0: - self.head_dim = self.hidden_size // self.num_attention_heads - assert self.num_attention_heads % self.num_key_value_heads == 0 - assert self.head_dim * self.num_attention_heads == self.hidden_size - - class RotaryEmbedding(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() @@ -117,14 +95,16 @@ def forward( # pylint: disable=too-many-locals self.k_cache.append(op.squeeze(k, axis=0)) self.v_cache.append(op.squeeze(v, axis=0)) - k = op.reshape(self.k_cache.view(total_seq_len), (t, b, h_kv, d)) - v = op.reshape(self.v_cache.view(total_seq_len), (t, b, h_kv, d)) + k = op.reshape(self.k_cache.view(total_seq_len), (b, t, h_kv, d)) + v = op.reshape(self.v_cache.view(total_seq_len), (b, t, h_kv, d)) if h_kv != h_q: k = k.repeat(h_q // h_kv, axis=2) v = v.repeat(h_q // h_kv, axis=2) - attn_weights = op.matmul( # [b, h, s, t] - q.permute_dims([0, 2, 1, 3]), # [b, h, s, d] - k.permute_dims([1, 2, 3, 0]), # [b, h, d, t] + q = q.permute_dims([0, 2, 1, 3]) # [b, h, s, d] + k = k.permute_dims([0, 2, 1, 3]) # [b, h, t, d] + v = v.permute_dims([0, 2, 1, 3]) # [b, h, t, d] + attn_weights = op.matmul( + q, k.permute_dims([0, 1, 3, 2]) # [b, h, s, d] x [b, h, d, t] = [b, h, s, t] ) / math.sqrt(d) dtype = attn_weights.dtype attn_weights = attn_weights.maximum(tir.min_value(dtype)).minimum(attention_mask) @@ -133,10 +113,7 @@ def forward( # pylint: disable=too-many-locals else: attn_weights = op.softmax(attn_weights.astype("float32"), axis=-1).astype(dtype) return self.o_proj( - op.matmul( # [b, h, s, d] - attn_weights, # [b, h, s, t] - v.permute_dims([1, 2, 0, 3]), # [b, h, t, d] - ) + op.matmul(attn_weights, v) # [b, h, s, t] x [b, h, t, d] = [b, h, s, d] .permute_dims([0, 2, 1, 3]) # [b, s, h, d] .reshape((b, s, h_q * d)) ) @@ -178,11 +155,11 @@ def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor class LlamaForCasualLM(nn.Module): - def __init__(self, config: LlamaConfig, dtype: str = "float32"): + def __init__(self, config: LlamaConfig): self.model = LlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.vocab_size = config.vocab_size - self.dtype = dtype + self.dtype = "float32" def to(self, dtype: Optional[str] = None): super().to(dtype=dtype) @@ -239,14 +216,26 @@ def get_default_spec(self): "prefill": { "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, }, "decode": { "inputs": nn.spec.Tensor([batch_size, 1], "int32"), "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, }, "softmax_with_temperature": { "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), "temperature": nn.spec.Tensor([], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, }, } return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/mlc_llm/models/llama_param_map.py b/python/mlc_chat/compiler/model/llama_parameter.py similarity index 70% rename from mlc_llm/models/llama_param_map.py rename to python/mlc_chat/compiler/model/llama_parameter.py index 3737893702..4c68fdc899 100644 --- a/mlc_llm/models/llama_param_map.py +++ b/python/mlc_chat/compiler/model/llama_parameter.py @@ -2,16 +2,17 @@ This file specifies how MLC's Llama parameter maps from other formats, for example HuggingFace PyTorch, HuggingFace safetensors. """ -import numpy as np +from typing import Callable, Dict, List -from mlc_llm.param_loader import ParameterMapping +import numpy as np -from .llama import LlamaConfig, LlamaForCasualLM +from ..parameter import ExternMapping +from .llama_config import LlamaConfig +from .llama_model import LlamaForCasualLM -def hf_torch(model_config: LlamaConfig) -> ParameterMapping: - """ - Returns a parameter mapping that maps from the names of MLC LLM parameters to +def huggingface(model_config: LlamaConfig, _) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to the names of HuggingFace PyTorch parameters. Parameters @@ -21,15 +22,15 @@ def hf_torch(model_config: LlamaConfig) -> ParameterMapping: Returns ------- - param_map : ParameterMapping + param_map : ExternMapping The parameter mapping from MLC to HuggingFace PyTorch. """ model = LlamaForCasualLM(model_config) _, named_params = model.export_tvm(spec=model.get_default_spec()) parameter_names = {name for name, _ in named_params} - name_map = {} - map_func = {} + param_map: Dict[str, List[str]] = {} + map_func: Dict[str, Callable] = {} unused_params = set() for i in range(model_config.num_hidden_layers): @@ -37,24 +38,24 @@ def hf_torch(model_config: LlamaConfig) -> ParameterMapping: attn = f"model.layers.{i}.self_attn" assert f"{attn}.qkv_proj.weight" in parameter_names map_func[f"{attn}.qkv_proj.weight"] = lambda q, k, v: np.concatenate([q, k, v], axis=0) - name_map[f"{attn}.qkv_proj.weight"] = ( + param_map[f"{attn}.qkv_proj.weight"] = [ f"{attn}.q_proj.weight", f"{attn}.k_proj.weight", f"{attn}.v_proj.weight", - ) + ] # Add gates in MLP mlp = f"model.layers.{i}.mlp" assert f"{mlp}.gate_up_proj.weight" in parameter_names map_func[f"{mlp}.gate_up_proj.weight"] = lambda gate, up: np.concatenate([gate, up], axis=0) - name_map[f"{mlp}.gate_up_proj.weight"] = ( + param_map[f"{mlp}.gate_up_proj.weight"] = [ f"{mlp}.gate_proj.weight", f"{mlp}.up_proj.weight", - ) + ] # inv_freq is not used in the model unused_params.add(f"{attn}.rotary_emb.inv_freq") for name in parameter_names: if name not in map_func: map_func[name] = lambda x: x - name_map[name] = (name,) - return ParameterMapping(name_map, map_func, unused_params) + param_map[name] = [name] + return ExternMapping(param_map, map_func, unused_params) diff --git a/python/mlc_chat/compiler/model/llama_quantization.py b/python/mlc_chat/compiler/model/llama_quantization.py new file mode 100644 index 0000000000..a263ba0c4d --- /dev/null +++ b/python/mlc_chat/compiler/model/llama_quantization.py @@ -0,0 +1,101 @@ +""" +Quantization specs for Llama2 architecture. +TODO: add docstring +""" +from typing import Callable, Dict, List, Optional + +import tvm +from tvm.runtime import NDArray + +from ..parameter import QuantizeMapping +from ..quantization import QuantizeConfig +from ..quantization.group_quantizer import te_quantize as te_group_quantize +from .llama_config import LlamaConfig +from .llama_model import LlamaForCasualLM + + +def huggingface_group_quantize( + model_config: LlamaConfig, + quantize_config: QuantizeConfig, + target: Optional[tvm.target.Target] = None, +) -> QuantizeMapping: + """Returns a parameter mapping that maps a parameter in MLC LLM's model + definition to its eventual names and values after quantization. + + Parameters + ---------- + model_config : LlamaConfig + The configuration of the Llama model. + quantize_config : GroupQuantizeConfig + The configuration of the group quantization. + target : Optional[tvm.target.Target] + The target device to run the quantization on, by default None, which + means the quantization will be run on CPU. + + Returns + ------- + quantize_map : QuantizeMapping + The parameter mapping from a parameter in MLC LLM's model definition to + its eventual names and values after quantization. + """ + + def group_quantize( + param: NDArray, config: QuantizeConfig, target: Optional[tvm.target.Target] = None + ): + if target is None or target.kind.name == "llvm": + target = tvm.target.Target("llvm") + device = tvm.cpu() + elif target.kind.name == "cuda": + device = tvm.cuda() + else: + raise ValueError(f"Invalid target device: {target}") + param_tensor = tvm.te.placeholder(param.shape, dtype=param.dtype, name="param") + weight_compute, scale_compute, other_computes = te_group_quantize( # type: ignore + param_tensor, config + ) + s = tvm.te.create_schedule( # pylint: disable=invalid-name + [compute.op for compute in [weight_compute, scale_compute] + other_computes] + ) + if target.kind.name == "cuda": + # thread_binding for cuda + for compute in [weight_compute, scale_compute] + other_computes: + xo, xi = s[compute].split(compute.op.axis[0], 256) # pylint: disable=invalid-name + s[compute].bind(xo, tvm.te.thread_axis("blockIdx.x")) + s[compute].bind(xi, tvm.te.thread_axis("threadIdx.x")) + f_quantize = tvm.build( + s, [param_tensor, weight_compute, scale_compute], name="group_quantize", target=target + ) + weight = tvm.nd.empty(weight_compute.shape, weight_compute.dtype, device=device) + scale = tvm.nd.empty(scale_compute.shape, scale_compute.dtype, device=device) + f_quantize(param.copyto(device), weight, scale) + return weight, scale + + # Param check + assert ( + quantize_config.kind == "group_quantize" + ), f"Invalid quantization config: group quantization expected but got {quantize_config.kind}" + assert ( + quantize_config.name == "q4f16_1" + ), """Only support q4f16_1 quantization scheme for now.""" + + # Fetch model parameter & names + model = LlamaForCasualLM(model_config) + _, named_params = model.export_tvm(spec=model.get_default_spec()) + parameter_names = {name for name, _ in named_params} + + # Init mappings + param_map: Dict[str, List[str]] = {} + map_func: Dict[str, Callable] = {} + + # Dispatch quantization scheme + # Also see https://github.com/mlc-ai/mlc-llm/blob/main/mlc_llm/quantization/__init__.py + for name in parameter_names: + if "norm.weight" not in name and "embed" not in name: + param_map[name] = [f"{name}_quantized", f"{name}_scale"] + map_func[name] = lambda x: group_quantize(x, quantize_config, target=target) + else: + # skip these parameters + param_map[name] = [name] + map_func[name] = lambda x: [x] + + return QuantizeMapping(param_map, map_func) diff --git a/python/mlc_chat/compiler/model/model.py b/python/mlc_chat/compiler/model/model.py new file mode 100644 index 0000000000..3027a39500 --- /dev/null +++ b/python/mlc_chat/compiler/model/model.py @@ -0,0 +1,65 @@ +"""A centralized registry of all existing model architures and their configurations.""" +import dataclasses +from typing import Any, Callable, Dict + +from tvm.relax.frontend import nn + +from ..parameter import ExternMapping, QuantizeMapping +from ..quantization.quantization import QuantizeConfig +from . import llama_config, llama_model, llama_parameter + +ModelConfig = Any +"""A ModelConfig is an object that represents a model architecture. It is required to have +a class method `from_file` with the following signature: + + def from_file(cls, path: Path) -> ModelConfig: + ... +""" + +FuncGetExternMap = Callable[[ModelConfig, QuantizeConfig], ExternMapping] +FuncGetQuantMap = Callable[[ModelConfig, QuantizeConfig], QuantizeMapping] + + +@dataclasses.dataclass +class Model: + """All about a model architecture: its configuration, its parameter loader and quantization. + + Parameters + ---------- + name : str + The name of the model. + + model : Callable[[ModelConfig], nn.Module] + A method that creates the `nn.Module` that represents the model from `ModelConfig`. + + config : ModelConfig + A class that has a `from_file` class method, whose signature is "Path -> ModelConfig". + + source : Dict[str, FuncGetExternMap] + A dictionary that maps the name of a source format to parameter mapping. + + quantize: Dict[str, FuncGetQuantMap] + A dictionary that maps the name of a quantization method to quantization mapping. + """ + + name: str + config: ModelConfig + model: Callable[[ModelConfig], nn.Module] + source: Dict[str, FuncGetExternMap] + quantize: Dict[str, FuncGetQuantMap] + + +MODELS: Dict[str, Model] = { + "llama": Model( + name="llama", + model=llama_model.LlamaForCasualLM, + config=llama_config.LlamaConfig, + source={ + "huggingface-torch": llama_parameter.huggingface, + "huggingface-safetensor": llama_parameter.huggingface, + }, + quantize={}, + ) +} + +MODEL_PRESETS: Dict[str, Dict[str, Any]] = llama_config.CONFIG diff --git a/python/mlc_chat/compiler/parameter/__init__.py b/python/mlc_chat/compiler/parameter/__init__.py new file mode 100644 index 0000000000..f119b01f91 --- /dev/null +++ b/python/mlc_chat/compiler/parameter/__init__.py @@ -0,0 +1,6 @@ +""" +A subpackage of the compiler that represents mapping between external parameters, quantized +parameters and parameters in MLC-defined models. +""" +from .huggingface_loader import HuggingFaceLoader +from .mapping import ExternMapping, QuantizeMapping diff --git a/python/mlc_chat/compiler/parameter/huggingface_loader.py b/python/mlc_chat/compiler/parameter/huggingface_loader.py new file mode 100644 index 0000000000..ed91255c81 --- /dev/null +++ b/python/mlc_chat/compiler/parameter/huggingface_loader.py @@ -0,0 +1,193 @@ +"""A weight loader for HuggingFace's PyTorch format""" + +import gc +import json +import logging +from collections import OrderedDict, defaultdict +from pathlib import Path +from typing import Dict, Iterator, List, Optional, Tuple + +import numpy as np +from tqdm import tqdm +from tvm.runtime import NDArray +from tvm.runtime.ndarray import array as as_ndarray + +from .mapping import ExternMapping, QuantizeMapping +from .stats import Stats +from .utils import ( + ParamQuantizer, + check_parameter_usage, + load_safetensor_shard, + load_torch_shard, +) + +logger = logging.getLogger(__name__) + + +class HuggingFaceLoader: # pylint: disable=too-few-public-methods + """A loader loading HuggingFace's PyTorch/SafeTensor format and converts them + to MLC's parameters. + + Attributes + ---------- + stats : Stats + Statistics of the loading process. + + extern_param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch/SafeTensor. + + torch_to_path : Dict[str, Path] + A mapping from PyTorch/SafeTensor parameter name to the path of the file containing it, + or the path meaning all parameters are stored in a single file. + + cached_files : Dict[Path, Dict[str, np.ndarray]] + A cache of the loaded files. The key is the path of the file, and the value is a mapping + from parameter name to the parameter value. + + quantize_param_map : Optional[QuantizeMapping] + The quantization mapping from MLC to quantized MLC parameters. + """ + + stats: Stats + cached_files: Dict[Path, Dict[str, np.ndarray]] + torch_to_path: Dict[str, Path] + extern_param_map: ExternMapping + quantize_param_map: Optional[QuantizeMapping] + + def __init__( + self, + path: Path, + extern_param_map: ExternMapping, + quantize_param_map: Optional[QuantizeMapping] = None, + ) -> None: + """Create a parameter loader from HuggingFace PyTorch format. + + Parameters + ---------- + path : pathlib.Path + Path to either a JSON indexing file, or a PyTorch bin file. + 1) For JSON indexing file, it is usually `pytorch_model.bin.index.json` + or `model.safetensors.index.json` in the repo, which contains a `weight_map` that + maps each PyTorch parameter to the file containing the weight. + 2) For PyTorch bin file, it is usually `pytorch_model.bin` in the repo, + which contains all the parameters. + 3) For safetensor file, it is usually `model.safetensors` in the repo, + which contains all the parameters. + + extern_param_map : ExternMapping + Maps an MLC parameter to a list of PyTorch/SafeTensor parameters. + + quantize_param_map: Optional[QuantizeMapping] + The quantization mapping from MLC to quantized MLC parameters, default to None, which + means no quantization. + """ + assert path.is_file() + self.stats = Stats() + self.extern_param_map = extern_param_map + self.cached_files = {} + self.torch_to_path = {} + self.quantize_param_map = quantize_param_map + if path.suffix in (".bin", ".safetensors"): + self._load_file(path) + for name in self.cached_files[path].keys(): + self.torch_to_path[name] = path + elif path.suffix == ".json": + with path.open("r", encoding="utf-8") as in_file: + torch_weight_map = json.load(in_file)["weight_map"] + for torch_name, path_str in torch_weight_map.items(): + self.torch_to_path[torch_name] = path.parent / path_str + else: + raise FileNotFoundError(f"Unknown file suffix: {path}") + check_parameter_usage(extern_param_map, set(self.torch_to_path.keys())) + + def load(self) -> Iterator[Tuple[str, NDArray]]: + """Load the parameters and yield the MLC parameter and its value.""" + mlc_names = _loading_order(self.extern_param_map, self.torch_to_path) + for mlc_name in tqdm(mlc_names): + param = self._load_mlc_param(mlc_name) + if self.quantize_param_map: + with self.stats.timer("quant_time_sec"): + quantized_params = ParamQuantizer(self.quantize_param_map).quantize( + mlc_name, param + ) + for quantized_name, quantized_param in quantized_params: + logger.info( + ' Quantized Parameter: "%s", shape: %s, dtype: %s', + quantized_name, + quantized_param.shape, + quantized_param.dtype, + ) + yield quantized_name, quantized_param + else: + yield mlc_name, param + cached_files = list(self.cached_files.keys()) + for path in cached_files: + self._unload_file(path) + self.stats.log_time_info("HF") + self.stats.log_mem_usage() + + def _load_mlc_param(self, mlc_name: str) -> np.ndarray: + torch_names = self.extern_param_map.param_map[mlc_name] + files_required = {self.torch_to_path[p] for p in torch_names} + files_existing = set(self.cached_files.keys()) + files_to_load = files_required - files_existing + files_to_unload = files_existing - files_required + + # Step 1. When there is some file to unloaded: + # - If no pending file load: unloading is deferred as there is no gain in peak memory usage; + # - Need to load files: unload immediately to save memory and make space for the new files. + if files_to_load: + for path in files_to_unload: + self._unload_file(path) + # Step 2. Load all the files needed + for path in files_to_load: + self._load_file(path) + # Step 3. Collect all torch parameters in order + torch_params = [self.cached_files[self.torch_to_path[i]][i] for i in torch_names] + # Step 4. Apply the mapping function + with self.stats.timer("map_time_sec"): + param = self.extern_param_map.map_func[mlc_name](*torch_params) + logger.info(' Parameter: "%s", shape: %s, dtype: %s', mlc_name, param.shape, param.dtype) + param = as_ndarray(param) + return param + + def _load_file(self, path: Path) -> None: + logger.info("Loading HF parameters from: %s", path) + load_func = load_safetensor_shard if path.suffix == ".safetensors" else load_torch_shard + with self.stats.timer("load_time_sec"): + result = {} + for name, param in load_func(path): + result[name] = param + self.stats.mem_add(param.nbytes) + self.cached_files[path] = result + + def _unload_file(self, path: Path) -> None: + logger.info("Unloading HF weight file: %s", path) + with self.stats.timer("load_time_sec"): + for _, param in self.cached_files[path].items(): + self.stats.mem_rm(param.nbytes) + del self.cached_files[path] + gc.collect() + + +def _loading_order(param_map: ExternMapping, torch_to_path: Dict[str, Path]) -> List[str]: + # Step 1. Build a map from path to torch parameters + path_to_torch: Dict[Path, List[str]] = defaultdict(list) + for torch_name, path in torch_to_path.items(): + path_to_torch[path].append(torch_name) + # Step 2. Build a map from torch parameters to MLC parameters + torch_to_mlc = defaultdict(list) + for mlc_name, torch_names in param_map.param_map.items(): + for torch_name in torch_names: + torch_to_mlc[torch_name].append(mlc_name) + # Step 3. Construct the ordering that ensures file locality + order = OrderedDict() + for _, torch_names in path_to_torch.items(): + for torch_name in torch_names: + for mlc_name in torch_to_mlc[torch_name]: + if mlc_name not in order: + order[mlc_name] = 1 + return list(order.keys()) + + +__all__ = ["HuggingFaceLoader"] diff --git a/python/mlc_chat/compiler/parameter/mapping.py b/python/mlc_chat/compiler/parameter/mapping.py new file mode 100644 index 0000000000..aab674cfa8 --- /dev/null +++ b/python/mlc_chat/compiler/parameter/mapping.py @@ -0,0 +1,87 @@ +"""Parameter mapping for converting different LLM implementations to MLC LLM.""" +import dataclasses +from typing import Callable, Dict, List, Set, Union + +import numpy as np +from tvm.runtime import NDArray + +MapFuncVariadic = Union[ + Callable[[], np.ndarray], + Callable[[np.ndarray], np.ndarray], + Callable[[np.ndarray, np.ndarray], np.ndarray], + Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray], + Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray], +] + + +@dataclasses.dataclass +class ExternMapping: + """Mapping from a parameter name in MLC LLM's model definition to its potential source, + for example, from MLC parameter "model.layers.2.post_attention_layernorm.weight" to PyTorch's + parameter correspondingly. + + Parameters + ---------- + param_map : Dict[str, List[str]] + A dictionary that maps the name of a parameter to its source. For example, + in Llama2, the source of MLC parameter "model.layers.0.self_attn.qkv_proj.weight" from + huggingface torch are: + + - "model.layers.0.self_attn.q_proj.weight" + - "model.layers.0.self_attn.k_proj.weight" + - "model.layers.0.self_attn.v_proj.weight" + + map_func : Dict[str, Callable[[np.ndarray, ...], np.ndarray]] + A dictionary that maps the name of a parameter to a function that combines the source + parameters into the MLC parameter. For example, for the above example, the function + would be: `lambda q, k, v: np.concatenate([q, k, v], axis=0)`. + + unused_params : Set[str] + Parameter names in the source weights that are not used in the MLC LLM model definition. + """ + + param_map: Dict[str, List[str]] + map_func: Dict[str, MapFuncVariadic] + unused_params: Set[str] = dataclasses.field(default_factory=set) + + +@dataclasses.dataclass +class QuantizeMapping: + """Mapping from a parameter in MLC LLM's model definition to its eventual names and values after + quantization. In certain group quantization, for example, `qkv_proj.weight` is mapped to + `qkv_proj.weight_quantized` and `qkv_proj.weight_scale` respectively. If a parameter's name is + not in the mapping, it is assumed to be unchanged, i.e. not quantized. + + Parameters + ---------- + param_map : Dict[str, List[str]] + A dictionary that maps the name of a parameter to its destination. For example, + in certain group quantization, the destinations of MLC parameter "qkv_proj.weight` are: + + - "qkv_proj.weight_quantized" + - "qkv_proj.weight_scale" + + map_func : Dict[str, Callable[NDArray, List[NDArray]]] + A dictionary that maps the name of a parameter to a function that splits the MLC parameter + into the destination parameters. + + Notes + ----- + There are two forms of weight conversion in MLC LLM, one is A) on-the-fly quantization to the + raw fp16/bf16/fp32 weights from HuggingFace, and the other is B) loading pre-quantized weights + from an external framework, e.g. AutoGPTQ, AutoAWQ. From the perspective of parameter + correspondence. + + - In case A), it is recommended that the weight loader take both `ExternMapping` and + `QuantizeMapping` as input, and do quantiaztion on the fly as a raw parameter being + loaded into RAM; + - In case B), a pass over `nn.Module` is recommended to take place first to converts parameters + from its non-quantized form to the quantized one, and then only `ExternMapping` is + used to convert the quantized parameters into the desired form. + """ + + param_map: Dict[str, List[str]] + map_func: Dict[str, Callable[[NDArray], List[NDArray]]] + + +__all__ = ["ExternMapping", "QuantizeMapping"] diff --git a/python/mlc_chat/compiler/parameter/stats.py b/python/mlc_chat/compiler/parameter/stats.py new file mode 100644 index 0000000000..9f5d1e16fa --- /dev/null +++ b/python/mlc_chat/compiler/parameter/stats.py @@ -0,0 +1,86 @@ +"""Statistics of the loading process of parameter loaders""" +import dataclasses +import logging +import time +from contextlib import contextmanager + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class Stats: + """Statistics of the loading process of parameter loaders. + + Attributes + ---------- + load_time_sec : float + Time used in loading the parameters. + + map_time_sec : float + Time used in applying the mapping function, i.e. `ExternMapping.map_func`. + + quant_time_sec : float + Time used in quantizing the parameters, i.e. `QuantizeMapping.quant_func`. + + current_memory_gb : float + The current RAM usage in GB. + + total_memory_gb : float + The total size data loaded from disk in GB. + + max_memory_gb : float + The maximum RAM usage in GB. + """ + + load_time_sec: float = 0.0 + map_time_sec: float = 0.0 + quant_time_sec: float = 0.0 + + current_memory_gb: float = 0.0 + total_memory_gb: float = 0.0 + max_memory_gb: float = 0.0 + + def timer(self, attr): + """A context manager to time the scope and add the time to the attribute.""" + + @contextmanager + def timed_scope(): + start_time = time.time() + yield + elapsed_time = time.time() - start_time + setattr(self, attr, getattr(self, attr) + elapsed_time) + + return timed_scope() + + def mem_add(self, nbytes: int): + """Add the memory usage by the given number of bytes.""" + mem_gb = float(nbytes) / float(1024**3) + self.current_memory_gb += mem_gb + self.total_memory_gb += mem_gb + self.max_memory_gb = max(self.max_memory_gb, self.current_memory_gb) + + def mem_rm(self, nbytes: int): + """Remove the memory usage by the given number of bytes.""" + mem_gb = float(nbytes) / float(1024**3) + self.current_memory_gb -= mem_gb + + def log_time_info(self, weight_format: str): + """Log the time used in loading, pre-quantization and quantization.""" + logger.info( + "Time used: " + "%s loading: %.3f sec; " + "Pre-quantization mapping: %.3f sec; " + "Quantization: %.3f sec", + weight_format, + self.load_time_sec, + self.map_time_sec, + self.quant_time_sec, + ) + + def log_mem_usage(self): + """Log the Memory usage information.""" + logger.info( + "Memory usage: Total size loaded from disk: %.3f GB; Peak memory usage: %.3f GB", + self.total_memory_gb, + self.max_memory_gb, + ) diff --git a/python/mlc_chat/compiler/parameter/utils.py b/python/mlc_chat/compiler/parameter/utils.py new file mode 100644 index 0000000000..a2789cee55 --- /dev/null +++ b/python/mlc_chat/compiler/parameter/utils.py @@ -0,0 +1,88 @@ +"""Common utilities for loading parameters""" +# pylint: disable=too-few-public-methods +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Iterator, Set, Tuple + +import numpy as np + +from .mapping import ExternMapping + +if TYPE_CHECKING: + from tvm.runtime import NDArray + + from ..parameter import QuantizeMapping + +logger = logging.getLogger(__name__) + + +class ParamQuantizer: + """A parameter quantizer that quantizes given mlc-llm parameters""" + + quantize_map: "QuantizeMapping" + + def __init__(self, quantize_map: "QuantizeMapping") -> None: + self.quantize_map = quantize_map + + def quantize(self, name: str, param: "NDArray") -> Iterator[Tuple[str, "NDArray"]]: + """Apply quantization to the given parameters + + Parameters + ---------- + name : str + The name of the parameter + param : NDArray + The parameter to be quantized + + Returns + ------- + List[Tuple[str, NDArray]] + The quantized parameters, each with its name + """ + + assert name in self.quantize_map.param_map + quantized_names = self.quantize_map.param_map[name] + quantized_params = self.quantize_map.map_func[name](param) + return zip(quantized_names, quantized_params) + + +def check_parameter_usage(param_map: ExternMapping, extern_weights: Set[str]): + """Check that all external parameters have been used and are stored in the weights file.""" + used_extern_names = set(sum(param_map.param_map.values(), [])) + # Check 1. All extern parameters in the weight files are used unless explicitly specified + unused_extern_names = extern_weights - used_extern_names - param_map.unused_params + if unused_extern_names: + logger.warning( + "Unused extern parameters: %s", + ", ".join(sorted(unused_extern_names)), + ) + # Check 2. All extern parameters required are stored in the weight files + nonexistent_extern_names = used_extern_names - extern_weights + if nonexistent_extern_names: + raise ValueError( + "The following extern parameters do not exist in the weight files:\n " + + "\n ".join(sorted(nonexistent_extern_names)), + ) + + +def load_torch_shard(path: Path) -> Iterator[Tuple[str, np.ndarray]]: + """Load and yield PyTorch format parameters.""" + import torch # pylint: disable=import-outside-toplevel + + for name, param in torch.load(path, map_location=torch.device("cpu")).items(): + param = param.detach().cpu() + dtype = str(param.dtype) + if dtype == "torch.bfloat16": + param = param.float() + param = param.numpy() + yield name, param + + +def load_safetensor_shard(path: Path) -> Iterator[Tuple[str, np.ndarray]]: + """Load and yield SafeTensor format parameters.""" + import safetensors # pylint: disable=import-outside-toplevel,import-error + + with safetensors.safe_open(path, framework="numpy", device="cpu") as in_file: + for name in in_file.keys(): + param = in_file.get_tensor(name) + yield name, param diff --git a/python/mlc_chat/compiler/quantization/__init__.py b/python/mlc_chat/compiler/quantization/__init__.py new file mode 100644 index 0000000000..a932119f9c --- /dev/null +++ b/python/mlc_chat/compiler/quantization/__init__.py @@ -0,0 +1,2 @@ +"""A subpackage for quantization and dequantization algorithms""" +from .quantization import QUANT, QuantizeConfig diff --git a/python/mlc_chat/compiler/quantization/group_quantizer.py b/python/mlc_chat/compiler/quantization/group_quantizer.py new file mode 100644 index 0000000000..b95c946abd --- /dev/null +++ b/python/mlc_chat/compiler/quantization/group_quantizer.py @@ -0,0 +1,70 @@ +"""A group quantizer for on the fly parameter quantization""" +# pylint: disable=too-few-public-methods + +from typing import List, Tuple + +from tvm import te, tir + +from .quantization import QuantizeConfig + + +def te_quantize( + weight: te.Tensor, config: QuantizeConfig +) -> Tuple[te.Tensor, te.Tensor, List[te.Tensor]]: + """Group quantization for weight tensor, defined in tensor expression.""" + # pylint: disable=too-many-locals + assert len(weight.shape) == 2 + n, m = weight.shape # pylint: disable=invalid-name + # compute scale per group + r = te.reduce_axis((0, config.group_size), name="r") # pylint: disable=invalid-name + num_group = tir.ceildiv(m, config.group_size) + scale_shape = (n, num_group) + max_abs = te.compute( + shape=scale_shape, + fcompute=lambda i, j: te.max( + tir.if_then_else( + j * config.group_size + r < weight.shape[1], + te.abs(weight[i, j * config.group_size + r]), + tir.const(1e-4, config.weight_dtype), + ), + axis=r, + ), + name="max_abs_value", + ) + scale = te.compute( + (n, m), + lambda i, j: max_abs[i, j] / tir.const(config.max_int_value, dtype=config.weight_dtype), + name="scale", + ) + + # compute scaled weight + tir_max_int = tir.const(config.max_int_value, config.weight_dtype) + tir_zero = tir.const(0, config.weight_dtype) + tir_max_int_2 = tir.const(config.max_int_value * 2, config.weight_dtype) + scaled_weight = te.compute( + shape=weight.shape, + fcompute=lambda i, j: tir.min( + tir.max( + tir.round(weight[i, j] / scale[i, j // config.group_size] + tir_max_int), + tir_zero, + ), + tir_max_int_2, + ).astype(config.storage_dtype), + ) + + # compute quantized weight per storage + r = te.reduce_axis((0, config.num_elem_per_storage), name="r") # pylint: disable=invalid-name + num_storage = config.num_storage_per_group * num_group + quantized_weight_shape = (n, num_storage) + quantized_weight = te.compute( + shape=quantized_weight_shape, + fcompute=lambda i, j: tir.sum( + scaled_weight[i, j * config.num_elem_per_storage + r] + << (r * config.quantize_dtype_bits), + axis=r, + where=j * config.num_elem_per_storage + r < m, + ), + name="weight", + ) + return quantized_weight, scale, [max_abs, scaled_weight] + # pylint: enable=too-many-locals diff --git a/python/mlc_chat/compiler/quantization/quantization.py b/python/mlc_chat/compiler/quantization/quantization.py new file mode 100644 index 0000000000..c1ba794063 --- /dev/null +++ b/python/mlc_chat/compiler/quantization/quantization.py @@ -0,0 +1,22 @@ +"""A centralized registry of all existing quantization methods and their configurations.""" +from typing import Any, Dict + +QuantizeConfig = Any +"""A QuantizeConfig is an object that represents an quantization algorithm. It is required to +have the following fields: + + name : str + The name of the quantization algorithm, for example, "q4f16_1". + + kind : str + The kind of quantization algorithm, for example, "group_quant", "faster_transformer". + +It is also required to have the following method: + + def quantize(self, module: nn.Module) -> nn.Module: + ... +""" + +QUANT: Dict[str, QuantizeConfig] = { + "q4f16_1": None, +} diff --git a/python/mlc_chat/embeddings/openai.py b/python/mlc_chat/embeddings/openai.py index 5795ed8158..ad6b750b0b 100644 --- a/python/mlc_chat/embeddings/openai.py +++ b/python/mlc_chat/embeddings/openai.py @@ -1,17 +1,15 @@ +# pylint: disable=missing-docstring from __future__ import annotations -from langchain.embeddings import OpenAIEmbeddings -from langchain.embeddings.openai import embed_with_retry, async_embed_with_retry - import logging -from typing import ( - List, - Optional, - Sequence, - Tuple, -) +from typing import Iterable, List, Optional, Sequence, Tuple import numpy as np +from langchain.embeddings import OpenAIEmbeddings # pylint: disable=import-error +from langchain.embeddings.openai import ( # pylint: disable=import-error + async_embed_with_retry, + embed_with_retry, +) logger = logging.getLogger(__name__) @@ -25,13 +23,13 @@ def _chunk_tokens(self, texts: Sequence[str]) -> Tuple[List[List], List[int]]: ) try: - import tiktoken - except ImportError: + import tiktoken # pylint: disable=import-outside-toplevel + except ImportError as err: raise ImportError( "Could not import tiktoken python package. " "This is needed in order to for OpenAIEmbeddings. " "Please install it with `pip install tiktoken`." - ) + ) from err tokens = [] indices = [] @@ -62,10 +60,10 @@ def _batch_embed( ) -> List[List[float]]: batched_embeddings: List[List[float]] = [] _chunk_size = chunk_size or self.chunk_size - _iter = range(0, len(inputs), _chunk_size) + _iter: Iterable = range(0, len(inputs), _chunk_size) if self.show_progress_bar: try: - from tqdm.auto import tqdm + from tqdm import tqdm # pylint: disable=import-outside-toplevel _iter = tqdm(_iter) except ImportError: @@ -85,10 +83,10 @@ async def _abatch_embed( ) -> List[List[float]]: batched_embeddings: List[List[float]] = [] _chunk_size = chunk_size or self.chunk_size - _iter = range(0, len(inputs), _chunk_size) + _iter: Iterable = range(0, len(inputs), _chunk_size) if self.show_progress_bar: try: - from tqdm.auto import tqdm + from tqdm import tqdm # pylint: disable=import-outside-toplevel _iter = tqdm(_iter) except ImportError: @@ -105,8 +103,12 @@ async def _abatch_embed( # please refer to # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb - def _get_len_safe_embeddings( - self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None + def _get_len_safe_embeddings( # pylint: disable=too-many-locals,unused-argument + self, + texts: List[str], + *, + engine: str, + chunk_size: Optional[int] = None, ) -> List[List[float]]: tokens, indices = self._chunk_tokens(texts) batched_embeddings = self._batch_embed(tokens, chunk_size=chunk_size) @@ -121,9 +123,9 @@ def _get_len_safe_embeddings( self, input="", **self._invocation_params, - )[ - "data" - ][0]["embedding"] + )["data"][ + 0 + ]["embedding"] for _result, num_tokens in zip(results, num_tokens_in_batch): if len(_result) == 0: average = empty_average @@ -136,8 +138,12 @@ def _get_len_safe_embeddings( # please refer to # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb - async def _aget_len_safe_embeddings( - self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None + async def _aget_len_safe_embeddings( # pylint: disable=too-many-locals,unused-argument + self, + texts: List[str], + *, + engine: str, + chunk_size: Optional[int] = None, ) -> List[List[float]]: tokens, indices = self._chunk_tokens(texts) batched_embeddings = await self._abatch_embed(tokens, chunk_size=chunk_size) @@ -155,7 +161,9 @@ async def _aget_len_safe_embeddings( input="", **self._invocation_params, ) - )["data"][0]["embedding"] + )[ + "data" + ][0]["embedding"] for _result, num_tokens in zip(results, num_tokens_in_batch): if len(_result) == 0: average = empty_average diff --git a/python/mlc_chat/gradio.py b/python/mlc_chat/gradio.py index 5975a8681d..1ab6ae6dc0 100644 --- a/python/mlc_chat/gradio.py +++ b/python/mlc_chat/gradio.py @@ -1,11 +1,10 @@ """Gradio interface for MLC Chat.""" -# pylint: disable=import-error, import-outside-toplevel, invalid-name, line-too-long, protected-access -# too-many-instance-attributes, too-many-locals, unused-import - -from typing import Dict +# pylint: disable=import-error,invalid-name,too-many-instance-attributes,too-many-locals import argparse -import os import glob +import os +from typing import Dict, Optional + import gradio as gr from .chat_module import ChatModule @@ -47,17 +46,19 @@ def _get_all_available_models_under_dir(artifact_path: str) -> Dict[str, str]: Note ---- We only search for folders under the artifact_path, without recursive search for subfolders. - For each folder, we count it as a valid MLC model folder if either it contains a `mlc-chat-config.json` - file, or it contains a `params` folder which contains a `mlc-chat-config.json` file. We will map - the name of a valid folder to its full path to the folder containing `mlc-chat-config.json`. + For each folder, we count it as a valid MLC model folder if either it contains an + `mlc-chat-config.json` file, or it contains a `params` folder which contains an + `mlc-chat-config.json` file. We will map the name of a valid folder to its full path to the + folder containing `mlc-chat-config.json`. """ # step 0. retrieve the absolute path of artifact_path search_dir = os.path.abspath(artifact_path) if not os.path.exists(search_dir): err_msg = ( - f"The artifact path {artifact_path} you provided is neither a valid full path nor a valid path ", - "relative to the current working directory. Please provide a correct artifact path.", + f"The artifact path {artifact_path} you provided is neither a valid full path nor a " + "valid path relative to the current working directory. Please provide a correct " + "artifact path.", ) raise FileNotFoundError(err_msg) @@ -77,9 +78,9 @@ def _get_all_available_models_under_dir(artifact_path: str) -> Dict[str, str]: class GradioModule: - r"""The Gradio module for MLC Chat. Different from ChatModule Python API, Gradio module allows users - to load in a directory of models, watch the streaming in web browser, and switch between models more - easily to compare performance. + r"""The Gradio module for MLC Chat. Different from ChatModule Python API, Gradio module allows + users to load in a directory of models, watch the streaming in web browser, and switch between + models more easily to compare performance. Note: Multimodality will be supported soon, i.e. allowing users to upload an image to chat. """ @@ -87,7 +88,7 @@ class GradioModule: def __init__(self, artifact_path: str = "dist", device: str = "auto"): self.artifact_path = artifact_path self.device_str = device - self.chat_mod = None + self.chat_mod: Optional[ChatModule] = None self.model_dict = _get_all_available_models_under_dir(artifact_path) def gradio_reload_model(self, model_name: str): @@ -132,6 +133,7 @@ def gradio_answer(self, chatbot, stream_interval): Note: Below is a low-level implementation of generate() API, since it's easier to yield without delta callback.""" prompt = chatbot[-1][0] + # pylint: disable=protected-access self.chat_mod._prefill(prompt) i, new_msg = 0, "" while not self.chat_mod._stopped(): @@ -141,6 +143,7 @@ def gradio_answer(self, chatbot, stream_interval): chatbot[-1][1] = new_msg yield chatbot i += 1 + # pylint: enable=protected-access def gradio_stats(self): """Get runtime statistics.""" @@ -148,8 +151,14 @@ def gradio_stats(self): def launch_gradio( - artifact_path: str = "dist", device: str = "auto", port: int = 7860, share: bool = False, host: str = "127.0.0.1"): - r"""Launch the gradio interface with a given port, creating a publically sharable link if specified.""" + artifact_path: str = "dist", + device: str = "auto", + port: int = 7860, + share: bool = False, + host: str = "127.0.0.1", +): + r"""Launch the gradio interface with a given port, creating a publically sharable link if + specified.""" # create a gradio module mod = GradioModule(artifact_path, device) @@ -230,7 +239,7 @@ def launch_gradio( stats_button.click(mod.gradio_stats, [], [stats_output]) # launch to the web - demo.launch(share=share, enable_queue=True, server_port=port,server_name=host) + demo.launch(share=share, enable_queue=True, server_port=port, server_name=host) if __name__ == "__main__": diff --git a/python/mlc_chat/interface/openai_api.py b/python/mlc_chat/interface/openai_api.py index 11a72d8ba6..654b1646bc 100644 --- a/python/mlc_chat/interface/openai_api.py +++ b/python/mlc_chat/interface/openai_api.py @@ -1,11 +1,14 @@ +# pylint: disable=missing-docstring,fixme,import-error,too-few-public-methods """ -Adapted from FastChat's OpenAI protocol: https://github.com/lm-sys/FastChat/blob/main/fastchat/protocol/openai_api_protocol.py +Adapted from FastChat's OpenAI protocol: +https://github.com/lm-sys/FastChat/blob/main/fastchat/protocol/openai_api_protocol.py """ -from typing import Literal, Optional, List, Dict, Any, Union -from pydantic import BaseModel, Field -import shortuuid import time +from typing import Any, Dict, List, Literal, Optional, Union + +import shortuuid +from pydantic import BaseModel, Field class ChatMessage(BaseModel): @@ -13,32 +16,43 @@ class ChatMessage(BaseModel): content: str name: str | None = None + class ChatCompletionRequest(BaseModel): model: str messages: list[ChatMessage] stream: bool | None = False temperature: float = None top_p: float = None + # TODO: replace by presence_penalty and frequency_penalty repetition_penalty: float = None mean_gen_len: int = None + # TODO: replace by max_tokens max_gen_len: int = None - # TODO: Implement support for the following fields - # n: Optional[int] = 1 + presence_penalty: float = None + frequency_penalty: float = None + n: int = None + stop: Union[str, List[str]] = None + # TODO: Implement support for the OpenAI API parameters + # function [] + # function_call # stop: Optional[Union[str, List[str]]] = None - # presence_penalty: Optional[float] = 0.0 - # frequency_penalty: Optional[float] = 0.0 + # max_tokens: Optional[int] + # logit_bias # user: Optional[str] = None + class UsageInfo(BaseModel): prompt_tokens: int = 0 completion_tokens: int | None = 0 total_tokens: int = 0 + class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage finish_reason: Literal["stop", "length"] | None = None + class ChatCompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") object: str = "chat.completion" @@ -47,40 +61,48 @@ class ChatCompletionResponse(BaseModel): # TODO: Implement support for the following fields usage: UsageInfo | None = None + class DeltaMessage(BaseModel): role: str | None = None content: str | None = None + class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage finish_reason: Literal["stop", "length"] | None = None + class ChatCompletionStreamResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") object: str = "chat.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) choices: list[ChatCompletionResponseStreamChoice] + class CompletionRequest(BaseModel): model: str prompt: str | list[str] + stream: bool | None = False temperature: float = None repetition_penalty: float = None top_p: float = None mean_gen_len: int = None + # TODO: replace by max_tokens max_gen_len: int = None - system_prompt: str = None - chat_roles: List[str] = None - messages: List[List[str]] = None - offset: str = None - separator_style: int = None - seps: List[str] = None - role_msg_sep: str = None - role_empty_sep: str = None - stop_str: str = None - stop_tokens: List[int] = None - add_bos: bool = None + presence_penalty: float = None + frequency_penalty: float = None + # TODO: Implement support for the OpenAI API parameters + # suffix + # max_tokens: Optional[int] + # n: Optional[int] = 1 + # logprobs + # echo + # stop: Optional[Union[str, List[str]]] = None + # best_of + # logit_bias + # user: Optional[str] = None + class CompletionResponseChoice(BaseModel): index: int @@ -88,6 +110,7 @@ class CompletionResponseChoice(BaseModel): logprobs: int | None = None finish_reason: Literal["stop", "length"] | None = None + class CompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") object: str = "text_completion" @@ -95,11 +118,26 @@ class CompletionResponse(BaseModel): choices: list[CompletionResponseChoice] usage: UsageInfo + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + choices: List[CompletionResponseStreamChoice] + + class EmbeddingsRequest(BaseModel): model: Optional[str] = None input: Union[str, List[Any]] user: Optional[str] = None + class EmbeddingsResponse(BaseModel): object: str = "list" data: List[Dict[str, Any]] diff --git a/python/mlc_chat/rest.py b/python/mlc_chat/rest.py index c341b3be8c..2816d9bec7 100644 --- a/python/mlc_chat/rest.py +++ b/python/mlc_chat/rest.py @@ -1,31 +1,44 @@ +# pylint: disable=missing-docstring,fixme,import-error import argparse import asyncio -import json -import os -import subprocess -import sys +import dataclasses from contextlib import asynccontextmanager +from typing import Dict -from mlc_chat.chat_module import ChatConfig, ConvConfig - +import numpy as np import uvicorn from fastapi import FastAPI -from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware - -from dataclasses import dataclass, field, fields +from fastapi.responses import StreamingResponse +from mlc_chat.chat_module import GenerationConfig from .base import set_global_random_seed from .chat_module import ChatModule -from .interface.openai_api import * +from .interface.openai_api import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + DeltaMessage, + EmbeddingsRequest, + EmbeddingsResponse, + UsageInfo, +) -import numpy as np -@dataclass +@dataclasses.dataclass class RestAPIArgs: - """RestAPIArgs is the dataclass that organizes the arguments used for starting a REST API server.""" + """RestAPIArgs is the dataclass that organizes the arguments used for starting a REST API + server.""" - model: str = field( + model: str = dataclasses.field( metadata={ "help": ( """ @@ -38,7 +51,7 @@ class RestAPIArgs: ) } ) - lib_path: str = field( + lib_path: str = dataclasses.field( default=None, metadata={ "help": ( @@ -46,19 +59,9 @@ class RestAPIArgs: The full path to the model library file to use (e.g. a ``.so`` file). """ ) - } + }, ) - config_overrides_path: str = field( - default=None, - metadata={ - "help": ( - """ - The full path to the model config file to use for overriding the default (e.g. a ``.json`` file). - """ - ) - } - ) - device: str = field( + device: str = dataclasses.field( default="auto", metadata={ "help": ( @@ -70,9 +73,9 @@ class RestAPIArgs: is provided, it will be set to 0 by default. """ ) - } + }, ) - host: str = field( + host: str = dataclasses.field( default="127.0.0.1", metadata={ "help": ( @@ -80,9 +83,9 @@ class RestAPIArgs: The host at which the server should be started, defaults to ``127.0.0.1``. """ ) - } + }, ) - port: int = field( + port: int = dataclasses.field( default=8000, metadata={ "help": ( @@ -90,9 +93,9 @@ class RestAPIArgs: The port on which the server should be started, defaults to ``8000``. """ ) - } + }, ) - random_seed: int = field( + random_seed: int = dataclasses.field( default=None, metadata={ "help": ( @@ -101,14 +104,14 @@ class RestAPIArgs: no seed is set. """ ) - } + }, ) def convert_args_to_argparser() -> argparse.ArgumentParser: """Convert from RestAPIArgs to an equivalent ArgumentParser.""" args = argparse.ArgumentParser("MLC Chat REST API") - for field in fields(RestAPIArgs): + for field in dataclasses.fields(RestAPIArgs): name = field.name.replace("_", "-") field_name = f"--{name}" # `kwargs` contains `help`, `choices`, and `action` @@ -121,37 +124,26 @@ def convert_args_to_argparser() -> argparse.ArgumentParser: return args -session = {} +session: Dict[str, ChatModule] = {} @asynccontextmanager -async def lifespan(app: FastAPI): - chat_config_overrides = None - if ARGS.config_overrides_path and os.path.isfile(ARGS.config_overrides_path): - with open(ARGS.config_overrides_path, mode="rt", encoding="utf-8") as f: - json_object = json.load(f) - chat_config_overrides = ChatConfig._from_json(json_object) +async def lifespan(_app: FastAPI): if ARGS.random_seed is not None: set_global_random_seed(ARGS.random_seed) chat_mod = ChatModule( model=ARGS.model, device=ARGS.device, - lib_path=ARGS.lib_path, - chat_config=chat_config_overrides + model_lib_path=ARGS.lib_path, ) session["chat_mod"] = chat_mod - yield - session.clear() -app = FastAPI(lifespan=lifespan) - -origins = [ - "*", -] +origins = ["*"] +app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=origins, @@ -160,85 +152,100 @@ async def lifespan(app: FastAPI): allow_headers=["*"], ) -class AsyncChatCompletionStream: + +class AsyncCompletionStream: + def __init__(self, generation_config: GenerationConfig): + self.generation_config = generation_config + def __aiter__(self): return self async def get_next_msg(self): + # pylint: disable=protected-access if not session["chat_mod"]._stopped(): - session["chat_mod"]._decode() + session["chat_mod"]._decode(generation_config=self.generation_config) msg = session["chat_mod"]._get_message() return msg - else: - raise StopAsyncIteration + # pylint: enable=protected-access + raise StopAsyncIteration async def __anext__(self): if not session["chat_mod"]._stopped(): task = asyncio.create_task(self.get_next_msg()) msg = await task return msg - else: - raise StopAsyncIteration + raise StopAsyncIteration @app.post("/v1/chat/completions") -async def request_completion(request: ChatCompletionRequest): +async def request_chat_completion(request: ChatCompletionRequest): """ Creates model response for the given chat conversation. + The messages field contains a list of messages (describing the conversation history). eg: + ```"messages": [{"role": "user", "content": "What's my name?"}, + {"role": "assistant", "content": "Your name is Llama."}, + {"role": "user", "content": "No, that's your name. My name is X."}, + {"role": "assistant", "content": "Ah, my apologies! Your name is X! "}, + {"role": "user", "content": "What is the meaning of life?"}, + ] + ``` + ] """ - - chat_config = ChatConfig( + generation_config = GenerationConfig( temperature=request.temperature, repetition_penalty=request.repetition_penalty, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, top_p=request.top_p, mean_gen_len=request.mean_gen_len, max_gen_len=request.max_gen_len, + n=request.n, + stop=request.stop, ) - session["chat_mod"].update_chat_config(chat_config) - if len(request.messages) > 1: - raise ValueError( - """ - The /v1/chat/completions endpoint currently only supports single message prompts. - Please ensure your request contains only one message - """) + session["chat_mod"].reset_chat() # Reset previous history, KV cache, etc. if request.stream: - - session["chat_mod"]._prefill(input=request.messages[0].content) + session["chat_mod"]._prefill( # pylint: disable=protected-access + input=request.messages, + generation_config=generation_config, + ) async def iter_response(): prev_txt = "" - async for content in AsyncChatCompletionStream(): + async for content in AsyncCompletionStream(generation_config=generation_config): if content: + valid_content = content.replace("�", "") chunk = ChatCompletionStreamResponse( choices=[ ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage( - role="assistant", content=content[len(prev_txt) :] + role="assistant", content=valid_content[len(prev_txt) :] ), finish_reason="stop", ) ] ) - prev_txt = content + prev_txt = valid_content yield f"data: {chunk.json(exclude_unset=True)}\n\n" return StreamingResponse(iter_response(), media_type="text/event-stream") - else: - msg = session["chat_mod"].generate(prompt=request.messages[0].content) - return ChatCompletionResponse( - choices=[ - ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=msg), - finish_reason="stop", - ) - ], - # TODO: Fill in correct usage info - usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0), - ) + msg = session["chat_mod"].generate(prompt=request.messages, generation_config=generation_config) + if isinstance(msg, str): + msg = [msg] + return ChatCompletionResponse( + choices=[ + ChatCompletionResponseChoice( + index=index, + message=ChatMessage(role="assistant", content=msg[index]), + finish_reason="stop", + ) + for index in range(len(msg)) + ], + # TODO: Fill in correct usage info + usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) @app.post("/v1/completions") @@ -247,31 +254,17 @@ async def request_completion(request: CompletionRequest): Creates a completion for a given prompt. """ - conv_config = ConvConfig( - system=request.system_prompt, - roles=request.chat_roles, - messages=request.messages, - offset=request.offset, - separator_style=request.separator_style, - seps=request.seps, - role_msg_sep=request.role_msg_sep, - role_empty_sep=request.role_empty_sep, - stop_str=request.stop_str, - stop_tokens=request.stop_tokens, - add_bos=request.add_bos, - ) - - chat_config = ChatConfig( + generation_config = GenerationConfig( temperature=request.temperature, repetition_penalty=request.repetition_penalty, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, top_p=request.top_p, mean_gen_len=request.mean_gen_len, max_gen_len=request.max_gen_len, - conv_config=conv_config, ) session["chat_mod"].reset_chat() - session["chat_mod"].update_chat_config(chat_config) # Langchain's load_qa_chain.run expects the input to be a list with the query if isinstance(request.prompt, list): if len(request.prompt) > 1: @@ -279,13 +272,36 @@ async def request_completion(request: CompletionRequest): """ The /v1/completions endpoint currently only supports single message prompts. Please ensure your request contains only one message - """) + """ + ) prompt = request.prompt[0] else: prompt = request.prompt - msg = session["chat_mod"].generate(prompt=prompt) + if request.stream: + session["chat_mod"]._prefill( # pylint: disable=protected-access + input=prompt, + generation_config=generation_config, + ) + async def iter_response(): + prev_txt = "" + async for content in AsyncCompletionStream(generation_config=generation_config): + if content: + chunk = CompletionStreamResponse( + choices=[ + CompletionResponseStreamChoice( + index=0, + text=content[len(prev_txt) :], + finish_reason="stop", + ) + ] + ) + prev_txt = content + yield f"data: {chunk.json(exclude_unset=True)}\n\n" + + return StreamingResponse(iter_response(), media_type="text/event-stream") + msg = session["chat_mod"].generate(prompt=prompt, generation_config=generation_config) return CompletionResponse( choices=[CompletionResponseChoice(index=0, text=msg)], # TODO: Fill in correct usage info @@ -299,13 +315,13 @@ async def request_embeddings(request: EmbeddingsRequest): Gets embedding for some text. """ inps = [] - if type(request.input) == str: + if isinstance(request.input, str): inps.append(request.input) - elif type(request.input) == list: + elif isinstance(request.input, list): inps = request.input else: assert f"Invalid input type {type(request.input)}" - + data = [] for i, inp in enumerate(inps): session["chat_mod"].reset_chat() @@ -315,12 +331,7 @@ async def request_embeddings(request: EmbeddingsRequest): data.append({"object": "embedding", "embedding": norm_emb.tolist(), "index": i}) # TODO: Fill in correct usage info return EmbeddingsResponse( - data=data, - usage=UsageInfo( - prompt_tokens=0, - completion_tokens=0, - total_tokens=0 - ) + data=data, usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0) ) @@ -340,6 +351,14 @@ async def read_stats(): return session["chat_mod"].stats() +@app.get("/verbose_stats") +async def read_stats_verbose(): + """ + Get the verbose runtime stats. + """ + return session["chat_mod"].stats(verbose=True) + + ARGS = convert_args_to_argparser().parse_args() if __name__ == "__main__": uvicorn.run("mlc_chat.rest:app", host=ARGS.host, port=ARGS.port, reload=False, access_log=False) diff --git a/python/mlc_chat/support/__init__.py b/python/mlc_chat/support/__init__.py new file mode 100644 index 0000000000..ca5d7a6b5b --- /dev/null +++ b/python/mlc_chat/support/__init__.py @@ -0,0 +1,4 @@ +""" +Common utilities used in the Python package. Do not import anything by default, +as they may introduce unnecessary dependencies. +""" diff --git a/python/mlc_chat/support/auto_config.py b/python/mlc_chat/support/auto_config.py new file mode 100644 index 0000000000..61a84b4041 --- /dev/null +++ b/python/mlc_chat/support/auto_config.py @@ -0,0 +1,101 @@ +"""Help function for detecting the model configuration file `config.json`""" +import json +import logging +import tempfile +from pathlib import Path +from typing import TYPE_CHECKING, Union + +from .style import green + +if TYPE_CHECKING: + from mlc_chat.compiler import Model # pylint: disable=unused-import + +logger = logging.getLogger(__name__) + +FOUND = green("Found") + + +def detect_config(config: Union[str, Path]) -> Path: + """Detect and return the path that points to config.json. If `config` is a directory, + it looks for config.json below it. + + Parameters + --------- + config : Union[str, pathlib.Path] + The preset name of the model, or the path to `config.json`, or the directory containing + `config.json`. + + Returns + ------- + config_json_path : pathlib.Path + The path points to config.json. + """ + from mlc_chat.compiler import ( # pylint: disable=import-outside-toplevel + MODEL_PRESETS, + ) + + if isinstance(config, str) and config in MODEL_PRESETS: + logger.info("%s preset model: %s", FOUND, config) + content = MODEL_PRESETS[config] + temp_file = tempfile.NamedTemporaryFile( # pylint: disable=consider-using-with + suffix=".json", + delete=False, + ) + logger.info("Dumping config to: %s", temp_file.name) + config_path = Path(temp_file.name) + with config_path.open("w", encoding="utf-8") as config_file: + json.dump(content, config_file, indent=2) + else: + config_path = Path(config) + if not config_path.exists(): + raise ValueError(f"{config_path} does not exist.") + + if config_path.is_dir(): + # search config.json under config path + config_json_path = config_path / "config.json" + if not config_json_path.exists(): + raise ValueError(f"Fail to find config.json under {config_path}.") + else: + config_json_path = config_path + + logger.info("%s model configuration: %s", FOUND, config_json_path) + return config_json_path + + +def detect_model_type(model_type: str, config: Path) -> "Model": + """Detect the model type from the configuration file. If `model_type` is "auto", it will be + inferred from the configuration file. Otherwise, it will be used as the model type, and sanity + check will be performed. + + Parameters + ---------- + model_type : str + The model type, for example, "llama". + + config : pathlib.Path + The path to config.json. + + Returns + ------- + model : mlc_chat.compiler.Model + The model type. + """ + + from mlc_chat.compiler import ( # pylint: disable=import-outside-toplevel + MODELS, + Model, + ) + + if model_type == "auto": + with open(config, "r", encoding="utf-8") as config_file: + cfg = json.load(config_file) + if "model_type" not in cfg: + raise ValueError( + f"'model_type' not found in: {config}. " + f"Please explicitly specify `--model-type` instead" + ) + model_type = cfg["model_type"] + logger.info("%s Model type: %s", FOUND, model_type) + if model_type not in MODELS: + raise ValueError(f"Unknown model type: {model_type}. Available ones: {list(MODELS.keys())}") + return MODELS[model_type] diff --git a/python/mlc_chat/support/auto_target.py b/python/mlc_chat/support/auto_target.py new file mode 100644 index 0000000000..f31e813410 --- /dev/null +++ b/python/mlc_chat/support/auto_target.py @@ -0,0 +1,308 @@ +"""Helper functioms for target auto-detection.""" +import logging +from typing import TYPE_CHECKING, Callable, Optional, Tuple + +from tvm import IRModule, relax +from tvm._ffi import register_func +from tvm.contrib import tar, xcode +from tvm.target import Target + +from .style import green, red + +if TYPE_CHECKING: + from mlc_chat.compiler.compile import CompileArgs + + +logger = logging.getLogger(__name__) + +# TODO: add help message on how to specify the target manually # pylint: disable=fixme +# TODO: include host detection logic below after the new TVM build is done. # pylint: disable=fixme +HELP_MSG = """TBD""" +FOUND = green("Found") +NOT_FOUND = red("Not found") +BuildFunc = Callable[[IRModule, "CompileArgs"], None] + + +def detect_target_and_host(target_hint: str, host_hint: str) -> Tuple[Target, BuildFunc]: + """Detect the configuration for the target device and its host, for example, target GPU and + the host CPU. + + Parameters + ---------- + target_hint : str + The hint for the target device. + + host_hint : str + The hint for the host CPU. + """ + target, build_func = _detect_target_gpu(target_hint) + if target.host is None: + target = Target(target, host=_detect_target_host(host_hint)) + return target, build_func + + +def _detect_target_gpu(hint: str) -> Tuple[Target, BuildFunc]: + if hint in ["iphone", "android", "webgpu", "mali", "opencl"]: + hint += ":generic" + if hint == "auto": + logger.info("Detecting potential target devices: %s", ", ".join(AUTO_DETECT_DEVICES)) + target: Optional[Target] = None + for device in AUTO_DETECT_DEVICES: + device_target = _detect_target_from_device(device + ":0") + if device_target is not None and target is None: + target = device_target + if target is None: + raise ValueError("No GPU target detected. Please specify explicitly") + return target, _build_default() + if hint in AUTO_DETECT_DEVICES: + target = _detect_target_from_device(hint + ":0") + if target is None: + raise ValueError(f"No GPU target detected from device: {hint}") + return target, _build_default() + if hint in PRESET: + preset = PRESET[hint] + target = Target(preset["target"]) # type: ignore[index] + build = preset.get("build", _build_default) # type: ignore[attr-defined] + return target, build() + if _is_device(hint): + logger.info("Detecting target device: %s", hint) + target = Target.from_device(hint) + logger.info("%s target: %s", FOUND, target.export()) + return target, _build_default() + try: + logger.info("Try creating device target from string: %s", hint) + target = Target(hint) + logger.info("%s target: %s", FOUND, target.export()) + return target, _build_default() + except Exception as err: + logger.info("%s: Failed to create target", NOT_FOUND) + raise ValueError(f"Invalid target: {hint}") from err + + +def _detect_target_host(hint: str) -> Target: + """Detect the host CPU architecture.""" + # cpu = codegen.llvm_get_system_cpu() + # triple = codegen.llvm_get_system_triple() + # vendor = codegen.llvm_get_system_x86_vendor() + if hint == "auto": + hint = "x86-64" + if hint == "x86-64": + hint = "x86_64" + return Target({"kind": "llvm", "mtriple": f"{hint}-unknown-unknown"}) + + +def _is_device(device: str): + if " " in device: + return False + if device.count(":") != 1: + return False + return True + + +def _add_prefix_symbol(mod: IRModule, prefix: str, is_system_lib: bool) -> IRModule: + if is_system_lib and prefix: + mod = mod.with_attr("system_lib_prefix", prefix) + elif is_system_lib: + logger.warning("--prefix-symbols is not specified when building a static library") + elif prefix: + logger.warning( + "--prefix-symbols is specified, but it will not take any effect " + "when building the shared library" + ) + return mod + + +def _detect_target_from_device(device: str) -> Optional[Target]: + try: + target = Target.from_device(device) + except ValueError: + logger.info("%s: target device: %s", NOT_FOUND, device) + return None + logger.info( + '%s configuration of target device "%s": %s', + FOUND, + device, + target.export(), + ) + return target + + +def _build_metal_x86_64(): + def build(mod: IRModule, args: "CompileArgs"): + output = args.output + mod = _add_prefix_symbol(mod, args.prefix_symbols, is_system_lib=False) + assert output.suffix == ".dylib" + relax.build( + mod, + target=args.target, + ).export_library( + str(output), + fcompile=xcode.create_dylib, + sdk="macosx", + arch="x86_64", + ) + + return build + + +def _build_iphone(): + @register_func("tvm_callback_metal_compile", override=True) + def compile_metal(src, target): + if target.libs: + return xcode.compile_metal(src, sdk=target.libs[0]) + return xcode.compile_metal(src) + + def build(mod: IRModule, args: "CompileArgs"): + output = args.output + mod = _add_prefix_symbol(mod, args.prefix_symbols, is_system_lib=True) + assert output.suffix == ".tar" + relax.build( + mod, + target=args.target, + system_lib=True, + ).export_library( + str(output), + fcompile=tar.tar, + ) + + return build + + +def _build_android(): + def build(mod: IRModule, args: "CompileArgs"): + output = args.output + mod = _add_prefix_symbol(mod, args.prefix_symbols, is_system_lib=True) + assert output.suffix == ".tar" + relax.build( + mod, + target=args.target, + system_lib=True, + ).export_library( + str(output), + fcompile=tar.tar, + ) + + return build + + +def _build_webgpu(): + def build(mod: IRModule, args: "CompileArgs"): + output = args.output + mod = _add_prefix_symbol(mod, args.prefix_symbols, is_system_lib=True) + assert output.suffix == ".wasm" + relax.build( + mod, + target=args.target, + system_lib=True, + ).export_library( + str(output), + ) + + return build + + +def _build_default(): + def build(mod: IRModule, args: "CompileArgs"): + output = args.output + if output.suffix in [".tar", ".lib"]: + system_lib = True + elif output.suffix in [".so", ".dylib", ".dll"]: + system_lib = False + else: + logger.warning("Unknown output suffix: %s. Assuming shared library.", output.suffix) + system_lib = False + mod = _add_prefix_symbol(mod, args.prefix_symbols, is_system_lib=system_lib) + relax.build( + mod, + target=args.target, + system_lib=system_lib, + ).export_library( + str(output), + ) + + return build + + +AUTO_DETECT_DEVICES = ["cuda", "rocm", "metal", "vulkan"] + +PRESET = { + "iphone:generic": { + "target": { + "kind": "metal", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + "libs": ["iphoneos"], + "host": { + "kind": "llvm", + "mtriple": "arm64-apple-darwin", + }, + }, + "build": _build_iphone, + }, + "android:generic": { + "target": { + "kind": "opencl", + "host": { + "kind": "llvm", + "mtriple": "aarch64-linux-android", + }, + }, + "build": _build_android, + }, + "metal:x86-64": { + "target": { + "kind": "metal", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + }, + "build": _build_metal_x86_64, + }, + "webgpu:generic": { + "target": { + "kind": "webgpu", + "host": { + "kind": "llvm", + "mtriple": "wasm32-unknown-unknown-wasm", + }, + }, + "build": _build_webgpu, + }, + "opencl:generic": { + "target": { + "kind": "opencl", + }, + }, + "mali:generic": { + "target": { + "kind": "opencl", + "host": { + "kind": "llvm", + "mtriple": "aarch64-linux-gnu", + }, + }, + }, + "metal:generic": { + "target": { + "kind": "metal", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + }, + }, + "vulkan:generic": { + "target": { + "kind": "vulkan", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + "supports_float16": 1, + "supports_int16": 1, + "supports_int8": 1, + "supports_8bit_buffer": 1, + "supports_16bit_buffer": 1, + "supports_storage_buffer_storage_class": 1, + }, + }, +} diff --git a/python/mlc_chat/support/auto_weight.py b/python/mlc_chat/support/auto_weight.py new file mode 100644 index 0000000000..042e7b5366 --- /dev/null +++ b/python/mlc_chat/support/auto_weight.py @@ -0,0 +1,132 @@ +"""Help functions for detecting weight paths and weight formats.""" +import json +import logging +from pathlib import Path +from typing import Tuple + +from .style import green, red + +logger = logging.getLogger(__name__) + +FOUND = green("Found") +NOT_FOUND = red("Not found") + + +def detect_weight( + weight_path: Path, + config_json_path: Path, + weight_format: str = "auto", +) -> Tuple[Path, str]: + """Detect the weight directory, and detect the weight format. + + Parameters + --------- + weight_path : pathlib.Path + The path to weight files. If `weight_path` is not None, check if it exists. Otherwise, find + `weight_path` in `config.json` or use the same directory as `config.json`. + + config_json_path: pathlib.Path + The path to `config.json`. + + weight_format : str + The hint for the weight format. If it is "auto", guess the weight format. + Otherwise, check the weights are in that format. + Available weight formats: + - auto (guess the weight format) + - PyTorch (validate via checking pytorch_model.bin.index.json) + - SafeTensor (validate via checking model.safetensors.index.json) + - AWQ + - GGML/GGUF + + Returns + ------- + weight_path : pathlib.Path + The path that points to the weights. + + weight_format : str + The valid weight format. + """ + if weight_path is None: + assert ( + config_json_path is not None and config_json_path.exists() + ), "Please provide config.json path." + + # 1. Find the weight_path in config.json + with open(config_json_path, encoding="utf-8") as i_f: + config = json.load(i_f) + if "weight_path" in config: + weight_path = Path(config["weight_path"]) + logger.info('Found "weight_path" in config.json: %s', weight_path) + if not weight_path.exists(): + raise ValueError(f"weight_path doesn't exist: {weight_path}") + else: + # 2. Find the weights file in the same directory as config.json + weight_path = config_json_path.parent + else: + if not weight_path.exists(): + raise ValueError(f"weight_path doesn't exist: {weight_path}") + + logger.info("%s weights from directory: %s", FOUND, weight_path) + + # check weight format + # weight_format = "auto", guess the weight format. + # otherwise, check the weight format is valid. + if weight_format == "auto": + weight_format = _guess_weight_format(weight_path) + + if weight_format not in AVAILABLE_WEIGHT_FORMAT: + raise ValueError( + f"Available weight format list: {AVAILABLE_WEIGHT_FORMAT}, but got {weight_format}" + ) + if weight_format in CHECK_FORMAT_METHODS: + check_func = CHECK_FORMAT_METHODS[weight_format] + if not check_func(weight_path): + raise ValueError(f"The weight is not in {weight_format} format.") + return weight_path, weight_format + + +def _guess_weight_format(weight_path: Path): + possible_formats = [] + for weight_format, check_func in CHECK_FORMAT_METHODS.items(): + if check_func(weight_path): + possible_formats.append(weight_format) + + if len(possible_formats) == 0: + raise ValueError( + "Fail to detect weight format. Use `--weight-format` to manually specify the format." + ) + + selected_format = possible_formats[0] + logger.info( + "Using %s format now. Use `--weight-format` to manually specify the format.", + selected_format, + ) + return selected_format + + +def _check_pytorch(weight_path: Path): + pytorch_json_path = weight_path / "pytorch_model.bin.index.json" + result = pytorch_json_path.exists() + if result: + logger.info("%s Huggingface PyTorch: %s", FOUND, pytorch_json_path) + else: + logger.info("%s Huggingface PyTorch", NOT_FOUND) + return result + + +def _check_safetensor(weight_path: Path): + safetensor_json_path = weight_path / "model.safetensors.index.json" + result = safetensor_json_path.exists() + if result: + logger.info("%s SafeTensor: %s", FOUND, safetensor_json_path) + else: + logger.info("%s SafeTensor", NOT_FOUND) + return result + + +CHECK_FORMAT_METHODS = { + "PyTorch": _check_pytorch, + "SafeTensor": _check_safetensor, +} + +AVAILABLE_WEIGHT_FORMAT = ["PyTorch", "SafeTensor", "GGML", "GGUF", "AWQ"] diff --git a/mlc_llm/models/model_config_base.py b/python/mlc_chat/support/config.py similarity index 62% rename from mlc_llm/models/model_config_base.py rename to python/mlc_chat/support/config.py index 85ac46dfc2..9e42b815bc 100644 --- a/mlc_llm/models/model_config_base.py +++ b/python/mlc_chat/support/config.py @@ -1,18 +1,24 @@ """ -Utilities that handle model configuration. Model configuration is usually a JSON file in HuggingFace -that contains the model's hyperparameters. For instance, Vicuna-13b-v1.5-16k contains the following -config file: https://huggingface.co/lmsys/vicuna-13b-v1.5-16k/blob/main/config.json +A common base class for configuration. A configuration could be initialized from its constructor, +a JSON string or a JSON file, and irrelevant fields during initialization are automatically moved +to the `kwargs` field. + +Take model configuration as an example: it is usually a JSON file in HuggingFace that contains +the model's hyperparameters. For instance, Vicuna-13b-v1.5-16k contains the following +[JSON file](https://huggingface.co/lmsys/vicuna-13b-v1.5-16k/blob/main/config.json). +The base class allows us to load the configuration from this JSON file, moving irrelevant fields +into `kwargs`, such as `transformers_version` and `use_cache`. """ import dataclasses import json from pathlib import Path from typing import Any, Dict, Type, TypeVar -ConfigClass = TypeVar("ConfigClass", bound="ModelConfig") +ConfigClass = TypeVar("ConfigClass", bound="ConfigBase") -class ModelConfig: - """Base class for model configurations, providing a common interface for loading configs from a +class ConfigBase: + """Base class for configurations, providing a common interface for loading configs from a JSON file or a dict. It requires the subclasses to be dataclasses, and has an `kwargs` field that stores the extra fields that are not defined in the dataclass. """ @@ -31,10 +37,10 @@ def from_dict(cls: Type[ConfigClass], source: Dict[str, Any]) -> ConfigClass: cfg : ConfigClass An instance of the config object. """ - field_names = [field.name for field in dataclasses.fields(cls)] + field_names = [field.name for field in dataclasses.fields(cls)] # type: ignore[arg-type] fields = {k: v for k, v in source.items() if k in field_names} kwargs = {k: v for k, v in source.items() if k not in field_names} - return cls(**fields, kwargs=kwargs) + return cls(**fields, kwargs=kwargs) # type: ignore[call-arg] @classmethod def from_file(cls: Type[ConfigClass], source: Path) -> ConfigClass: @@ -55,3 +61,6 @@ def from_file(cls: Type[ConfigClass], source: Path) -> ConfigClass: """ with source.open("r", encoding="utf-8") as in_file: return cls.from_dict(json.load(in_file)) + + +__all__ = ["ConfigBase"] diff --git a/python/mlc_chat/support/style.py b/python/mlc_chat/support/style.py new file mode 100644 index 0000000000..5b2272e1a0 --- /dev/null +++ b/python/mlc_chat/support/style.py @@ -0,0 +1,62 @@ +"""Printing styles.""" + +from enum import Enum + + +class Styles(Enum): + """Predefined set of styles to be used. + + Reference: + - https://en.wikipedia.org/wiki/ANSI_escape_code#3-bit_and_4-bit + - https://stackoverflow.com/a/17303428 + """ + + RED = "\033[91m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + BLUE = "\033[94m" + PURPLE = "\033[95m" + CYAN = "\033[96m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + END = "\033[0m" + + +def red(text: str) -> str: + """Return red text.""" + return f"{Styles.RED.value}{text}{Styles.END.value}" + + +def green(text: str) -> str: + """Return green text.""" + return f"{Styles.GREEN.value}{text}{Styles.END.value}" + + +def yellow(text: str) -> str: + """Return yellow text.""" + return f"{Styles.YELLOW.value}{text}{Styles.END.value}" + + +def blue(text: str) -> str: + """Return blue text.""" + return f"{Styles.BLUE.value}{text}{Styles.END.value}" + + +def purple(text: str) -> str: + """Return purple text.""" + return f"{Styles.PURPLE.value}{text}{Styles.END.value}" + + +def cyan(text: str) -> str: + """Return cyan text.""" + return f"{Styles.CYAN.value}{text}{Styles.END.value}" + + +def bold(text: str) -> str: + """Return bold text.""" + return f"{Styles.BOLD.value}{text}{Styles.END.value}" + + +def underline(text: str) -> str: + """Return underlined text.""" + return f"{Styles.UNDERLINE.value}{text}{Styles.END.value}" diff --git a/python/mlc_chat/support/tqdm.py b/python/mlc_chat/support/tqdm.py new file mode 100644 index 0000000000..9adceca480 --- /dev/null +++ b/python/mlc_chat/support/tqdm.py @@ -0,0 +1,38 @@ +"""Utils to better use tqdm""" +import contextlib +import inspect +import io + +from tqdm import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm as _redirect_logging + + +@contextlib.contextmanager +def _redirect_print(): + old_print = print + + def new_print(*args, **kwargs): + with io.StringIO() as output: + kwargs["file"] = output + kwargs["end"] = "" + old_print(*args, **kwargs) + content = output.getvalue() + tqdm.write(content) + + try: + inspect.builtins.print = new_print + yield + finally: + inspect.builtins.print = old_print + + +@contextlib.contextmanager +def redirect(): + """Redirect tqdm output to logging and print.""" + + with _redirect_logging(): + with _redirect_print(): + yield + + +__all__ = ["tqdm", "redirect"] diff --git a/python/setup.py b/python/setup.py index aa7394f1d2..af471c19c0 100644 --- a/python/setup.py +++ b/python/setup.py @@ -2,7 +2,6 @@ """Setup MLC LLM package.""" import os import shutil -import sys from setuptools import find_packages, setup from setuptools.dist import Distribution @@ -16,7 +15,8 @@ def get_lib_path(): # Directly exec libinfo to get the right setup libinfo_py = os.path.join(CURRENT_DIR, "./mlc_chat/libinfo.py") libinfo = {"__file__": libinfo_py} - exec(compile(open(libinfo_py, "rb").read(), libinfo_py, "exec"), libinfo, libinfo) + with open(libinfo_py, "rb") as f: + exec(compile(f.read(), libinfo_py, "exec"), libinfo, libinfo) version = libinfo["__version__"] # conda installs libraries into env instead of packaging with pip @@ -35,10 +35,11 @@ def git_describe_version(original_version): """Get git describe version.""" ver_py = os.path.join(CURRENT_DIR, "..", "version.py") libver = {"__file__": ver_py} - exec(compile(open(ver_py, "rb").read(), ver_py, "exec"), libver, libver) + with open(ver_py, "rb") as f: + exec(compile(f.read(), ver_py, "exec"), libver, libver) _, gd_version = libver["git_describe_version"]() if gd_version is not None and gd_version != original_version: - print("Use git describe based version %s" % gd_version) + print(f"Use git describe based version {gd_version}") return gd_version @@ -47,60 +48,66 @@ def git_describe_version(original_version): class BinaryDistribution(Distribution): + """This class is needed in order to create OS specific wheels.""" + def has_ext_modules(self): + """Return True for binary distribution.""" return True def is_pure(self): + """Return False for binary distribution.""" return False -setup_kwargs = {} -if not CONDA_BUILD: - with open("MANIFEST.in", "w") as fo: - for path in LIB_LIST: +def main(): + """The main entrypoint.""" + setup_kwargs = {} + if not CONDA_BUILD: + with open("MANIFEST.in", "w", encoding="utf-8") as fo: + for path in LIB_LIST: + if os.path.isfile(path): + shutil.copy(path, os.path.join(CURRENT_DIR, "mlc_chat")) + _, libname = os.path.split(path) + fo.write(f"include mlc_chat/{libname}\n") + setup_kwargs = {"include_package_data": True} + + setup( + name="mlc_chat", + version=__version__, + description="MLC Chat: an universal runtime running LLMs", + url="https://llm.mlc.ai/", + author="MLC LLM Contributors", + license="Apache 2.0", + # See https://pypi.org/classifiers/ + classifiers=[ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + ], + keywords="machine learning", + zip_safe=False, + packages=find_packages(), + package_dir={"mlc_chat": "mlc_chat"}, + install_requires=["fastapi", "uvicorn", "shortuuid"], + distclass=BinaryDistribution, + **setup_kwargs, + ) + + def _remove_path(path): + if os.path.exists(path): if os.path.isfile(path): - shutil.copy(path, os.path.join(CURRENT_DIR, "mlc_chat")) - _, libname = os.path.split(path) - fo.write(f"include mlc_chat/{libname}\n") - setup_kwargs = {"include_package_data": True} - - -setup( - name="mlc_chat", - version=__version__, - description="MLC Chat: an universal runtime running LLMs", - url="https://llm.mlc.ai/", - author="MLC LLM Contributors", - license="Apache 2.0", - # See https://pypi.org/classifiers/ - classifiers=[ - "License :: OSI Approved :: Apache Software License", - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", - ], - keywords="machine learning", - zip_safe=False, - packages=find_packages(), - package_dir={"mlc_chat": "mlc_chat"}, - install_requires=["fastapi", "uvicorn", "shortuuid"], - distclass=BinaryDistribution, - **setup_kwargs, -) - - -def _remove_path(path): - if os.path.exists(path): - if os.path.isfile(path): - os.remove(path) - elif os.path.isdir(path): - shutil.rmtree(path) - - -if not CONDA_BUILD: - # Wheel cleanup - os.remove("MANIFEST.in") - for path in LIB_LIST: - _, libname = os.path.split(path) - _remove_path(f"mlc_chat/{libname}") + os.remove(path) + elif os.path.isdir(path): + shutil.rmtree(path) + + if not CONDA_BUILD: + # Wheel cleanup + os.remove("MANIFEST.in") + for path in LIB_LIST: + _, libname = os.path.split(path) + _remove_path(f"mlc_chat/{libname}") + + +main() diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 31673df90d..f0b74fa5c2 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -327,7 +327,7 @@ def copy_to_worker_0(sess: di.Session, host_array): def get_tvm_model(artifact_path, model, quantization, num_shards, dev): - model_artifact_path = os.path.join(artifact_path, f"{model}-{quantization}-batched") + model_artifact_path = os.path.join(artifact_path, f"{model}-{quantization}") lib_path = os.path.join(model_artifact_path, f"{model}-{quantization}-cuda.so") if num_shards == 1: diff --git a/site/img/multi-gpu/figure-1.svg b/site/img/multi-gpu/figure-1.svg new file mode 100644 index 0000000000..d3083cf775 --- /dev/null +++ b/site/img/multi-gpu/figure-1.svg @@ -0,0 +1,247 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/site/img/multi-gpu/figure-2.svg b/site/img/multi-gpu/figure-2.svg new file mode 100644 index 0000000000..70d35f5037 --- /dev/null +++ b/site/img/multi-gpu/figure-2.svg @@ -0,0 +1,418 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/site/img/multi-gpu/figure-3.svg b/site/img/multi-gpu/figure-3.svg new file mode 100644 index 0000000000..078231fae6 --- /dev/null +++ b/site/img/multi-gpu/figure-3.svg @@ -0,0 +1,167 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/cpp/conv_unittest.cc b/tests/cpp/conv_unittest.cc index 214736320d..98d01a58ba 100644 --- a/tests/cpp/conv_unittest.cc +++ b/tests/cpp/conv_unittest.cc @@ -24,6 +24,4 @@ TEST(ConversationTest, ConversationJSONRoundTripTest) { _TestConversationJSONRoundTrip("LM"); } -TEST(ConversationTest, ConversationPartialUpdateTest) { - _TestConversationPartialUpdate(); -} +TEST(ConversationTest, ConversationPartialUpdateTest) { _TestConversationPartialUpdate(); } diff --git a/tests/debug/compare_lib.py b/tests/legacy-python/compare_lib.py similarity index 93% rename from tests/debug/compare_lib.py rename to tests/legacy-python/compare_lib.py index 9c2e35f014..5bcea1e699 100644 --- a/tests/debug/compare_lib.py +++ b/tests/legacy-python/compare_lib.py @@ -1,17 +1,14 @@ -from typing import List - import argparse -import os import json +import os +from typing import List -import tvm -from tvm import relax -from tvm import rpc -from tvm.relax.testing.lib_comparator import LibCompareVMInstrument import numpy as np - import torch +import tvm from transformers import AutoTokenizer, LlamaTokenizer +from tvm import relax, rpc +from tvm.relax.testing.lib_comparator import LibCompareVMInstrument from mlc_llm import utils @@ -53,7 +50,7 @@ def compare( if self.time_eval and name not in self.time_eval_results: res = self.mod.time_evaluator( - name, self.device, number=20, repeat=3#, cache_flush_bytes=256 * 10**6 + name, self.device, number=20, repeat=3 # , cache_flush_bytes=256 * 10**6 )(*new_args) self.time_eval_results[name] = (res.mean, 1) print(f"Time-eval result {name} on {self.device}: {res}") @@ -121,9 +118,7 @@ def __init__(self, args): ) ) self.cmp_device = tvm.device(args.cmp_device) - self.const_params_dict = utils.load_params( - args.artifact_path, self.primary_device - ) + self.const_params_dict = utils.load_params(args.artifact_path, self.primary_device) self.cmp_instrument = LibCompare( self.lib, self.cmp_device, @@ -134,9 +129,7 @@ def __init__(self, args): def deploy_to_pipeline(args) -> None: - with open( - os.path.join(args.artifact_path, "params", "mlc-chat-config.json"), "r" - ) as f: + with open(os.path.join(args.artifact_path, "params", "mlc-chat-config.json"), "r") as f: config = json.load(f) primary_device = tvm.device(args.primary_device) @@ -157,18 +150,14 @@ def deploy_to_pipeline(args) -> None: tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy(), primary_device, ) - first_sampled_token = tvm.nd.array( - np.array([[6234]]).astype("int32"), primary_device - ) + first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), primary_device) seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]]) second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1]) kv_caches = state.vm["create_kv_cache"]() print("Running inference...") print("======================= Starts Encoding =======================") - logits, kv_caches = state.vm["prefill"]( - inputs, seq_len_shape, kv_caches, const_params - ) + logits, kv_caches = state.vm["prefill"](inputs, seq_len_shape, kv_caches, const_params) print_as_table( sorted( state.cmp_instrument.time_eval_results.items(), diff --git a/tests/debug/dump_intermediate.py b/tests/legacy-python/dump_intermediate.py similarity index 95% rename from tests/debug/dump_intermediate.py rename to tests/legacy-python/dump_intermediate.py index 84cc8c74b1..52536ad760 100644 --- a/tests/debug/dump_intermediate.py +++ b/tests/legacy-python/dump_intermediate.py @@ -1,12 +1,12 @@ import argparse import os +import pickle import numpy as np import torch import tvm from transformers import AutoTokenizer from tvm import relax -import pickle from mlc_llm import utils @@ -77,12 +77,8 @@ def deploy_to_pipeline(args) -> None: ) print("Tokenizing...") - inputs = ( - tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy() - ) - first_sampled_token = tvm.nd.array( - np.array([[6234]]).astype("int32"), primary_device - ) + inputs = tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy() + first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), primary_device) seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]]) second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1]) kv_caches = state.vm["create_kv_cache"]() diff --git a/tests/evaluate.py b/tests/legacy-python/evaluate.py similarity index 96% rename from tests/evaluate.py rename to tests/legacy-python/evaluate.py index f37fdabf5f..4a370c517c 100644 --- a/tests/evaluate.py +++ b/tests/legacy-python/evaluate.py @@ -58,9 +58,7 @@ def compare( repeat=3, )(*new_args).mean shapes = [arg.shape for arg in new_args] - total_bytes = sum( - arg.numpy().size * arg.numpy().itemsize for arg in new_args - ) + total_bytes = sum(arg.numpy().size * arg.numpy().itemsize for arg in new_args) self.time_eval_results[name] = (res, 1, shapes, total_bytes) else: record = self.time_eval_results[name] @@ -177,9 +175,7 @@ def deploy_to_pipeline(args) -> None: # pylint: disable=too-many-locals print("Profiling...") kv_caches = vm["create_kv_cache"]() - logits, kv_caches = vm["prefill"]( - inputs, seq_len_shape, kv_caches, const_params - ) + logits, kv_caches = vm["prefill"](inputs, seq_len_shape, kv_caches, const_params) print("======================= Encoding Profiling =======================") print_as_table( sorted( diff --git a/tests/legacy-python/test_batching_llama.py b/tests/legacy-python/test_batching_llama.py new file mode 100644 index 0000000000..ff11188e4b --- /dev/null +++ b/tests/legacy-python/test_batching_llama.py @@ -0,0 +1,160 @@ +# pylint: disable=invalid-name,missing-docstring +# Used as reference + +import argparse +import json +import os + +import numpy as np +import torch +import tvm +from transformers import LlamaTokenizer # type: ignore[import] +from tvm import relax +from tvm.runtime import ShapeTuple + +from mlc_llm import utils + +############################################################## +# Test file for e2e Llama with batching enabled by directly +# calling functions in VM. +# +# NOTE: the test will not be runnable until the attention +# compute function is integrated to Llama. This is left as +# an item that we will work on shortly in the future. +############################################################## + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument("--local-id", type=str, default="Llama-2-7b-chat-hf-q4f16_1") + args.add_argument("--device-name", type=str, default="auto") + args.add_argument("--artifact-path", type=str, default="dist") + args.add_argument("--prompt", type=str, default="What's the meaning of life?") + args.add_argument("--profile", action="store_true", default=False) + parsed = args.parse_args() + parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1) + utils.argparse_postproc_common(parsed) + parsed.artifact_path = os.path.join( + parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" + ) + return parsed + + +def sample_from_logits(vm, logits, device): + temperature = 0.7 + top_p = 0.95 + + num_sequence = logits.shape[0] + temperature_arr = tvm.nd.array(np.full((num_sequence,), temperature, dtype="float32"), device) + probs = vm["softmax_with_temperature"](logits, temperature_arr).numpy() + + sampled_tokens = [] + fsample_top_p_from_prob = tvm.get_global_func("vm.builtin.sample_top_p_from_prob") + for seq_id in range(num_sequence): + token = fsample_top_p_from_prob(tvm.nd.array(probs[seq_id]), top_p, np.random.sample()) + sampled_tokens.append(token) + return sampled_tokens + + +def deploy_to_pipeline(args) -> None: # pylint: disable=too-many-locals + device = tvm.device(args.device_name) + const_params = utils.load_params(args.artifact_path, device) + ex = tvm.runtime.load_module( + os.path.join( + args.artifact_path, + f"{args.model}-{args.quantization.name}-{args.device_name}.so", + ) + ) + vm = relax.VirtualMachine(ex, device) + + with open( + os.path.join(args.artifact_path, "params", "mlc-chat-config.json"), + "r", + encoding="utf-8", + ) as f: + config = json.load(f) + + assert config["model_category"] == "llama" + tokenizer = LlamaTokenizer.from_pretrained( + os.path.join(args.artifact_path, "params"), trust_remote_code=True + ) + + num_sequences = 4 + generated_tokens = [[], [], [], []] + prompts = [ + "What's the meaning of life?", + "Introduce the history of Pittsburgh to me.", + "Write a three-day Seattle travel plan.", + "What is Alaska famous of?", + ] + num_decode_steps = 256 + + print("Create KV cache...") + max_total_seq_len = 16384 + page_size = 16 + kv_cache = vm["create_kv_cache"](ShapeTuple([num_sequences, max_total_seq_len, page_size])) + + fadd_sequence = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_add_sequence") + freset_append_length = tvm.get_global_func( + "vm.builtin.paged_attention_kv_cache_reset_append_lengths" + ) + freserve = tvm.get_global_func( + "vm.builtin.paged_attention_kv_cache_reserve_extra_length_for_append" + ) + fsync = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_sync_aux_array_to_device") + + for seq_id in range(num_sequences): + print(f"Process seq {seq_id} for prefill...") + inputs = tvm.nd.array( + tokenizer(prompts[seq_id], return_tensors="pt").input_ids.to(torch.int32).numpy(), + device, + ) + seq_length = inputs.shape[1] + embedding = vm["embed"](inputs, const_params) + + seq_id_in_cache = fadd_sequence(kv_cache) + assert seq_id_in_cache == seq_id + + freset_append_length(kv_cache) + freserve(kv_cache, seq_id, seq_length) + fsync(kv_cache) + + print(f"Prefilling seq {seq_id}...") + logits, _ = vm["prefill_with_embed"](embedding, kv_cache, const_params) + + tokens = sample_from_logits(vm, logits, device) + assert len(tokens) == 1 + generated_tokens[seq_id].append(tokens[0]) + + print("Decoding...") + for step in range(num_decode_steps): + inputs = tvm.nd.array( + np.array( + [[generated_tokens[seq_id][-1]] for seq_id in range(num_sequences)], dtype="int32" + ), + device, + ) + embedding = vm["embed"](inputs, const_params) + freset_append_length(kv_cache) + for seq_id in range(num_sequences): + freserve(kv_cache, seq_id, 1) + fsync(kv_cache) + + logits, _ = vm["decode_with_embed"](embedding, kv_cache, const_params) + tokens = sample_from_logits(vm, logits, device) + assert len(tokens) == num_sequences + + for seq_id in range(num_sequences): + generated_tokens[seq_id].append(tokens[seq_id]) + + for seq_id in range(num_sequences): + output = tokenizer.decode(generated_tokens[seq_id]) + print("====================================================================") + print(f"Prompt {seq_id}: {prompts[seq_id]}") + print(f"Output: {output}") + print("\n\n") + + +if __name__ == "__main__": + ARGS = _parse_args() + deploy_to_pipeline(ARGS) diff --git a/tests/python/test_build_args.py b/tests/legacy-python/test_build_args.py similarity index 89% rename from tests/python/test_build_args.py rename to tests/legacy-python/test_build_args.py index 3805b29199..8f32d123b6 100644 --- a/tests/python/test_build_args.py +++ b/tests/legacy-python/test_build_args.py @@ -3,11 +3,12 @@ import dataclasses import unittest -from mlc_llm import BuildArgs, utils, core +from mlc_llm import BuildArgs, core, utils + def old_make_args(): """The exact old way of creating `ArgumentParser`, used to test whether - `BuildArgs` is equivalent to this. """ + `BuildArgs` is equivalent to this.""" args = argparse.ArgumentParser() args.add_argument( "--model", @@ -17,7 +18,7 @@ def old_make_args(): 'The name of the model to build. If it is "auto", we will ' 'automatically set the model name according to "--model-path", ' '"hf-path" or the model folders under "--artifact-path/models"' - ) + ), ) args.add_argument( "--hf-path", @@ -30,19 +31,16 @@ def old_make_args(): type=str, choices=[*utils.quantization_schemes.keys()], default=list(utils.quantization_schemes.keys())[0], - help="The quantization mode we use to compile." + help="The quantization mode we use to compile.", ) args.add_argument( "--max-seq-len", type=int, default=-1, - help="The maximum allowed sequence length for the model." + help="The maximum allowed sequence length for the model.", ) args.add_argument( - "--target", - type=str, - default="auto", - help="The target platform to compile the model for." + "--target", type=str, default="auto", help="The target platform to compile the model for." ) args.add_argument( "--reuse-lib", @@ -51,10 +49,7 @@ def old_make_args(): help="Whether to reuse a previously generated lib.", ) args.add_argument( - "--artifact-path", - type=str, - default="dist", - help="Where to store the output." + "--artifact-path", type=str, default="dist", help="Where to store the output." ) args.add_argument( "--use-cache", @@ -66,13 +61,13 @@ def old_make_args(): "--debug-dump", action="store_true", default=False, - help="Whether to dump debugging files during compilation." + help="Whether to dump debugging files during compilation.", ) args.add_argument( "--debug-load-script", action="store_true", default=False, - help="Whether to load the script for debugging." + help="Whether to load the script for debugging.", ) args.add_argument( "--llvm-mingw", @@ -81,10 +76,7 @@ def old_make_args(): help="/path/to/llvm-mingw-root, use llvm-mingw to cross compile to windows.", ) args.add_argument( - "--system-lib", - action="store_true", - default=False, - help="A parameter to `relax.build`." + "--system-lib", action="store_true", default=False, help="A parameter to `relax.build`." ) args.add_argument( "--sep-embed", @@ -99,17 +91,20 @@ def old_make_args(): return args + # Referred to HfArgumentParserTest from https://github.com/huggingface/ # transformers/blob/e84bf1f734f87aa2bedc41b9b9933d00fc6add98/tests/utils # /test_hf_argparser.py#L143 class BuildArgsTest(unittest.TestCase): """Tests whether BuildArgs reaches parity with regular ArgumentParser.""" - def argparsers_equal(self, parse_a: argparse.ArgumentParser, - parse_b: argparse.ArgumentParser): + + def argparsers_equal(self, parse_a: argparse.ArgumentParser, parse_b: argparse.ArgumentParser): """ Small helper to check pseudo-equality of parsed arguments on `ArgumentParser` instances. """ - self.assertEqual(len(parse_a._actions), len(parse_b._actions)) # pylint: disable=protected-access + self.assertEqual( + len(parse_a._actions), len(parse_b._actions) + ) # pylint: disable=protected-access for x, y in zip(parse_a._actions, parse_b._actions): # pylint: disable=protected-access xx = {k: v for k, v in vars(x).items() if k != "container"} yy = {k: v for k, v in vars(y).items() if k != "container"} @@ -175,5 +170,6 @@ def test_namespaces_are_equivalent_str_boolean_int(self): build_args_namespace = argparse.Namespace(**build_args_as_dict) self.assertNotEqual(build_args_namespace, parsed_args) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/python/test_build_model_from_args.py b/tests/legacy-python/test_build_model_from_args.py similarity index 64% rename from tests/python/test_build_model_from_args.py rename to tests/legacy-python/test_build_model_from_args.py index a5ce550e9f..c7990d63df 100644 --- a/tests/python/test_build_model_from_args.py +++ b/tests/legacy-python/test_build_model_from_args.py @@ -1,27 +1,25 @@ - import argparse import os import unittest from unittest.mock import MagicMock, mock_open, patch from mlc_llm import utils - from mlc_llm.core import build_model_from_args class MockMkdir(object): def __init__(self): self.received_args = None - + def __call__(self, *args): self.received_args = args -class BuildModelTest(unittest.TestCase): +class BuildModelTest(unittest.TestCase): def setUp(self): self._orig_mkdir = os.mkdir os.mkdir = MockMkdir() - + self.mock_args = argparse.Namespace() self.mock_args.quantization = utils.quantization_schemes["q8f16_1"] self.mock_args.debug_dump = False @@ -38,29 +36,36 @@ def setUp(self): self.mock_args.model = "/tmp/" self.mock_args.target_kind = "cuda" self.mock_args.max_seq_len = 2048 - + def tearDown(self): os.mkdir = self._orig_mkdir @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect = [ {} ])) + @patch("json.load", MagicMock(side_effect=[{}])) def test_llama_model(self, mock_file): self.mock_args.model_category = "llama" build_model_from_args(self.mock_args) @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect = [ { - "use_parallel_residual": False, - "hidden_size": 32, - "intermediate_size": 32, - "num_attention_heads": 32, - "num_hidden_layers": 28, - "vocab_size": 1024, - "rotary_pct": 1, - "rotary_emb_base": 1, - "layer_norm_eps": 1, - } ])) + @patch( + "json.load", + MagicMock( + side_effect=[ + { + "use_parallel_residual": False, + "hidden_size": 32, + "intermediate_size": 32, + "num_attention_heads": 32, + "num_hidden_layers": 28, + "vocab_size": 1024, + "rotary_pct": 1, + "rotary_emb_base": 1, + "layer_norm_eps": 1, + } + ] + ), + ) def test_gpt_neox_model(self, mock_file): self.mock_args.model_category = "gpt_neox" self.mock_args.model = "dolly-test" @@ -68,7 +73,7 @@ def test_gpt_neox_model(self, mock_file): build_model_from_args(self.mock_args) @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect = [ {} ])) + @patch("json.load", MagicMock(side_effect=[{}])) def test_gpt_bigcode_model(self, mock_file): self.mock_args.model_category = "gpt_bigcode" self.mock_args.model = "gpt_bigcode" @@ -76,51 +81,62 @@ def test_gpt_bigcode_model(self, mock_file): build_model_from_args(self.mock_args) @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect = [ {} ])) + @patch("json.load", MagicMock(side_effect=[{}])) def test_minigpt_model(self, mock_file): self.mock_args.model_category = "minigpt" self.mock_args.model = "minigpt4-7b" build_model_from_args(self.mock_args) - @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect = [ { - "vocab_size": 1024, - "n_embd": 32, - "n_inner": 32, - "n_head": 32, - "n_layer": 28, - "bos_token_id": 28, - "eos_token_id": 1, - "rotary_dim": 1, - "tie_word_embeddings": 1, - } ])) + @patch( + "json.load", + MagicMock( + side_effect=[ + { + "vocab_size": 1024, + "n_embd": 32, + "n_inner": 32, + "n_head": 32, + "n_layer": 28, + "bos_token_id": 28, + "eos_token_id": 1, + "rotary_dim": 1, + "tie_word_embeddings": 1, + } + ] + ), + ) def test_gptj_model(self, mock_file): self.mock_args.model_category = "gptj" self.mock_args.model = "gpt-j-" build_model_from_args(self.mock_args) - @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect = [ { - "num_hidden_layers": 16, - "vocab_size": 1024, - "hidden_size": 16, - "intermediate_size": 32, - } ])) + @patch( + "json.load", + MagicMock( + side_effect=[ + { + "num_hidden_layers": 16, + "vocab_size": 1024, + "hidden_size": 16, + "intermediate_size": 32, + } + ] + ), + ) def test_rwkv_model(self, mock_file): self.mock_args.model_category = "rwkv" self.mock_args.model = "rwkv-" build_model_from_args(self.mock_args) - @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect = [ { } ])) + @patch("json.load", MagicMock(side_effect=[{}])) def test_chatglm_model(self, mock_file): self.mock_args.model_category = "chatglm" self.mock_args.model = "chatglm2" - build_model_from_args(self.mock_args) \ No newline at end of file + build_model_from_args(self.mock_args) diff --git a/tests/python/model/test_llama.py b/tests/python/model/test_llama.py new file mode 100644 index 0000000000..9e75247c32 --- /dev/null +++ b/tests/python/model/test_llama.py @@ -0,0 +1,20 @@ +# pylint: disable=invalid-name,missing-docstring +import pytest +from mlc_chat.compiler import MODELS + + +@pytest.mark.parametrize("model_name", ["llama2_7b", "llama2_13b", "llama2_70b"]) +def test_llama2_creation(model_name: str): + model_info = MODELS["llama"] + config = model_info.config.from_predefined(model_name) + model = model_info.model(config) + mod, named_params = model.export_tvm(spec=model.get_default_spec()) + mod.show(black_format=False) + for name, param in named_params: + print(name, param.shape, param.dtype) + + +if __name__ == "__main__": + test_llama2_creation("llama2_7b") + test_llama2_creation("llama2_13b") + test_llama2_creation("llama2_70b") diff --git a/tests/python/parameter/test_group_quantizer.py b/tests/python/parameter/test_group_quantizer.py new file mode 100644 index 0000000000..4c16548b64 --- /dev/null +++ b/tests/python/parameter/test_group_quantizer.py @@ -0,0 +1,157 @@ +# pylint: disable=missing-docstring,too-many-instance-attributes +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Tuple, Union + +import numpy as np +import tvm +from mlc_chat.compiler import MODELS +from mlc_chat.compiler.model.llama_config import LlamaConfig +from mlc_chat.compiler.model.llama_quantization import huggingface_group_quantize +from mlc_chat.compiler.parameter import HuggingFaceLoader +from mlc_chat.support import tqdm +from tvm.runtime import NDArray + +if TYPE_CHECKING: + from tvm.relax.frontend import nn + +logging.basicConfig( + level=logging.DEBUG, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="[{asctime}] {levelname} {filename}:{lineno}: {message}", +) + + +def test_load_torch_llama_group_quantize(base_path: Union[str, Path], target: str = "llvm"): + @dataclass + class TestGroupQuantizeConfig: + name: str = "q4f16_1" + kind: str = "group_quantize" + group_size: int = 32 + weight_dtype: str = "float16" + max_int_value: int = 7 + storage_dtype: str = "uint32" + num_elem_per_storage: int = 8 + num_storage_per_group: int = 4 + quantize_dtype_bits: int = 4 + + def quantize(self, _: "nn.Module") -> "nn.Module": + raise NotImplementedError + + base_path = Path(base_path) + path_config = base_path / "config.json" + path_params = base_path / "pytorch_model.bin.index.json" + + model = MODELS["llama"] + model_config = LlamaConfig.from_file(path_config) + quantize_config = TestGroupQuantizeConfig() + loader = HuggingFaceLoader( + path=path_params, + extern_param_map=model.source["huggingface-torch"](model_config, None), + quantize_param_map=huggingface_group_quantize( + model_config, + quantize_config, + target=tvm.target.Target(target), + ), + ) + with tqdm.redirect(): + for _name, _param in loader.load(): + ... + + +def test_group_quantize_vs_numpy(): + bits = { + "int4": 4, + "int8": 8, + "fp16": 16, + "fp32": 32, + "int32": 32, + "uint32": 32, + } + + # pylint: disable=unused-variable + def group_quantize_np( + w: NDArray, # pylint: disable=invalid-name + quantize_dtype: str = "int4", + storage_dtype: str = "uint32", + group_size: int = 32, + # symmetric: bool = True, + # transpose: bool = False, + ) -> Tuple[NDArray, NDArray]: + # pylint: disable=too-many-locals + def _pad_axis_by_factor(tensor: np.ndarray, axis: int, factor: int) -> np.ndarray: + dim = int(tensor.shape[axis]) + if dim % factor == 0: + return tensor + pad_width = [[0, 0] for i in tensor.shape] + pad_width[axis][1] = factor - (dim % factor) + return np.pad(tensor, pad_width, mode="constant", constant_values=0) + + def _clip( + x: np.ndarray, # pylint: disable=invalid-name + x_min: int, + x_max: int, + dtype: str, + ) -> np.ndarray: + return np.clip(x, a_min=x_min, a_max=x_max).astype(dtype) + + num_elem_per_storage = bits[storage_dtype] // bits[quantize_dtype] + assert group_size % num_elem_per_storage == 0 + num_storage_units = (group_size + num_elem_per_storage - 1) // num_elem_per_storage + + # using numpy for now + w = w.numpy() + + # Step 1. Tile `w`: [n, k'] -> [n, k, group_size] + w = _pad_axis_by_factor(w, axis=1, factor=group_size) + n, k = [int(v) for v in w.shape] # pylint: disable=invalid-name + assert k % group_size == 0, "Padding is not working properly" + k = k // group_size + w = w.reshape([n, k, group_size]) + + # Step 2. Calculate + if quantize_dtype.startswith("int"): + max_int_value = (2 ** (bits[quantize_dtype] - 1)) - 1 + # 1) `scale`: [n, k, group_size] -> [n, k] + scale = np.maximum(np.amax(w, axis=-1), 1e-4) / max_int_value + # 2) `w`: w / scale + + w = _clip( + np.round(w / scale[:, :, np.newaxis]).astype("int") + max_int_value, + x_min=0, + x_max=max_int_value * 2, + dtype=storage_dtype, + ) + else: + raise NotImplementedError + + # Step 3. Compress `w` to every `num_elem_per_storage` elements + res = np.zeros((n, k, num_storage_units), dtype=np.uint32) + for i in range(n): + for j in range(k): + for m in range(num_storage_units): # pylint: disable=invalid-name + for k in range(num_elem_per_storage): + res[i, j, m] += w[i, j, m * num_elem_per_storage + k] * 2**k + return tvm.nd.array(res), tvm.nd.array(scale) + # pylint: enable=too-many-locals + + +if __name__ == "__main__": + test_load_torch_llama_group_quantize( + base_path="./dist/models/Llama-2-7b-hf", + target="llvm", + ) + test_load_torch_llama_group_quantize( + base_path="./dist/models/Llama-2-7b-hf", + target="nvidia/nvidia-a100", + ) + test_load_torch_llama_group_quantize( + base_path="./dist/models/Llama-2-13b-hf", + target="llvm", + ) + test_load_torch_llama_group_quantize( + base_path="./dist/models/Llama-2-13b-hf", + target="nvidia/nvidia-a100", + ) diff --git a/tests/python/parameter/test_huggingface.py b/tests/python/parameter/test_huggingface.py new file mode 100644 index 0000000000..ecd8e16455 --- /dev/null +++ b/tests/python/parameter/test_huggingface.py @@ -0,0 +1,76 @@ +# pylint: disable=missing-docstring +import logging +from pathlib import Path +from typing import Union + +import pytest +from mlc_chat.compiler import MODELS + +# from mlc_chat.compiler.model.llama_config import LlamaConfig +# from mlc_chat.compiler.model.llama_parameter import huggingface +from mlc_chat.compiler.parameter import HuggingFaceLoader +from mlc_chat.support import tqdm + +logging.basicConfig( + level=logging.DEBUG, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="[{asctime}] {levelname} {filename}:{lineno}: {message}", +) + + +@pytest.mark.parametrize( + "base_path", + [ + "./dist/models/Llama-2-7b-hf", + "./dist/models/Llama-2-13b-hf", + "./dist/models/Llama-2-70b-hf", + ], +) +def test_load_torch_llama(base_path: Union[str, Path]): + base_path = Path(base_path) + path_config = base_path / "config.json" + path_params = base_path / "pytorch_model.bin.index.json" + + model = MODELS["llama"] + config = model.config.from_file(path_config) + loader = HuggingFaceLoader( + path=path_params, + extern_param_map=model.source["huggingface-torch"](config, None), + ) + with tqdm.redirect(): + for _name, _param in loader.load(): + return # To reduce the time of the test + + +@pytest.mark.parametrize( + "base_path", + [ + "./dist/models/Llama-2-7b-hf", + "./dist/models/Llama-2-13b-hf", + "./dist/models/Llama-2-70b-hf", + ], +) +def test_load_safetensor_llama(base_path: Union[str, Path]): + base_path = Path(base_path) + path_config = base_path / "config.json" + path_params = base_path / "model.safetensors.index.json" + + model = MODELS["llama"] + config = model.config.from_file(path_config) + loader = HuggingFaceLoader( + path=path_params, + extern_param_map=model.source["huggingface-safetensor"](config, None), + ) + with tqdm.redirect(): + for _name, _param in loader.load(): + return # To reduce the time of the test + + +if __name__ == "__main__": + test_load_torch_llama(base_path="./dist/models/Llama-2-7b-hf") + test_load_torch_llama(base_path="./dist/models/Llama-2-13b-hf") + test_load_torch_llama(base_path="./dist/models/Llama-2-70b-hf") + test_load_safetensor_llama(base_path="./dist/models/Llama-2-7b-hf") + test_load_safetensor_llama(base_path="./dist/models/Llama-2-13b-hf") + test_load_safetensor_llama(base_path="./dist/models/Llama-2-70b-hf") diff --git a/tests/python/support/test_auto_config.py b/tests/python/support/test_auto_config.py new file mode 100644 index 0000000000..540c544c22 --- /dev/null +++ b/tests/python/support/test_auto_config.py @@ -0,0 +1,45 @@ +# pylint: disable=missing-docstring +import json +import logging +import tempfile +from pathlib import Path + +import pytest +from mlc_chat.support.auto_config import detect_config + +logging.basicConfig( + level=logging.INFO, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="{asctime} {levelname} {filename}:{lineno}: {message}", +) + + +def _create_json_file(json_path, data): + with open(json_path, "w", encoding="utf-8") as i_f: + json.dump(data, i_f) + + +def test_detect_config(): + with tempfile.TemporaryDirectory() as tmpdir: + base_path = Path(tmpdir) + config_json_path = base_path / "config.json" + _create_json_file(config_json_path, {}) + + assert detect_config(base_path) == config_json_path + assert detect_config(config_json_path) == config_json_path + + +def test_detect_config_fail(): + with pytest.raises(ValueError): + detect_config(Path("do/not/exist")) + + with tempfile.TemporaryDirectory() as tmpdir: + base_path = Path(tmpdir) + with pytest.raises(ValueError): + assert detect_config(base_path) + + +if __name__ == "__main__": + test_detect_config() + test_detect_config_fail() diff --git a/tests/python/support/test_auto_weight.py b/tests/python/support/test_auto_weight.py new file mode 100644 index 0000000000..2987135267 --- /dev/null +++ b/tests/python/support/test_auto_weight.py @@ -0,0 +1,126 @@ +# pylint: disable=missing-docstring +import json +import logging +import os +import tempfile +from pathlib import Path + +import pytest +from mlc_chat.support.auto_weight import detect_weight + +logging.basicConfig( + level=logging.INFO, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="{asctime} {levelname} {filename}:{lineno}: {message}", +) + + +def _create_json_file(json_path, data): + with open(json_path, "w", encoding="utf-8") as i_f: + json.dump(data, i_f) + + +@pytest.mark.parametrize( + "weight_format, index_filename, result", + [ + ("PyTorch", "pytorch_model.bin.index.json", "PyTorch"), + ("SafeTensor", "model.safetensors.index.json", "SafeTensor"), + ("GGML", None, "GGML"), + ("GGUF", None, "GGUF"), + ("AWQ", None, "AWQ"), + ("auto", "pytorch_model.bin.index.json", "PyTorch"), + ("auto", "model.safetensors.index.json", "SafeTensor"), + ], +) +def test_detect_weight(weight_format, index_filename, result): + with tempfile.TemporaryDirectory() as tmpdir: + base_path = Path(tmpdir) + if index_filename is not None: + weight_index_file = base_path / index_filename + _create_json_file(weight_index_file, {}) + assert detect_weight(base_path, None, weight_format) == (base_path, result) + + +@pytest.mark.parametrize( + "weight_format, index_filename, result", + [ + ("PyTorch", "pytorch_model.bin.index.json", "PyTorch"), + ("SafeTensor", "model.safetensors.index.json", "SafeTensor"), + ("GGML", None, "GGML"), + ("GGUF", None, "GGUF"), + ("AWQ", None, "AWQ"), + ("auto", "pytorch_model.bin.index.json", "PyTorch"), + ("auto", "model.safetensors.index.json", "SafeTensor"), + ], +) +def test_detect_weight_in_config_json(weight_format, index_filename, result): + with tempfile.TemporaryDirectory() as config_dir, tempfile.TemporaryDirectory() as weight_dir: + config_path = Path(config_dir) + weight_path = Path(weight_dir) + config_json_path = config_path / "config.json" + _create_json_file(config_json_path, {"weight_path": weight_dir}) + if index_filename is not None: + weight_index_file = weight_path / index_filename + _create_json_file(weight_index_file, {}) + + assert detect_weight(None, config_json_path, weight_format) == (weight_path, result) + + +@pytest.mark.parametrize( + "weight_format, index_filename, result", + [ + ("PyTorch", "pytorch_model.bin.index.json", "PyTorch"), + ("SafeTensor", "model.safetensors.index.json", "SafeTensor"), + ("GGML", None, "GGML"), + ("GGUF", None, "GGUF"), + ("AWQ", None, "AWQ"), + ("auto", "pytorch_model.bin.index.json", "PyTorch"), + ("auto", "model.safetensors.index.json", "SafeTensor"), + ], +) +def test_detect_weight_same_dir_config_json(weight_format, index_filename, result): + with tempfile.TemporaryDirectory() as tmpdir: + base_path = Path(tmpdir) + config_json_path = base_path / "config.json" + _create_json_file(config_json_path, {}) + if index_filename is not None: + weight_index_file = os.path.join(tmpdir, index_filename) + _create_json_file(weight_index_file, {}) + assert detect_weight(None, config_json_path, weight_format) == (base_path, result) + + +def test_find_weight_fail(): + with tempfile.TemporaryDirectory() as tmpdir: + base_path = Path(tmpdir) + with pytest.raises(ValueError): + detect_weight(Path("do/not/exist"), base_path, "AWQ") + with pytest.raises(AssertionError): + detect_weight(None, Path("do/not/exist"), "AWQ") + + +if __name__ == "__main__": + test_detect_weight("PyTorch", "pytorch_model.bin.index.json", "PyTorch") + test_detect_weight("SafeTensor", "model.safetensors.index.json", "SafeTensor") + test_detect_weight("GGML", None, "GGML") + test_detect_weight("GGUF", None, "GGUF") + test_detect_weight("AWQ", None, "AWQ") + test_detect_weight("auto", "pytorch_model.bin.index.json", "PyTorch") + test_detect_weight("auto", "model.safetensors.index.json", "SafeTensor") + test_detect_weight_in_config_json("PyTorch", "pytorch_model.bin.index.json", "PyTorch") + test_detect_weight_in_config_json("SafeTensor", "model.safetensors.index.json", "SafeTensor") + test_detect_weight_in_config_json("GGML", None, "GGML") + test_detect_weight_in_config_json("GGUF", None, "GGUF") + test_detect_weight_in_config_json("AWQ", None, "AWQ") + test_detect_weight_in_config_json("auto", "pytorch_model.bin.index.json", "PyTorch") + test_detect_weight_in_config_json("auto", "model.safetensors.index.json", "SafeTensor") + test_detect_weight_same_dir_config_json("PyTorch", "pytorch_model.bin.index.json", "PyTorch") + test_detect_weight_same_dir_config_json( + "SafeTensor", "model.safetensors.index.json", "SafeTensor" + ) + test_detect_weight_same_dir_config_json("GGML", None, "GGML") + test_detect_weight_same_dir_config_json("GGUF", None, "GGUF") + test_detect_weight_same_dir_config_json("AWQ", None, "AWQ") + test_detect_weight_same_dir_config_json("auto", "pytorch_model.bin.index.json", "PyTorch") + test_detect_weight_same_dir_config_json("auto", "model.safetensors.index.json", "SafeTensor") + test_find_weight_fail() diff --git a/tests/python/test_model_llama.py b/tests/python/test_model_llama.py deleted file mode 100644 index 3019b85bc0..0000000000 --- a/tests/python/test_model_llama.py +++ /dev/null @@ -1,70 +0,0 @@ -# pylint: disable=invalid-name,missing-docstring -import numpy as np -from tvm.relax.frontend.nn import spec - -from mlc_llm.models.llama import LlamaConfig, LlamaForCasualLM - - -def main(): - config = LlamaConfig( - hidden_act="silu", - hidden_size=256, - intermediate_size=688, - max_sequence_length=128, - num_attention_heads=8, - num_hidden_layers=8, - rms_norm_eps=1e-05, - vocab_size=4096, - position_embedding_base=10000, - ) - batch_size, total_seq_len, dtype = 1, 32, "float32" - - # Usecase 1. Define a model and export it to TVM's IRModule - model = LlamaForCasualLM(config) - model.to(dtype=dtype) - mod_spec = { - "prefill": { - "inputs": spec.Tensor([batch_size, "seq_len"], "int32"), - "total_seq_len": int, - }, - "decode": { - "inputs": spec.Tensor([batch_size, 1], "int32"), - "total_seq_len": int, - }, - "softmax_with_temperature": { - "logits": spec.Tensor([1, 1, config.vocab_size], "float32"), - "temperature": spec.Tensor([], "float32"), - }, - } - mod, _ = model.export_tvm(spec=mod_spec) - mod.show(black_format=False) - - # Usecase 2. JIT compile a model - for _, param in model.state_dict().items(): - param.data = np.random.rand(*param.shape).astype(param.dtype) - model = model.jit( - spec=mod_spec, - target="llvm", - device="cpu", - out_format="torch", - ) - - # Usecase 3. Run a model with PyTorch - import torch # pylint: disable=import-outside-toplevel - - result = model["prefill"]( - torch.from_numpy( - np.random.randint( - 0, - config.vocab_size, - size=(batch_size, total_seq_len), - dtype="int32", - ) - ), - total_seq_len, - ) - assert isinstance(result, torch.Tensor) - - -if __name__ == "__main__": - main() diff --git a/tests/python/test_param_loader_llama.py b/tests/python/test_param_loader_llama.py deleted file mode 100644 index 4c34c6964b..0000000000 --- a/tests/python/test_param_loader_llama.py +++ /dev/null @@ -1,32 +0,0 @@ -# pylint: disable=missing-docstring -import logging -from pathlib import Path - -from mlc_llm.models.llama import LlamaConfig -from mlc_llm.models.llama_param_map import hf_torch -from mlc_llm.param_loader import HFTorchLoader - -logging.basicConfig( - level=logging.DEBUG, - style="{", - datefmt="%Y-%m-%d %H:%M:%S", - format="{asctime} {levelname} {filename}:{lineno}: {message}", -) - - -def test_load_7b(): - prefix = Path("./dist/models/llama-2-7b-chat-hf/") - path_config = prefix / "config.json" - path_params = prefix / "pytorch_model.bin.index.json" - - model_config = LlamaConfig.from_file(path_config) - with HFTorchLoader( - config_path=path_params, - param_map=hf_torch(model_config), - ) as loader: - for name in loader.suggest_loading_order(): - loader.load_param(name=name) - - -if __name__ == "__main__": - test_load_7b() diff --git a/tests/python/test_update_config.py b/tests/python/test_update_config.py deleted file mode 100644 index ec92e2c5d3..0000000000 --- a/tests/python/test_update_config.py +++ /dev/null @@ -1,89 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch - -from mlc_chat.chat_module import ChatConfig, ChatModule, ConvConfig - -class UpdateConfigTest(unittest.TestCase): - - - @patch("mlc_chat.chat_module.ChatModule.__init__") - def setUp(self, mock_init): - mock_init.return_value = None - self.cm_under_test = ChatModule("test") - default_conv_config = { - "prefix_tokens": [], - "role_empty_sep": "", - "role_msg_sep": "", - "seps": [""], - "stop_tokens": [2], - "offset": 0, - "separator_style": 1, - "messages": [], - "stop_str": "<\/s>", - "roles": ["Prompt", "Code"], - "system": "", - "add_bos": True, - "name": "codellama_completion" - } - default_chat_config = { - 'model_lib': 'default_model_lib', - 'local_id': 'default_local_id', - 'conv_template': 'codellama_completion', - 'temperature': 0.7, - 'repetition_penalty': 1.0, - 'top_p': 0.95, - 'mean_gen_len': 128, - 'max_gen_len': 512, - 'shift_fill_factor': 0.3, - 'tokenizer_files': ['tokenizer.json', 'tokenizer.model'], - 'conv_config': None, - 'model_category': 'llama', - 'model_name': 'default_model_name' - } - self.cm_under_test.default_chat_config = default_chat_config - self.cm_under_test.default_conv_config = default_conv_config - self.cm_under_test._load_json_override_func = MagicMock() - - def test_update_config(self): - expected_value = '{"model_lib": "default_model_lib", "local_id": "default_local_id", "conv_template": "codellama_completion", "temperature": 0.5, "repetition_penalty": 1.0, "top_p": 0.95, "mean_gen_len": 128, "max_gen_len": 512, "shift_fill_factor": 0.3, "tokenizer_files": ["tokenizer.json", "tokenizer.model"], "conv_config": {"prefix_tokens": [], "role_empty_sep": "", "role_msg_sep": "", "seps": [""], "stop_tokens": [2], "offset": 0, "separator_style": 1, "messages": [], "stop_str": "}", "roles": ["Prompt", "Code"], "system": "", "add_bos": true, "name": "codellama_completion"}, "model_category": "llama", "model_name": "default_model_name"}' - - conv_config = ConvConfig( - system=None, - roles=None, - messages=None, - offset=None, - separator_style=None, - seps=None, - role_msg_sep=None, - role_empty_sep=None, - stop_str="}", - stop_tokens=None, - add_bos=None, - ) - - chat_config = ChatConfig( - temperature=0.5, - repetition_penalty=None, - top_p=None, - mean_gen_len=None, - max_gen_len=None, - conv_config=conv_config, - ) - - self.cm_under_test.update_chat_config(chat_config) - self.cm_under_test._load_json_override_func.assert_called_once_with(expected_value.replace('\n', '').replace('\t', ''), True) - - def test_update_config_none_conv_config(self): - expected_value = '{"model_lib": "default_model_lib", "local_id": "default_local_id", "conv_template": "codellama_completion", "temperature": 0.5, "repetition_penalty": 1.0, "top_p": 0.95, "mean_gen_len": 128, "max_gen_len": 512, "shift_fill_factor": 0.3, "tokenizer_files": ["tokenizer.json", "tokenizer.model"], "conv_config": null, "model_category": "llama", "model_name": "default_model_name"}' - - chat_config = ChatConfig( - temperature=0.5, - repetition_penalty=None, - top_p=None, - mean_gen_len=None, - max_gen_len=None, - ) - - self.cm_under_test.update_chat_config(chat_config) - self.cm_under_test._load_json_override_func.assert_called_once_with(expected_value.replace('\n', '').replace('\t', ''), True) - \ No newline at end of file