diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
new file mode 100644
index 0000000000..478f75e8fd
--- /dev/null
+++ b/.github/workflows/lint.yml
@@ -0,0 +1,87 @@
+name: Lint
+on: [push, pull_request]
+env:
+ IMAGE: 'mlcaidev/ci-cpu:caab922'
+
+jobs:
+ isort:
+ name: Python / isort
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ submodules: ''
+ - name: Version
+ run: |
+ wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh
+ chmod u+x ./ci/bash.sh
+ ./ci/bash.sh $IMAGE "conda env export --name ci-lint"
+ - name: Lint
+ run: |
+ ./ci/bash.sh $IMAGE bash ./ci/task/isort.sh
+
+ black:
+ name: Python / black
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ submodules: ''
+ - name: Version
+ run: |
+ wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh
+ chmod u+x ./ci/bash.sh
+ ./ci/bash.sh $IMAGE "conda env export --name ci-lint"
+ - name: Lint
+ run: |
+ ./ci/bash.sh $IMAGE bash ./ci/task/black.sh
+
+ mypy:
+ name: Python / mypy
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ submodules: ''
+ - name: Version
+ run: |
+ wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh
+ chmod u+x ./ci/bash.sh
+ ./ci/bash.sh $IMAGE "conda env export --name ci-lint"
+ - name: Lint
+ run: |
+ ./ci/bash.sh $IMAGE bash ./ci/task/mypy.sh
+
+ pylint:
+ name: Python / pylint
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ submodules: ''
+ - name: Version
+ run: |
+ wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh
+ chmod u+x ./ci/bash.sh
+ ./ci/bash.sh $IMAGE "conda env export --name ci-lint"
+ - name: Lint
+ run: |
+ ./ci/bash.sh $IMAGE bash ./ci/task/pylint.sh
+
+ clang-format:
+ name: C++ / clang-format
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ submodules: ''
+ ref: ${{ github.event.pull_request.head.sha }}
+ fetch-depth: 0
+ - name: Version
+ run: |
+ wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh
+ chmod u+x ./ci/bash.sh
+ ./ci/bash.sh $IMAGE "conda env export --name ci-lint"
+ - name: Lint
+ run: |
+ ./ci/bash.sh $IMAGE bash ./ci/task/clang-format.sh
diff --git a/3rdparty/tvm b/3rdparty/tvm
index e5ca38dd73..30b4fa3c13 160000
--- a/3rdparty/tvm
+++ b/3rdparty/tvm
@@ -1 +1 @@
-Subproject commit e5ca38dd735ba4d30782a4a58bf6195861642eb0
+Subproject commit 30b4fa3c13fc80d5c9151a9dc445d22c57ced3e0
diff --git a/README.md b/README.md
index bb52c1c735..f20d1c8a93 100644
--- a/README.md
+++ b/README.md
@@ -4,18 +4,18 @@
[Documentation](https://llm.mlc.ai/docs) | [Blog](https://blog.mlc.ai/) | [Discord][discord-url]
-Machine Learning Compilation for Large Language Models (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques.
+**M**achine **L**earning **C**ompilation for **L**arge **L**anguage **M**odels (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques.
**Universal deployment.** MLC LLM supports the following platforms and hardware:
- |
- AMD GPU |
- NVIDIA GPU |
- Apple M1/M2 GPU |
- Intel GPU |
+ |
+ AMD GPU |
+ NVIDIA GPU |
+ Apple GPU |
+ Intel GPU |
@@ -28,21 +28,18 @@ Machine Learning Compilation for Large Language Models (MLC LLM) is a high-perfo
macOS |
- ✅ Metal |
+ ✅ Metal (dGPU) |
N/A |
✅ Metal |
- ✅ Metal |
+ ✅ Metal (iGPU) |
Web Browser |
- ✅ WebGPU |
- ✅ WebGPU |
- ✅ WebGPU |
- ✅ WebGPU |
+ ✅ WebGPU and WASM |
iOS / iPadOS |
- ✅ Metal on Apple M1/M2 GPU |
+ ✅ Metal on Apple A-series GPU |
Android |
@@ -52,8 +49,25 @@ Machine Learning Compilation for Large Language Models (MLC LLM) is a high-perfo
+
+**Scalable.** MLC LLM scales universally on NVIDIA and AMD GPUs, cloud and gaming GPUs. Below
+showcases our single batch decoding performance with prefilling = 1 and decoding = 256.
+
+Performance of 4-bit CodeLlama-34B and Llama2-70B on two NVIDIA RTX 4090 and two AMD Radeon 7900 XTX:
+
+
+
+
+
+Scaling of fp16 and 4-bit CodeLlama-34 and Llama2-70B on A100-80G-PCIe and A10G-24G-PCIe, up to 8 GPUs:
+
+
+
+
## News
+* [10/18/2023] [[Post]](https://blog.mlc.ai/2023/10/19/Scalable-Language-Model-Inference-on-Multiple-NVDIA-AMD-GPUs) Scalable multi-GPU support for CUDA and ROCm are official.
+* [09/02/2023] Prebuilt ROCm 5.7 and CUDA 12.2 package is [available](https://llm.mlc.ai/docs/install/tvm.html#option-1-prebuilt-package).
* [08/25/2023] CodeLlama support is up.
* [08/14/2023] [[Post]](https://blog.mlc.ai/2023/08/09/GPU-Accelerated-LLM-on-Orange-Pi) Mali GPU support is up on Orange Pi.
* [08/09/2023] [[Post]](https://blog.mlc.ai/2023/08/09/Making-AMD-GPUs-competitive-for-LLM-inference) ROCm backend is mature to use.
@@ -66,7 +80,55 @@ Machine Learning Compilation for Large Language Models (MLC LLM) is a high-perfo
## Getting Started
-Please visit our [this page](https://llm.mlc.ai/docs/index.html#getting-started) for detailed instructions.
+Please visit our [documentation](https://llm.mlc.ai/docs/index.html#getting-started) for detailed instructions.
+
+## Model Support
+
+MLC LLM supports a wide range of model architectures and variants. We have the following prebuilts which you can
+use off-the-shelf. Visit [Prebuilt Models](https://llm.mlc.ai/docs/prebuilt_models.html) to see the full list, and [Compile Models via MLC](https://llm.mlc.ai/docs/compilation/compile_models.html) to see how to use models not on this list.
+
+
+
+
+ Architecture |
+ Prebuilt Model Variants |
+
+
+
+
+ Llama |
+ Llama-2, Code Llama, Vicuna, WizardLM, WizardMath, OpenOrca Platypus2, FlagAlpha Llama-2 Chinese, georgesung Llama-2 Uncensored |
+
+
+ GPT-NeoX |
+ RedPajama |
+
+
+ GPT-J |
+ |
+
+
+ RWKV |
+ RWKV-raven |
+
+
+ MiniGPT |
+ |
+
+
+ GPTBigCode |
+ WizardCoder |
+
+
+ ChatGLM |
+ |
+
+
+ StableLM |
+ |
+
+
+
## Universal Deployment APIs
diff --git a/android/MLCChat/app/src/main/assets/app-config.json b/android/MLCChat/app/src/main/assets/app-config.json
index ddbcb793ae..fb7c4546b3 100644
--- a/android/MLCChat/app/src/main/assets/app-config.json
+++ b/android/MLCChat/app/src/main/assets/app-config.json
@@ -1,9 +1,14 @@
{
"model_libs": [
+ "Llama-2-7b-chat-hf-q4f16_0",
"Llama-2-7b-chat-hf-q4f16_1",
"RedPajama-INCITE-Chat-3B-v1-q4f16_1"
],
"model_list": [
+ {
+ "model_url": "https://huggingface.co/mlc-ai/mlc-chat-Llama-2-7b-chat-hf-q4f16_0/",
+ "local_id": "Llama-2-7b-chat-hf-q4f16_0"
+ },
{
"model_url": "https://huggingface.co/mlc-ai/mlc-chat-Llama-2-7b-chat-hf-q4f16_1/",
"local_id": "Llama-2-7b-chat-hf-q4f16_1"
diff --git a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt
index e1b5928019..f51d56ec10 100644
--- a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt
+++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt
@@ -2,6 +2,9 @@ package ai.mlc.mlcchat
import ai.mlc.mlcllm.ChatModule
import android.app.Application
+import android.content.ClipData
+import android.content.ClipboardManager
+import android.content.Context
import android.os.Environment
import android.widget.Toast
import androidx.compose.runtime.mutableStateOf
@@ -23,6 +26,8 @@ class AppViewModel(application: Application) : AndroidViewModel(application) {
val modelList = emptyList().toMutableStateList()
val chatState = ChatState()
val modelSampleList = emptyList().toMutableStateList()
+ private var showAlert = mutableStateOf(false)
+ private var alertMessage = mutableStateOf("")
private var appConfig = AppConfig(
emptyList(),
emptyList().toMutableList(),
@@ -44,13 +49,38 @@ class AppViewModel(application: Application) : AndroidViewModel(application) {
loadAppConfig()
}
+ fun supportedModelLibs(): List {
+ return appConfig.modelLibs
+ }
+
+ fun isShowingAlert(): Boolean {
+ return showAlert.value
+ }
+
+ fun errorMessage(): String {
+ return alertMessage.value
+ }
+
+ fun dismissAlert() {
+ require(showAlert.value)
+ showAlert.value = false
+ }
+
+ fun copyError() {
+ require(showAlert.value)
+ val clipboard =
+ application.getSystemService(Context.CLIPBOARD_SERVICE) as ClipboardManager
+ clipboard.setPrimaryClip(ClipData.newPlainText("MLCChat", errorMessage()))
+ }
+
+ private fun issueAlert(error: String) {
+ showAlert.value = true
+ alertMessage.value = error
+ }
+
fun requestAddModel(url: String, localId: String?) {
if (localId != null && localIdSet.contains(localId)) {
- Toast.makeText(
- application,
- "localId: $localId has been occupied",
- Toast.LENGTH_SHORT
- ).show()
+ issueAlert("localId: $localId has been occupied")
} else {
downloadModelConfig(if (url.endsWith("/")) url else "$url/", localId, false)
}
@@ -58,11 +88,7 @@ class AppViewModel(application: Application) : AndroidViewModel(application) {
fun requestDeleteModel(localId: String) {
deleteModel(localId)
- Toast.makeText(
- application,
- "Model: $localId has been deleted",
- Toast.LENGTH_SHORT
- ).show()
+ issueAlert("Model: $localId has been deleted")
}
@@ -133,11 +159,7 @@ class AppViewModel(application: Application) : AndroidViewModel(application) {
private fun isModelConfigAllowed(modelConfig: ModelConfig): Boolean {
if (appConfig.modelLibs.contains(modelConfig.modelLib)) return true;
viewModelScope.launch {
- Toast.makeText(
- application,
- "Model lib ${modelConfig.modelLib} is not supported.",
- Toast.LENGTH_SHORT
- ).show()
+ issueAlert("Model lib ${modelConfig.modelLib} is not supported.")
}
return false
}
@@ -169,11 +191,7 @@ class AppViewModel(application: Application) : AndroidViewModel(application) {
}
if (localIdSet.contains(modelConfig.localId)) {
tempFile.delete()
- Toast.makeText(
- application,
- "${modelConfig.localId} has been used, please consider another local ID",
- Toast.LENGTH_SHORT
- ).show()
+ issueAlert("${modelConfig.localId} has been used, please consider another local ID")
return@launch
}
if (!isModelConfigAllowed(modelConfig)) {
@@ -188,21 +206,13 @@ class AppViewModel(application: Application) : AndroidViewModel(application) {
addModelConfig(modelConfig, modelUrl, isBuiltin)
} catch (e: Exception) {
viewModelScope.launch {
- Toast.makeText(
- application,
- "Add model failed: ${e.localizedMessage}",
- Toast.LENGTH_SHORT
- ).show()
+ issueAlert("Add model failed: ${e.localizedMessage}")
}
}
}
} catch (e: Exception) {
viewModelScope.launch {
- Toast.makeText(
- application,
- "Download model config failed: ${e.localizedMessage}",
- Toast.LENGTH_SHORT
- ).show()
+ issueAlert("Download model config failed: ${e.localizedMessage}")
}
}
diff --git a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/StartView.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/StartView.kt
index ee2833fca0..87fba77a05 100644
--- a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/StartView.kt
+++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/StartView.kt
@@ -20,6 +20,7 @@ import androidx.compose.material.icons.outlined.Delete
import androidx.compose.material.icons.outlined.Download
import androidx.compose.material.icons.outlined.Pause
import androidx.compose.material.icons.outlined.Schedule
+import androidx.compose.material3.AlertDialog
import androidx.compose.material3.Divider
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon
@@ -94,9 +95,14 @@ fun StartView(
}
}
if (isAddingModel) {
- Text(
- text = "Add Model Variant", modifier = Modifier.padding(top = 10.dp)
- )
+ Text(text = "Supported Base Model Libs", modifier = Modifier.padding(top = 10.dp))
+ for (lib in appViewModel.supportedModelLibs()) {
+ Text(
+ text = lib,
+ style = MaterialTheme.typography.bodyMedium
+ )
+ }
+ Text(text = "Add Model Variant", modifier = Modifier.padding(top = 10.dp))
LazyColumn() {
items(
items = appViewModel.modelSampleList
@@ -148,10 +154,36 @@ fun StartView(
}
}
}
-
+ if (appViewModel.isShowingAlert()) {
+ AlertDialog(
+ onDismissRequest = { appViewModel.dismissAlert() },
+ onConfirmation = { appViewModel.copyError() },
+ error = appViewModel.errorMessage()
+ )
+ }
}
}
+@ExperimentalMaterial3Api
+@Composable
+fun AlertDialog(
+ onDismissRequest: () -> Unit,
+ onConfirmation: () -> Unit,
+ error: String,
+) {
+ AlertDialog(
+ title = { Text(text = "Error") },
+ text = { Text(text = error) },
+ onDismissRequest = { onDismissRequest() },
+ confirmButton = {
+ TextButton(onClick = { onConfirmation() }) { Text("Copy") }
+ },
+ dismissButton = {
+ TextButton(onClick = { onDismissRequest() }) { Text("Dismiss") }
+ }
+ )
+}
+
@Composable
fun ModelView(
navController: NavController,
diff --git a/android/prepare_libs.sh b/android/prepare_libs.sh
index 72457954c0..938ffd5cd8 100755
--- a/android/prepare_libs.sh
+++ b/android/prepare_libs.sh
@@ -9,7 +9,10 @@ python prepare_model_lib.py
cd build
touch config.cmake
-echo "set(TVM_HOME ${TVM_HOME})" >> config.cmake
+if [ ${TVM_HOME-0} -ne 0 ]; then
+ echo "set(TVM_HOME ${TVM_HOME})" >> config.cmake
+fi
+
cmake .. \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \
diff --git a/android/src/cpp/tvm_runtime.h b/android/src/cpp/tvm_runtime.h
index 5a1267119d..2caaaaeb1a 100644
--- a/android/src/cpp/tvm_runtime.h
+++ b/android/src/cpp/tvm_runtime.h
@@ -1,13 +1,12 @@
#define DMLC_USE_LOGGING_LIBRARY
#define TVM_USE_LIBBACKTRACE 0
+#include
#include
#include
#include
#include
-#include
-
static_assert(TVM_LOG_CUSTOMIZE == 1, "TVM_LOG_CUSTOMIZE must be 1");
namespace tvm {
diff --git a/ci/bash.sh b/ci/bash.sh
new file mode 100755
index 0000000000..d54eae48ef
--- /dev/null
+++ b/ci/bash.sh
@@ -0,0 +1,91 @@
+#!/usr/bin/env bash
+
+#
+# Start a bash, mount /workspace to be current directory.
+#
+# Usage: docker/bash.sh
+# Starts an interactive session
+#
+# Usage2: docker/bash.sh [COMMAND]
+# Execute command in the docker image, non-interactive
+#
+if [ "$#" -lt 1 ]; then
+ echo "Usage: docker/bash.sh [--no-gpu] [COMMAND]"
+ exit -1
+fi
+
+if [ "$1" == "--no-gpu" ]; then
+ ENABLE_NV_DOCKER=0
+ shift
+else
+ ENABLE_NV_DOCKER=1
+fi
+
+DOCKER_IMAGE_NAME=("$1")
+
+
+if [ "$#" -eq 1 ]; then
+ COMMAND="bash"
+ if [[ $(uname) == "Darwin" ]]; then
+ # Docker's host networking driver isn't supported on macOS.
+ # Use default bridge network and expose port for jupyter notebook.
+ DOCKER_EXTRA_PARAMS=("-it -p 8888:8888")
+ else
+ DOCKER_EXTRA_PARAMS=("-it --net=host")
+ fi
+else
+ shift 1
+ COMMAND=("$@")
+fi
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+WORKSPACE="$(pwd)"
+
+# Use nvidia-docker if the container is GPU.
+if [[ ! -z $CUDA_VISIBLE_DEVICES ]]; then
+ CUDA_ENV="-e CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}"
+else
+ CUDA_ENV=""
+fi
+
+# If this is an wheel test command then pass the env var to docker.
+if [[ ! -z $WHEEL_TEST ]]; then
+ WHEEL_TEST="-e WHEEL_TEST=${WHEEL_TEST}"
+fi
+
+if [[ "${DOCKER_IMAGE_NAME}" == *"cu"* ]]; then
+ if [ "$ENABLE_NV_DOCKER" -eq 1 ]; then
+ if ! type "nvidia-docker" 1> /dev/null 2> /dev/null
+ then
+ DOCKER_BINARY="docker"
+ CUDA_ENV=" --gpus all "${CUDA_ENV}
+ else
+ DOCKER_BINARY="nvidia-docker"
+ fi
+ else
+ DOCKER_BINARY="docker"
+ fi
+else
+ DOCKER_BINARY="docker"
+fi
+
+# Print arguments.
+echo "WORKSPACE: ${WORKSPACE}"
+echo "DOCKER CONTAINER NAME: ${DOCKER_IMAGE_NAME}"
+echo ""
+
+echo "Running '${COMMAND[@]}' inside ${DOCKER_IMAGE_NAME}..."
+
+# By default we cleanup - remove the container once it finish running (--rm)
+# and share the PID namespace (--pid=host) so the process inside does not have
+# pid 1 and SIGKILL is propagated to the process inside (jenkins can kill it).
+
+${DOCKER_BINARY} run --rm --pid=host\
+ -v ${WORKSPACE}:/workspace \
+ -v ${SCRIPT_DIR}:/docker \
+ -w /workspace \
+ ${CUDA_ENV} \
+ ${WHEEL_TEST} \
+ ${DOCKER_EXTRA_PARAMS[@]} \
+ ${DOCKER_IMAGE_NAME} \
+ ${COMMAND[@]}
diff --git a/ci/task/black.sh b/ci/task/black.sh
new file mode 100755
index 0000000000..dcc4b42555
--- /dev/null
+++ b/ci/task/black.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+set -eo pipefail
+
+source ~/.bashrc
+micromamba activate ci-lint
+export NUM_THREADS=$(nproc)
+export PYTHONPATH="./python:$PYTHONPATH"
+
+set -x
+
+black --check --workers $NUM_THREADS \
+ ./python/ \
+ ./tests/python
diff --git a/ci/task/clang-format.sh b/ci/task/clang-format.sh
new file mode 100755
index 0000000000..54780ec4f9
--- /dev/null
+++ b/ci/task/clang-format.sh
@@ -0,0 +1,67 @@
+#!/bin/bash
+set -eo pipefail
+
+source ~/.bashrc
+micromamba activate ci-lint
+export NUM_THREADS=$(nproc)
+export PYTHONPATH="./python:$PYTHONPATH"
+
+set -x
+git config --global --add safe.directory '*'
+
+INPLACE_FORMAT=${INPLACE_FORMAT:=false}
+LINT_ALL_FILES=true
+REVISION=$(git rev-list --max-parents=0 HEAD)
+
+while (($#)); do
+ case "$1" in
+ -i)
+ INPLACE_FORMAT=true
+ shift 1
+ ;;
+ --rev)
+ LINT_ALL_FILES=false
+ REVISION=$2
+ shift 2
+ ;;
+ *)
+ echo "Usage: clang-format.sh [-i] [--rev ]"
+ echo ""
+ echo "Run clang-format on files that changed since or on all files in the repo"
+ echo "Examples:"
+ echo "- Compare last one commit: clang-format.sh --rev HEAD~1"
+ echo "- Compare against upstream/main: clang-format.sh --rev upstream/main"
+ echo "The -i will format files in-place instead of checking them."
+ exit 1
+ ;;
+ esac
+done
+
+cleanup() {
+ if [ -f /tmp/$$.clang-format.txt ]; then
+ echo ""
+ echo "---------clang-format log----------"
+ cat /tmp/$$.clang-format.txt
+ fi
+ rm -rf /tmp/$$.clang-format.txt
+}
+trap cleanup 0
+
+if [[ "$INPLACE_FORMAT" == "true" ]]; then
+ echo "Running inplace git-clang-format against $REVISION"
+ git-clang-format --extensions h,hh,hpp,c,cc,cpp,mm "$REVISION"
+ exit 0
+fi
+
+if [[ "$LINT_ALL_FILES" == "true" ]]; then
+ echo "Running git-clang-format against all C++ files"
+ git-clang-format --diff --extensions h,hh,hpp,c,cc,cpp,mm "$REVISION" 1>/tmp/$$.clang-format.txt
+else
+ echo "Running git-clang-format against $REVISION"
+ git-clang-format --diff --extensions h,hh,hpp,c,cc,cpp,mm "$REVISION" 1>/tmp/$$.clang-format.txt
+fi
+
+if grep --quiet -E "diff" ();
}
-ModelPaths ModelPaths::Find(const std::string& device_name, const std::string& local_id) {
+ModelPaths ModelPaths::Find(const std::string& device_name, const std::string& local_id,
+ const std::string& user_lib_path) {
// Step 1. Find config path
std::filesystem::path config_path;
if (auto path = TryInferMLCChatConfig(local_id)) {
@@ -368,26 +370,36 @@ ModelPaths ModelPaths::Find(const std::string& device_name, const std::string& l
}
std::cout << "Use model weights: " << params_json << std::endl;
// Step 3. Find model lib path
- std::string lib_local_id = ReadStringFromJSONFile(config_path, "model_lib");
- std::string lib_name = lib_local_id + "-" + device_name;
std::filesystem::path lib_path;
- if (auto path = FindFile({lib_local_id,
- "dist/prebuilt/lib", // Using prebuilt workflow
- "dist/" + local_id, "dist/prebuilt/" + lib_local_id},
- {
- lib_name + GetArchSuffix(),
- lib_name,
- },
- GetLibSuffixes())) {
- lib_path = path.value();
+ if (!user_lib_path.empty()) {
+ lib_path = user_lib_path;
+ if (!std::filesystem::exists(lib_path) || !std::filesystem::is_regular_file(lib_path)) {
+ LOG(FATAL) << "The `lib_path` you passed in is not a file: " << user_lib_path << "\n";
+ exit(1);
+ }
} else {
- LOG(FATAL) << "Cannot find the model library that corresponds to `" << lib_local_id << "`.\n"
- << "We searched over the following possible paths: \n"
- << "- " + lib_local_id << "\n"
- << "- dist/prebuilt/lib \n"
- << "- dist/" + local_id << "\n"
- << "- dist/prebuilt/" + lib_local_id;
- exit(1);
+ std::string lib_local_id = ReadStringFromJSONFile(config_path, "model_lib");
+ std::string lib_name = lib_local_id + "-" + device_name;
+ if (auto path = FindFile({lib_local_id,
+ "dist/prebuilt/lib", // Using prebuilt workflow
+ "dist/" + local_id, "dist/prebuilt/" + lib_local_id},
+ {
+ lib_name + GetArchSuffix(),
+ lib_name,
+ },
+ GetLibSuffixes())) {
+ lib_path = path.value();
+ } else {
+ LOG(FATAL) << "Cannot find the model library that corresponds to `" << lib_local_id << "`.\n"
+ << "We searched over the following possible paths: \n"
+ << "- " + lib_local_id << "\n"
+ << "- dist/prebuilt/lib \n"
+ << "- dist/" + local_id << "\n"
+ << "- dist/prebuilt/" + lib_local_id << "\n"
+ << "If you would like to directly specify the full model library path, you may "
+ << "consider passing in the `--model-lib-path` argument.\n";
+ exit(1);
+ }
}
std::cout << "Use model library: " << lib_path << std::endl;
return ModelPaths{config_path, params_json, lib_path};
@@ -427,8 +439,8 @@ void Converse(ChatModule* chat, const std::string& input, int stream_interval,
* \param stream_interval The interval that should be used for streaming the response.
*/
void Chat(ChatModule* chat, const std::string& device_name, std::string local_id,
- int stream_interval = 2) {
- ModelPaths model = ModelPaths::Find(device_name, local_id);
+ std::string lib_path, int stream_interval = 2) {
+ ModelPaths model = ModelPaths::Find(device_name, local_id, lib_path);
PrintSpecialCommands();
chat->Reload(model);
chat->ProcessSystemPrompts();
@@ -456,7 +468,7 @@ void Chat(ChatModule* chat, const std::string& device_name, std::string local_id
if (new_local_id.empty()) {
new_local_id = local_id;
}
- model = ModelPaths::Find(device_name, new_local_id);
+ model = ModelPaths::Find(device_name, new_local_id, lib_path);
chat->Reload(model);
local_id = new_local_id;
} else if (input.substr(0, 5) == "/help") {
@@ -468,9 +480,19 @@ void Chat(ChatModule* chat, const std::string& device_name, std::string local_id
}
int main(int argc, char* argv[]) {
- argparse::ArgumentParser args("mlc_chat");
-
- args.add_argument("--model");
+ argparse::ArgumentParser args("mlc_chat_cli");
+
+ args.add_description(
+ "MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box.\n"
+ "Note: the --model argument is required. It can either be the model name with its "
+ "quantization scheme or a full path to the model folder. In the former case, the "
+ "provided name will be used to search for the model folder over possible paths. "
+ "--model-lib-path argument is optional. If unspecified, the --model argument will be used "
+ "to search for the library file over possible paths.");
+
+ args.add_argument("--model").help("[required] the model to use");
+ args.add_argument("--model-lib-path")
+ .help("[optional] the full path to the model library file to use");
args.add_argument("--device").default_value("auto");
args.add_argument("--evaluate").default_value(false).implicit_value(true);
args.add_argument("--eval-prompt-len").default_value(128).scan<'i', int>();
@@ -485,6 +507,10 @@ int main(int argc, char* argv[]) {
}
std::string local_id = args.get("--model");
+ std::string lib_path;
+ if (args.present("--model-lib-path")) {
+ lib_path = args.get("--model-lib-path");
+ }
auto [device_name, device_id] = DetectDevice(args.get("--device"));
try {
@@ -494,14 +520,14 @@ int main(int argc, char* argv[]) {
// that are not supposed to be used in chat app setting
int prompt_len = args.get("--eval-prompt-len");
int gen_len = args.get("--eval-gen-len");
- ModelPaths model = ModelPaths::Find(device_name, local_id);
+ ModelPaths model = ModelPaths::Find(device_name, local_id, lib_path);
tvm::runtime::Module chat_mod = mlc::llm::CreateChatModule(GetDevice(device_name, device_id));
std::string model_path = model.config.parent_path().string();
tvm::runtime::Module lib = tvm::runtime::Module::LoadFromFile(model.lib.string());
chat_mod.GetFunction("reload")(lib, tvm::String(model_path));
chat_mod.GetFunction("evaluate")(prompt_len, gen_len);
} else {
- Chat(&chat, device_name, local_id);
+ Chat(&chat, device_name, local_id, lib_path);
}
} catch (const std::runtime_error& err) {
std::cerr << err.what() << std::endl;
diff --git a/cpp/conv_templates.cc b/cpp/conv_templates.cc
index 69f84b2421..ae91bf2070 100644
--- a/cpp/conv_templates.cc
+++ b/cpp/conv_templates.cc
@@ -473,6 +473,25 @@ Conversation VanillaLM() {
return conv;
}
+Conversation StableLM3B() {
+ Conversation conv;
+ conv.name = "stablelm-3b";
+ conv.system = "";
+ conv.roles = {"Prompt", "LM"};
+ conv.messages = {};
+ conv.separator_style = SeparatorStyle::kLM;
+ conv.offset = 0;
+ conv.seps = {""};
+ conv.role_msg_sep = "";
+ conv.role_empty_sep = "";
+ // TODO(mlc-team): add eos to mlc-chat-config
+ // and remove eos from stop token setting.
+ // so the same template works for more tokenizers
+ conv.stop_tokens = {0};
+ conv.add_bos = true;
+ return conv;
+}
+
Conversation GPTBigCode() {
Conversation conv;
conv.name = "gpt_bigcode";
@@ -580,6 +599,7 @@ Conversation Conversation::FromTemplate(const std::string& name) {
{"minigpt", MiniGPT},
{"moss", MOSS},
{"LM", VanillaLM},
+ {"stablelm-3b", StableLM3B},
{"gpt_bigcode", GPTBigCode},
{"wizardlm_7b", WizardLM7B},
{"wizard_coder_or_math", WizardCoderOrMATH},
diff --git a/cpp/conversation.h b/cpp/conversation.h
index 82332aede6..6211c24c25 100644
--- a/cpp/conversation.h
+++ b/cpp/conversation.h
@@ -283,7 +283,8 @@ class Conversation {
/* place_in_prompt= */ place_in_prompt);
} else {
ICHECK(this->separator_style == SeparatorStyle::kLM ||
- this->separator_style == SeparatorStyle::kCodeCompletion) << "Unsupported separator_style";
+ this->separator_style == SeparatorStyle::kCodeCompletion)
+ << "Unsupported separator_style";
// special handle LM, LM mode have no memory
// and only returns last one
if (this->messages.size() >= 2) {
diff --git a/cpp/image_embed.h b/cpp/image_embed.h
index 87b862242c..e0e21da686 100644
--- a/cpp/image_embed.h
+++ b/cpp/image_embed.h
@@ -6,17 +6,7 @@
#include
#include
-#ifndef MLC_LLM_DLL
-#ifdef _WIN32
-#ifdef MLC_LLM_EXPORTS
-#define MLC_LLM_DLL __declspec(dllexport)
-#else
-#define MLC_LLM_DLL __declspec(dllimport)
-#endif
-#else
-#define MLC_LLM_DLL __attribute__((visibility("default")))
-#endif
-#endif
+#include "base.h"
namespace mlc {
namespace llm {
diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc
index f3fdcc3a36..35a8d1f41e 100644
--- a/cpp/llm_chat.cc
+++ b/cpp/llm_chat.cc
@@ -32,6 +32,9 @@
#include
#include "conversation.h"
+#include "random.h"
+#include "support.h"
+#include "tokenizers.h"
namespace mlc {
namespace llm {
@@ -39,62 +42,6 @@ namespace llm {
using tvm::Device;
using namespace tvm::runtime;
namespace {
-//----------------------------
-// Tokenizers
-//----------------------------
-using tokenizers::Tokenizer;
-
-std::string LoadBytesFromFile(const std::string& path) {
- std::ifstream fs(path, std::ios::in | std::ios::binary);
- ICHECK(!fs.fail()) << "Cannot open " << path;
- std::string data;
- fs.seekg(0, std::ios::end);
- size_t size = static_cast(fs.tellg());
- fs.seekg(0, std::ios::beg);
- data.resize(size);
- fs.read(data.data(), size);
- return data;
-}
-
-std::unique_ptr TokenizerFromPath(const std::string& _path) {
- std::filesystem::path path(_path);
- std::filesystem::path sentencepiece;
- std::filesystem::path huggingface;
- std::filesystem::path rwkvworld;
- CHECK(std::filesystem::exists(path)) << "Cannot find tokenizer via path: " << _path;
- if (std::filesystem::is_directory(path)) {
- sentencepiece = path / "tokenizer.model";
- huggingface = path / "tokenizer.json";
- rwkvworld = path / "tokenizer_model";
- // Check ByteLevelBPE
- {
- std::filesystem::path merges_path = path / "merges.txt";
- std::filesystem::path vocab_path = path / "vocab.json";
- std::filesystem::path added_tokens_path = path / "added_tokens.json";
- if (std::filesystem::exists(merges_path) && std::filesystem::exists(vocab_path) &&
- std::filesystem::exists(added_tokens_path)) {
- std::string vocab = LoadBytesFromFile(vocab_path.string());
- std::string merges = LoadBytesFromFile(merges_path.string());
- std::string added_tokens = LoadBytesFromFile(added_tokens_path.string());
- return Tokenizer::FromBlobByteLevelBPE(vocab, merges, added_tokens);
- }
- }
- } else {
- sentencepiece = path.parent_path() / "tokenizer.model";
- huggingface = path.parent_path() / "tokenizer.json";
- rwkvworld = path.parent_path() / "tokenizer_model";
- }
- if (std::filesystem::exists(sentencepiece)) {
- return Tokenizer::FromBlobSentencePiece(LoadBytesFromFile(sentencepiece.string()));
- }
- if (std::filesystem::exists(huggingface)) {
- return Tokenizer::FromBlobJSON(LoadBytesFromFile(huggingface.string()));
- }
- if (std::filesystem::exists(rwkvworld)) {
- return Tokenizer::FromBlobRWKVWorld(rwkvworld.string());
- }
- LOG(FATAL) << "Cannot find any tokenizer under: " << _path;
-}
//------------------------------
// support functions
@@ -315,23 +262,6 @@ struct FunctionTable {
PackedFunc fkvcache_array_popn_;
};
-class RandomGenerator {
- private:
- std::mt19937 gen;
- std::uniform_real_distribution<> dis;
-
- RandomGenerator(int seed) : gen(seed), dis(0.0, 1.0) {}
-
- public:
- static RandomGenerator& GetInstance(int seed = std::random_device{}()) {
- static RandomGenerator instance(seed);
- return instance;
- }
-
- double GetRandomNumber() { return dis(gen); }
-
- void SetSeed(int seed) { gen.seed(seed); }
-};
} // namespace
//------------------------------
@@ -507,7 +437,7 @@ class LLMChat {
/*!
* \brief Reload model, tokenizers and configurations from the specified model path.
- * \param executable The module to reload.
+ * \param reload_lib The module to reload, it can either be a path to the library or a tvm Module.
* \param model_path The path to search for models.
* \param app_config_json The JSON string used to partially override the configuration loaded from
* disk, default to empty string.
@@ -599,7 +529,20 @@ class LLMChat {
* \brief Get input tokens based on history
* \param place_in_prompt The place of the input message in the prompt.
*/
- std::vector GetInputTokens(PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll) {
+ std::vector GetInputTokens(PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll,
+ picojson::object generation_config = picojson::object()) {
+ // prepare generation settings
+ // the generation_config will not override the original config
+ // since is only used for this generation
+ int64_t gen_mean_gen_len;
+ if (generation_config.count("mean_gen_len")) {
+ CHECK(generation_config["mean_gen_len"].is());
+ gen_mean_gen_len = generation_config["mean_gen_len"].get();
+ } else {
+ gen_mean_gen_len = this->mean_gen_len_;
+ }
+
+ // work on input tokens
std::vector tokens;
std::vector prompts;
@@ -619,7 +562,7 @@ class LLMChat {
std::string all_prompt = GetConcatPrompt(prompts, 0, 0);
std::vector encoded = this->tokenizer_->Encode(all_prompt);
tokens.insert(tokens.end(), encoded.begin(), encoded.end());
- if (this->total_seq_len_ + tokens.size() + this->mean_gen_len_ < this->max_window_size_) {
+ if (this->total_seq_len_ + tokens.size() + gen_mean_gen_len < this->max_window_size_) {
return tokens;
}
// need shift window and re-encode
@@ -656,11 +599,11 @@ class LLMChat {
if (tokens.size() >= this->max_window_size_) {
LOG(WARNING)
<< "The prompt tokens are more than `max_window_size`, the input will be truncated.";
- ICHECK_GT(this->max_window_size_, this->mean_gen_len_);
+ ICHECK_GT(this->max_window_size_, gen_mean_gen_len);
std::vector truncated_tokens(
- tokens.end() - (this->max_window_size_ - this->mean_gen_len_), tokens.end());
+ tokens.end() - (this->max_window_size_ - gen_mean_gen_len), tokens.end());
return truncated_tokens;
- } else if (tokens.size() + this->mean_gen_len_ >= this->max_window_size_) {
+ } else if (tokens.size() + gen_mean_gen_len >= this->max_window_size_) {
LOG(WARNING)
<< "The prompt tokens are too long and the generated text may be incomplete, due to "
"limited `max_window_size`. ";
@@ -695,8 +638,10 @@ class LLMChat {
return view;
}
- std::vector PrepareBeforeEmbedding(std::string inp, bool append_conversation = true,
- PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll) {
+ std::vector PrepareBeforeEmbedding(
+ std::string inp, bool append_conversation = true,
+ PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll,
+ picojson::object generation_config = picojson::object()) {
if (conversation_.separator_style == SeparatorStyle::kLM ||
conversation_.separator_style == SeparatorStyle::kCodeCompletion) {
this->ResetChat();
@@ -705,7 +650,7 @@ class LLMChat {
this->ResetRuntimeStats();
}
output_ids_.clear();
- appeared_token_ids_.clear();
+ appeared_token_freq_.clear();
output_message_.clear();
stop_triggered_ = false;
if (append_conversation) {
@@ -713,7 +658,7 @@ class LLMChat {
conversation_.AppendReplyHeader(conversation_.roles[1]);
}
- return this->GetInputTokens(place_in_prompt);
+ return this->GetInputTokens(place_in_prompt, generation_config);
}
/*!
@@ -724,9 +669,14 @@ class LLMChat {
* \return the embedding of the tokenized prompt.
*/
ObjectRef EmbedStep(std::string inp, bool append_conversation = true,
- PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll) {
+ PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll,
+ String generation_config_str = "") {
+ // process generation settings
+ picojson::object generation_config =
+ this->LoadGenerationConfigFromString(generation_config_str);
+
std::vector prompt_tokens =
- PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt);
+ PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt, generation_config);
int64_t token_len = static_cast(prompt_tokens.size());
if (token_len == 0) {
return NDArray::Empty({}, DataType::Float(32), device_);
@@ -755,7 +705,8 @@ class LLMChat {
* \param embedding The embedding to prefill with.
* \param decode_next_token Whether to decode next token.
*/
- void PrefillWithEmbedStep(NDArray embedding, bool decode_next_token = true) {
+ void PrefillWithEmbedStep(NDArray embedding, bool decode_next_token = true,
+ String generation_config_str = "") {
if (ft_.use_disco) {
LOG(FATAL) << "NotImplementedError: Distributed inference is not supported for this model";
throw;
@@ -774,13 +725,16 @@ class LLMChat {
return;
}
- int32_t next_token = this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_);
+ picojson::object generation_config =
+ this->LoadGenerationConfigFromString(generation_config_str);
+
+ int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config);
auto tend = std::chrono::high_resolution_clock::now();
this->prefill_total_time += static_cast((tend - tstart).count()) / 1e9;
this->prefill_total_tokens += token_len;
- this->ProcessNextToken(next_token);
+ this->ProcessNextToken(next_token, generation_config);
}
/*!
@@ -791,20 +745,25 @@ class LLMChat {
* \param place_in_prompt The place of the input message in the prompt.
*/
void PrefillStep(std::string inp, bool append_conversation = true, bool decode_next_token = true,
- PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll) {
+ PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll,
+ String generation_config_str = "") {
if (ft_.embed_func_.defined() && ft_.prefill_with_embed_func_.defined()) {
// Temporarily placed inside `PrefillStep` for compatibility in transition.
// Will be separated out in the future.
if (ft_.use_disco) {
LOG(FATAL) << "NotImplementedError: Distributed inference is not supported for this model";
}
- NDArray embedding = Downcast(EmbedStep(inp, append_conversation, place_in_prompt));
- PrefillWithEmbedStep(embedding, decode_next_token);
+ NDArray embedding = Downcast(
+ EmbedStep(inp, append_conversation, place_in_prompt, generation_config_str));
+ PrefillWithEmbedStep(embedding, decode_next_token, generation_config_str);
return;
}
+ picojson::object generation_config =
+ this->LoadGenerationConfigFromString(generation_config_str);
+
std::vector prompt_tokens =
- this->PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt);
+ this->PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt, generation_config);
int64_t token_len = static_cast(prompt_tokens.size());
if (token_len == 0) return;
if (ft_.use_disco) {
@@ -824,16 +783,19 @@ class LLMChat {
return;
}
- int32_t next_token = this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_);
+ int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config);
auto tend = std::chrono::high_resolution_clock::now();
this->prefill_total_time += static_cast((tend - tstart).count()) / 1e9;
this->prefill_total_tokens += token_len;
- this->ProcessNextToken(next_token);
+ this->ProcessNextToken(next_token, generation_config);
}
- void DecodeStep() {
+ void DecodeStep(String generation_config_str = "") {
+ picojson::object generation_config =
+ this->LoadGenerationConfigFromString(generation_config_str);
+
ICHECK(!output_ids_.empty());
int32_t last_token = output_ids_.back();
tvm::runtime::NDArray input_data = GetInputTokenNDArray({last_token});
@@ -843,13 +805,13 @@ class LLMChat {
NDArray logits_on_device = this->ForwardTokens({last_token}, total_seq_len_ + 1);
total_seq_len_ += 1;
- int32_t next_token = this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_);
+ int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config);
auto tend = std::chrono::high_resolution_clock::now();
this->decode_total_time += static_cast((tend - tstart).count()) / 1e9;
this->decode_total_tokens += 1;
- this->ProcessNextToken(next_token);
+ this->ProcessNextToken(next_token, generation_config);
}
bool Stopped() { return stop_triggered_; }
@@ -921,7 +883,7 @@ class LLMChat {
{
auto tstart = std::chrono::high_resolution_clock::now();
logits_on_device = this->ForwardTokens(tokens, tokens.size());
- tokens.push_back(this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_));
+ tokens.push_back(this->SampleTokenFromLogits(logits_on_device));
auto tend = std::chrono::high_resolution_clock::now();
this->prefill_total_time = static_cast((tend - tstart).count()) / 1e9;
@@ -933,7 +895,7 @@ class LLMChat {
auto tstart = std::chrono::high_resolution_clock::now();
for (int64_t len = 1; len < generate_len; ++len) {
logits_on_device = this->ForwardTokens({tokens.back()}, tokens.size());
- tokens.push_back(this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_));
+ tokens.push_back(this->SampleTokenFromLogits(logits_on_device));
}
auto tend = std::chrono::high_resolution_clock::now();
@@ -950,6 +912,8 @@ class LLMChat {
picojson::object config;
config["temperature"] = picojson::value(this->temperature_);
config["repetition_penalty"] = picojson::value(this->repetition_penalty_);
+ config["presence_penalty"] = picojson::value(this->presence_penalty_);
+ config["frequency_penalty"] = picojson::value(this->frequency_penalty_);
config["top_p"] = picojson::value(this->top_p_);
config["mean_gen_len"] = picojson::value(this->mean_gen_len_);
config["max_gen_len"] = picojson::value(this->max_gen_len_);
@@ -957,29 +921,109 @@ class LLMChat {
config["conv_config"] = this->conversation_.SerializeToJSON();
return picojson::value(config);
}
+
+ picojson::object LoadGenerationConfigFromString(const std::string& generation_config_str) {
+ picojson::object generation_config = picojson::object();
+ if (!generation_config_str.empty()) {
+ picojson::value generation_config_json;
+ picojson::parse(generation_config_json, generation_config_str);
+ generation_config = generation_config_json.get();
+ }
+ return generation_config;
+ }
+
+ void ReadGenerationConfig(picojson::object generation_config, double* gen_temperature,
+ NDArray* gen_temperature_arr, double* gen_repetition_penalty,
+ double* gen_presence_penalty, double* gen_frequency_penalty,
+ double* gen_top_p) {
+ if (generation_config.count("temperature")) {
+ CHECK(generation_config["temperature"].is());
+ *gen_temperature = generation_config["temperature"].get();
+
+ *gen_temperature_arr = NDArray::Empty({}, DataType::Float(32), device_);
+ float temperature_cast = static_cast(*gen_temperature);
+ gen_temperature_arr->CopyFromBytes(&temperature_cast, sizeof(float));
+ } else {
+ *gen_temperature = this->temperature_;
+ *gen_temperature_arr = this->temperature_arr_;
+ }
+ if (generation_config.count("repetition_penalty")) {
+ CHECK(generation_config["repetition_penalty"].is());
+ CHECK(generation_config["repetition_penalty"].get() > 0)
+ << "Repetition penalty must be a positive number!";
+ *gen_repetition_penalty = generation_config["repetition_penalty"].get();
+ } else {
+ *gen_repetition_penalty = this->repetition_penalty_;
+ }
+ if (generation_config.count("presence_penalty")) {
+ CHECK(generation_config["presence_penalty"].is());
+ CHECK(abs(generation_config["presence_penalty"].get()) <= 2)
+ << "Presence penalty must be in the range -2 to 2!";
+ *gen_presence_penalty = generation_config["presence_penalty"].get();
+ } else {
+ *gen_presence_penalty = this->presence_penalty_;
+ }
+ if (generation_config.count("frequency_penalty")) {
+ CHECK(generation_config["frequency_penalty"].is());
+ CHECK(abs(generation_config["frequency_penalty"].get()) <= 2)
+ << "Frequency penalty must be in the range -2 to 2!";
+ *gen_frequency_penalty = generation_config["frequency_penalty"].get();
+ } else {
+ *gen_frequency_penalty = this->frequency_penalty_;
+ }
+ if (generation_config.count("top_p")) {
+ CHECK(generation_config["top_p"].is());
+ *gen_top_p = generation_config["top_p"].get();
+ } else {
+ *gen_top_p = this->top_p_;
+ }
+ }
+
/*!
* \brief Sample output token from logits on device
*/
- int32_t SampleTokenFromLogits(NDArray logits_on_device, float temperature, float top_p) {
- if (repetition_penalty_ == 1.0f) {
- if (temperature_ < 1e-6f) {
- this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
- } else {
- this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, temperature_));
+ int32_t SampleTokenFromLogits(NDArray logits_on_device,
+ picojson::object generation_config = picojson::object()) {
+ // prepare generation settings
+ // the generation_config will not override the original config
+ // since is only used for this generation
+ double gen_temperature;
+ double gen_repetition_penalty;
+ double gen_presence_penalty;
+ double gen_frequency_penalty;
+ double gen_top_p;
+ this->ReadGenerationConfig(generation_config, &gen_temperature, &this->temperature_arr_,
+ &gen_repetition_penalty, &gen_presence_penalty,
+ &gen_frequency_penalty, &gen_top_p);
+
+ // update logits
+ if (gen_presence_penalty != 0.0f || gen_frequency_penalty != 0.0f) {
+ this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
+ this->ApplyPresenceAndFrequencyPenaltyOnCPU(gen_presence_penalty, gen_presence_penalty);
+ if (gen_temperature >= 1e-6f) {
+ this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature);
}
- } else {
+ } else if (gen_repetition_penalty != 1.0f) {
this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
- this->ApplyRepetitionPenaltyOnCPU();
- if (temperature_ >= 1e-6f) {
- this->ApplySoftmaxWithTemperatureOnCPU();
+ this->ApplyRepetitionPenaltyOnCPU(gen_repetition_penalty);
+ if (gen_temperature >= 1e-6f) {
+ this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature);
+ }
+ } else {
+ if (gen_temperature < 1e-6f) {
+ this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
+ } else {
+ this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_));
}
}
+
+ // perform sampling
auto tstart = std::chrono::high_resolution_clock::now();
int next_token;
- if (temperature_ < 1e-6f) {
- next_token = this->SampleFromLogitsOnCPU();
+ if (gen_temperature < 1e-6f) {
+ next_token = this->SampleFromLogitsOnCPU(gen_temperature, gen_top_p);
} else {
- next_token = this->SampleFromProbOnCPU();
+ next_token = this->SampleFromProbOnCPU(gen_top_p);
}
auto tend = std::chrono::high_resolution_clock::now();
this->sample_total_time += static_cast((tend - tstart).count()) / 1e9;
@@ -990,7 +1034,38 @@ class LLMChat {
* \brief Add a generated token and check for stop condition.
* \param next_token The next token.
*/
- void ProcessNextToken(int32_t next_token) {
+ void ProcessNextToken(int32_t next_token,
+ picojson::object generation_config = picojson::object()) {
+ // prepare generation settings
+ // the generation_config will not override the original config
+ // since is only used for this generation
+ int64_t gen_max_gen_len;
+ if (generation_config.count("max_gen_len")) {
+ CHECK(generation_config["max_gen_len"].is());
+ gen_max_gen_len = generation_config["max_gen_len"].get();
+ } else {
+ gen_max_gen_len = this->max_gen_len_;
+ }
+
+ std::vector gen_stop_strs;
+ gen_stop_strs.push_back(conversation_.stop_str);
+
+ if (generation_config.count("stop")) {
+ if (!generation_config["stop"].is()) {
+ CHECK(generation_config["stop"].is() ||
+ generation_config["stop"].is());
+ if (generation_config["stop"].is()) {
+ gen_stop_strs.push_back(generation_config["stop"].get());
+ } else {
+ picojson::array gen_stop_strs_arr = generation_config["stop"].get();
+ for (const picojson::value& v : gen_stop_strs_arr) {
+ CHECK(v.is());
+ gen_stop_strs.push_back(v.get());
+ }
+ }
+ }
+ }
+
ICHECK(!stop_triggered_) << "Cannot call process when it is stopped";
stop_triggered_ =
@@ -999,32 +1074,39 @@ class LLMChat {
if (!stop_triggered_) {
output_ids_.push_back(next_token);
- appeared_token_ids_.insert(next_token);
+ if (appeared_token_freq_.find(next_token) != appeared_token_freq_.end()) {
+ appeared_token_freq_[next_token] += 1;
+ } else {
+ appeared_token_freq_[next_token] = 1;
+ }
}
output_message_ = tokenizer_->Decode(output_ids_);
- if (!conversation_.stop_str.empty()) {
- size_t stop_pos = output_message_.rfind(conversation_.stop_str);
- if (stop_pos != std::string::npos) {
- stop_triggered_ = true;
- if (ft_.support_backtracking_kv_) {
- // back tracking, find the first set of token that is smaller
- // than the length
- size_t backoff = 0;
- for (; backoff < output_ids_.size(); ++backoff) {
- output_ids_.pop_back();
- output_message_ = tokenizer_->Decode(output_ids_);
- if (output_message_.length() <= stop_pos) break;
- }
- // resize kv to remove the context
- ft_.fkvcache_array_popn_(kv_cache_, backoff);
- total_seq_len_ -= backoff;
+ size_t stop_pos = std::string::npos;
+ for (const std::string& stop_str : gen_stop_strs) {
+ if (!stop_str.empty()) {
+ stop_pos = std::min(stop_pos, output_message_.rfind(stop_str));
+ }
+ }
+
+ if (stop_pos != std::string::npos) {
+ stop_triggered_ = true;
+ if (ft_.support_backtracking_kv_) {
+ // back tracking, find the first set of token that is smaller
+ // than the length
+ size_t backoff = 0;
+ for (; (output_ids_.size() > 0) && (output_message_.length() > stop_pos); ++backoff) {
+ output_ids_.pop_back();
+ output_message_ = tokenizer_->Decode(output_ids_);
}
+ // resize kv to remove the context
+ ft_.fkvcache_array_popn_(kv_cache_, backoff);
+ total_seq_len_ -= backoff;
}
}
- if (static_cast(output_ids_.size()) >= max_gen_len_) {
+ if (static_cast(output_ids_.size()) >= gen_max_gen_len) {
stop_triggered_ = true;
} else if (total_seq_len_ >= max_window_size_) {
stop_triggered_ = true;
@@ -1077,32 +1159,42 @@ class LLMChat {
return Downcast(ret[0]);
}
- NDArray Softmax(NDArray input, float temperature) {
+ NDArray Softmax(NDArray input, NDArray temperature_arr) {
NDArray ret;
- ret = ft_.softmax_func_(input, temperature_arr_);
+ ret = ft_.softmax_func_(input, temperature_arr);
return ret;
}
- void ApplyRepetitionPenaltyOnCPU() {
+ void ApplyRepetitionPenaltyOnCPU(float repetition_penalty) {
CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!";
CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
float* logits_raw_data = static_cast(logits_on_cpu_->data);
- for (const int32_t& token_id : this->appeared_token_ids_) {
- if (logits_raw_data[token_id] <= 0) {
- logits_raw_data[token_id] *= this->repetition_penalty_;
+ for (const auto& token_freq : this->appeared_token_freq_) {
+ if (logits_raw_data[token_freq.first] <= 0) {
+ logits_raw_data[token_freq.first] *= repetition_penalty;
} else { // logits > 0
- logits_raw_data[token_id] /= this->repetition_penalty_;
+ logits_raw_data[token_freq.first] /= repetition_penalty;
}
}
}
- void ApplySoftmaxWithTemperatureOnCPU() {
+ void ApplyPresenceAndFrequencyPenaltyOnCPU(float presence_penalty, float frequency_penalty) {
+ CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!";
+ CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
+ float* logits_raw_data = static_cast(logits_on_cpu_->data);
+ for (const auto& token_freq : this->appeared_token_freq_) {
+ logits_raw_data[token_freq.first] -=
+ (token_freq.second * frequency_penalty + presence_penalty);
+ }
+ }
+
+ void ApplySoftmaxWithTemperatureOnCPU(float temperature) {
CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!";
CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
int vocab_size = logits_on_cpu_->shape[logits_on_cpu_->ndim - 1];
float* logits_raw_data = static_cast(logits_on_cpu_->data);
float m = std::numeric_limits::min();
- float inv_temp = 1.0f / this->temperature_;
+ float inv_temp = 1.0f / temperature;
double d = 0.0f;
for (int i = 0; i < vocab_size; ++i) {
float x = logits_raw_data[i] * inv_temp;
@@ -1137,18 +1229,18 @@ class LLMChat {
// Utils
static double GetRandomNumber() { return RandomGenerator::GetInstance().GetRandomNumber(); }
- int32_t SampleFromLogitsOnCPU() {
+ int32_t SampleFromLogitsOnCPU(float temperature, float top_p) {
ICHECK(logits_on_cpu_.defined()) << "logits_on_cpu_ is not defined";
ICHECK_EQ(logits_on_cpu_->ndim, 3) << "logits_on_cpu_ should be 3D";
ICHECK_EQ(logits_on_cpu_->shape[0], 1) << "logits_on_cpu_ should be 1 batch";
- return fsample_topp_from_logits_(logits_on_cpu_, temperature_, top_p_, GetRandomNumber());
+ return fsample_topp_from_logits_(logits_on_cpu_, temperature, top_p, GetRandomNumber());
}
- int32_t SampleFromProbOnCPU() {
+ int32_t SampleFromProbOnCPU(float top_p) {
ICHECK(logits_on_cpu_.defined()) << "logits_on_cpu_ is not defined";
ICHECK_EQ(logits_on_cpu_->ndim, 3) << "logits_on_cpu_ should be 3D";
ICHECK_EQ(logits_on_cpu_->shape[0], 1) << "logits_on_cpu_ should be 1 batch";
- return fsample_topp_from_prob_(logits_on_cpu_, top_p_, GetRandomNumber());
+ return fsample_topp_from_prob_(logits_on_cpu_, top_p, GetRandomNumber());
}
//----------------------------
@@ -1185,12 +1277,16 @@ class LLMChat {
NDArray temperature_arr_;
// repetition penalty
double repetition_penalty_{1.0};
+ // presence penalty
+ double presence_penalty_{0.0};
+ // frequency penalty
+ double frequency_penalty_{0.0};
// top_p
double top_p_{0.95};
// output ids till now (refresh after encoding step)
std::vector output_ids_;
- // appeared token ids till now (refresh after encoding step)
- std::unordered_set appeared_token_ids_;
+ // frequency of appeared token ids till now (refresh after encoding step)
+ std::unordered_map appeared_token_freq_;
// output message till now (refresh after encoding step)
std::string output_message_;
// Whether encounter stop str
@@ -1279,7 +1375,7 @@ class LLMChatModule : public ModuleNode {
});
} else if (name == "prefill") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
- ICHECK(1 <= args.size() && args.size() <= 3);
+ ICHECK(1 <= args.size() && args.size() <= 4);
if (args.size() == 1) {
// args: inp (with decode_next_token = true, place_in_prompt = kAll)
GetChat()->PrefillStep(args[0]);
@@ -1290,11 +1386,15 @@ class LLMChatModule : public ModuleNode {
// args: inp, decode_next_token, place_in_prompt
PlaceInPrompt place_in_prompt = static_cast(static_cast(args[2]));
GetChat()->PrefillStep(args[0], true, args[1], place_in_prompt);
+ } else if (args.size() == 4) {
+ // args: inp, decode_next_token, place_in_prompt, generation_config_str
+ PlaceInPrompt place_in_prompt = static_cast(static_cast(args[2]));
+ GetChat()->PrefillStep(args[0], true, args[1], place_in_prompt, args[3]);
}
});
} else if (name == "embed") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
- ICHECK(1 <= args.size() && args.size() <= 2);
+ ICHECK(1 <= args.size() && args.size() <= 3);
if (args.size() == 1) {
// args: inp (with place_in_prompt = kAll)
*rv = GetChat()->EmbedStep(args[0]);
@@ -1302,22 +1402,36 @@ class LLMChatModule : public ModuleNode {
// args: inp, place_in_prompt
PlaceInPrompt place_in_prompt = static_cast(static_cast(args[1]));
*rv = GetChat()->EmbedStep(args[0], true, place_in_prompt);
+ } else if (args.size() == 3) {
+ // args: inp, place_in_prompt, generation_config_str
+ PlaceInPrompt place_in_prompt = static_cast(static_cast(args[1]));
+ *rv = GetChat()->EmbedStep(args[0], true, place_in_prompt, args[2]);
}
});
} else if (name == "prefill_with_embed") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
- ICHECK(1 <= args.size() && args.size() <= 2);
+ ICHECK(1 <= args.size() && args.size() <= 3);
if (args.size() == 1) {
// args: embedding (with decode_next_token = true)
GetChat()->PrefillWithEmbedStep(args[0]);
} else if (args.size() == 2) {
// args: embedding, decode_next_token
GetChat()->PrefillWithEmbedStep(args[0], args[1]);
+ } else if (args.size() == 3) {
+ // args: embedding, decode_next_token, generation_config_str
+ GetChat()->PrefillWithEmbedStep(args[0], args[1], args[2]);
}
});
} else if (name == "decode") {
- return PackedFunc(
- [this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { GetChat()->DecodeStep(); });
+ return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
+ ICHECK(0 <= args.size() && args.size() <= 1);
+ if (args.size() == 0) {
+ GetChat()->DecodeStep();
+ } else if (args.size() == 1) {
+ // args: generation_config_str
+ GetChat()->DecodeStep(args[0]);
+ }
+ });
} else if (name == "reset_chat") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size(), 0);
diff --git a/cpp/llm_chat.h b/cpp/llm_chat.h
index 1839e8cb4d..39408d1685 100644
--- a/cpp/llm_chat.h
+++ b/cpp/llm_chat.h
@@ -6,17 +6,7 @@
#include
#include
-#ifndef MLC_LLM_DLL
-#ifdef _WIN32
-#ifdef MLC_LLM_EXPORTS
-#define MLC_LLM_DLL __declspec(dllexport)
-#else
-#define MLC_LLM_DLL __declspec(dllimport)
-#endif
-#else
-#define MLC_LLM_DLL __attribute__((visibility("default")))
-#endif
-#endif
+#include "base.h"
namespace mlc {
namespace llm {
diff --git a/cpp/random.h b/cpp/random.h
new file mode 100644
index 0000000000..e6331a9699
--- /dev/null
+++ b/cpp/random.h
@@ -0,0 +1,37 @@
+/*!
+ * Copyright (c) 2023 by Contributors
+ * \file random.h
+ * \brief Header of random number generator.
+ */
+
+#ifndef MLC_LLM_RANDOM_H_
+#define MLC_LLM_RANDOM_H_
+
+#include
+
+namespace mlc {
+namespace llm {
+
+// Random number generator
+class RandomGenerator {
+ private:
+ std::mt19937 gen;
+ std::uniform_real_distribution<> dis;
+
+ RandomGenerator(int seed) : gen(seed), dis(0.0, 1.0) {}
+
+ public:
+ static RandomGenerator& GetInstance(int seed = std::random_device{}()) {
+ static RandomGenerator instance(seed);
+ return instance;
+ }
+
+ double GetRandomNumber() { return dis(gen); }
+
+ void SetSeed(int seed) { gen.seed(seed); }
+};
+
+} // namespace llm
+} // namespace mlc
+
+#endif // MLC_LLM_RANDOM_H_
diff --git a/cpp/support.h b/cpp/support.h
new file mode 100644
index 0000000000..20eadbbd0a
--- /dev/null
+++ b/cpp/support.h
@@ -0,0 +1,31 @@
+/*!
+ * Copyright (c) 2023 by Contributors
+ * \file support.h
+ * \brief Header of utilities.
+ */
+
+#ifndef MLC_LLM_COMMON_H_
+#define MLC_LLM_COMMON_H_
+
+#include
+#include
+
+namespace mlc {
+namespace llm {
+
+inline std::string LoadBytesFromFile(const std::string& path) {
+ std::ifstream fs(path, std::ios::in | std::ios::binary);
+ ICHECK(!fs.fail()) << "Cannot open " << path;
+ std::string data;
+ fs.seekg(0, std::ios::end);
+ size_t size = static_cast(fs.tellg());
+ fs.seekg(0, std::ios::beg);
+ data.resize(size);
+ fs.read(data.data(), size);
+ return data;
+}
+
+} // namespace llm
+} // namespace mlc
+
+#endif // MLC_LLM_COMMON_H_
diff --git a/cpp/tokenizers.cc b/cpp/tokenizers.cc
new file mode 100644
index 0000000000..8d38dd9572
--- /dev/null
+++ b/cpp/tokenizers.cc
@@ -0,0 +1,61 @@
+/*!
+ * Copyright (c) 2023 by Contributors
+ * \file tokenizer.cc
+ */
+
+#include "tokenizers.h"
+
+#include
+#include
+
+#include
+#include
+#include
+
+#include "support.h"
+
+namespace mlc {
+namespace llm {
+
+std::unique_ptr TokenizerFromPath(const std::string& _path) {
+ std::filesystem::path path(_path);
+ std::filesystem::path sentencepiece;
+ std::filesystem::path huggingface;
+ std::filesystem::path rwkvworld;
+ CHECK(std::filesystem::exists(path)) << "Cannot find tokenizer via path: " << _path;
+ if (std::filesystem::is_directory(path)) {
+ sentencepiece = path / "tokenizer.model";
+ huggingface = path / "tokenizer.json";
+ rwkvworld = path / "tokenizer_model";
+ // Check ByteLevelBPE
+ {
+ std::filesystem::path merges_path = path / "merges.txt";
+ std::filesystem::path vocab_path = path / "vocab.json";
+ std::filesystem::path added_tokens_path = path / "added_tokens.json";
+ if (std::filesystem::exists(merges_path) && std::filesystem::exists(vocab_path) &&
+ std::filesystem::exists(added_tokens_path)) {
+ std::string vocab = LoadBytesFromFile(vocab_path.string());
+ std::string merges = LoadBytesFromFile(merges_path.string());
+ std::string added_tokens = LoadBytesFromFile(added_tokens_path.string());
+ return Tokenizer::FromBlobByteLevelBPE(vocab, merges, added_tokens);
+ }
+ }
+ } else {
+ sentencepiece = path.parent_path() / "tokenizer.model";
+ huggingface = path.parent_path() / "tokenizer.json";
+ rwkvworld = path.parent_path() / "tokenizer_model";
+ }
+ if (std::filesystem::exists(sentencepiece)) {
+ return Tokenizer::FromBlobSentencePiece(LoadBytesFromFile(sentencepiece.string()));
+ }
+ if (std::filesystem::exists(huggingface)) {
+ return Tokenizer::FromBlobJSON(LoadBytesFromFile(huggingface.string()));
+ }
+ if (std::filesystem::exists(rwkvworld)) {
+ return Tokenizer::FromBlobRWKVWorld(rwkvworld.string());
+ }
+ LOG(FATAL) << "Cannot find any tokenizer under: " << _path;
+}
+
+} // namespace llm
+} // namespace mlc
diff --git a/cpp/tokenizers.h b/cpp/tokenizers.h
new file mode 100644
index 0000000000..f44f828e97
--- /dev/null
+++ b/cpp/tokenizers.h
@@ -0,0 +1,24 @@
+/*!
+ * Copyright (c) 2023 by Contributors
+ * \file tokenizers.h
+ * \brief Header of tokenizer related functions.
+ */
+
+#ifndef MLC_LLM_TOKENIZER_H_
+#define MLC_LLM_TOKENIZER_H_
+
+#include
+
+#include "base.h"
+
+namespace mlc {
+namespace llm {
+
+using tokenizers::Tokenizer;
+
+MLC_LLM_DLL std::unique_ptr TokenizerFromPath(const std::string& _path);
+
+} // namespace llm
+} // namespace mlc
+
+#endif // MLC_LLM_TOKENIZER_H_
diff --git a/docs/community/faq.rst b/docs/community/faq.rst
index 45d73b4904..3913dd9639 100644
--- a/docs/community/faq.rst
+++ b/docs/community/faq.rst
@@ -5,7 +5,7 @@ Frequently Asked Questions
This is a list of Frequently Asked Questions (FAQ) about the MLC-LLM. Feel free to suggest new entries!
-... How can I customize the temperature, repetition penalty of models?
+... How can I customize the temperature, and repetition penalty of models?
Please check our :doc:`/get_started/mlc_chat_config` tutorial.
... What's the quantization algorithm MLC-LLM using?
@@ -13,5 +13,4 @@ This is a list of Frequently Asked Questions (FAQ) about the MLC-LLM. Feel free
... Why do I encounter an error ``free(): invalid pointer, Aborted (core dumped)`` at the end of model compilation?
This happens if you compiled TVM-Unity from source and didn't hide LLVM symbols in cmake configurations.
- Please follow our instructions in :ref:`Building TVM Unity from Source ` tutorial to compile TVM-Unity which hides LLVM symbols,
- or use our pre-builc MLC-AI pip wheels from `MLC Packages `__.
+ Please follow our instructions in :ref:`Building TVM Unity from Source ` tutorial to compile TVM-Unity which hides LLVM symbols, or use our pre-built MLC-LLM :doc:`pip wheels <../install/mlc_llm>`.
diff --git a/docs/community/guideline.rst b/docs/community/guideline.rst
index 7f671f614b..38a03e463e 100644
--- a/docs/community/guideline.rst
+++ b/docs/community/guideline.rst
@@ -42,11 +42,11 @@ Ready to contribute to MLC-LLM? Awesome! We are excited to see you are ready to
The standard way to make changes to MLC-LLM code base is through creating a `pull-request `__,
and we will review your code and merge it to the code base when it is ready.
-The first step to become a developer is to `fork `__ the repository to your own
+The first step to becoming a developer is to `fork `__ the repository to your own
github account, you will notice a repository under ``https://github.com/username/mlc-llm`` where ``username`` is your github user name.
You can clone your fork to your local machine and commit changes, or edit the contents of your fork (in the case you are just fixing typos)
-on github directly. Once your update is complete, you can click the ``contribute`` button and open a pull request to the main repository.
+on GitHub directly. Once your update is complete, you can click the ``contribute`` button and open a pull request to the main repository.
.. _contribute-new-models:
@@ -55,7 +55,7 @@ Contribute New Models to MLC-LLM
* If you have compiled a model using our :doc:`/compilation/compile_models` tutorial for an existing model architecture, please upload your models to the internet (e.g., Hugging Face) by following :ref:`distribute-compiled-models` tutorial. Once you have done that, you can create a pull request to add an entry in the :doc:`/prebuilt_models` page. Additionally, you have the option to `create a speed report issue `__ to track the speed and memory consumption of your model. You don't need to test it on all devices; let the community collaborate on building it together!
-* If you add a new model variant to MLC-LLM by following our :doc:`/tutorials/bring-your-own-models` tutorial.
+* If you add a new model variant to MLC-LLM by following our :doc:`/tutorials/customize/define_new_models` tutorial.
Please create a pull request to add your model architecture (currently model architectures are placed under
`relax_models `__ folder).
@@ -86,14 +86,14 @@ Fo your convenience, you can use `clang-format `__ to acknowledge contributors,
-please let us know if you contribute to the project and your name is not included in the list.
+please let us know if you contribute to the project and if your name is not included in the list.
diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst
index e559c8fc27..98c7f2d156 100644
--- a/docs/compilation/compile_models.rst
+++ b/docs/compilation/compile_models.rst
@@ -4,14 +4,14 @@ Compile Models via MLC
======================
This page describes how to compile a model with MLC LLM. Model compilation takes model inputs, produces quantized model weights,
-and optimized model lib for a given platform. It enables users to bring their own new model weights, try different quantization modes,
+and optimizes model lib for a given platform. It enables users to bring their own new model weights, try different quantization modes,
and customize the overall model optimization flow.
.. note::
Before you proceed, please make sure that you have :ref:`install-tvm-unity` correctly installed on your machine.
TVM-Unity is the necessary foundation for us to compile models with MLC LLM.
If you want to build webgpu, please also complete :ref:`install-web-build`.
- Please also follow the instruction in :ref:`deploy-cli` to obtain the CLI app that can be used to chat with the compiled model.
+ Please also follow the instructions in :ref:`deploy-cli` to obtain the CLI app that can be used to chat with the compiled model.
Finally, we strongly recommend you read :ref:`project-overview` first to get familiarized with the high-level terminologies.
@@ -25,12 +25,12 @@ Install MLC-LLM Package
Work with Source Code
^^^^^^^^^^^^^^^^^^^^^
-The easiest way is to use MLC-LLM is to clone the repository, and compile models under the root directory of the repository.
+The easiest way to use MLC-LLM is to clone the repository, and compile models under the root directory of the repository.
.. code:: bash
# clone the repository
- git clone git@github.com:mlc-ai/mlc-llm.git --recursive
+ git clone https://github.com/mlc-ai/mlc-llm.git --recursive
# enter to root directory of the repo
cd mlc-llm
# install mlc-llm
@@ -106,7 +106,7 @@ your personal computer.
xcrun: error: unable to find utility "metallib", not a developer tool or in PATH
, please check and make sure you have Command Line Tools for Xcode installed correctly.
- You can use ``xcrun metal`` to validate: when it prints ``metal: error: no input files``, it means the Command Line Tools for Xcode is installed and can be found, and you can proceed the model compiling.
+ You can use ``xcrun metal`` to validate: when it prints ``metal: error: no input files``, it means the Command Line Tools for Xcode is installed and can be found, and you can proceed with the model compiling.
.. group-tab:: Android
@@ -172,7 +172,7 @@ We can check the output with the commands below:
tokenizer_config.json
We now chat with the model using the command line interface (CLI) app.
- Follow the build from source instruction
+ Follow the build from the source instruction
.. code:: shell
@@ -271,7 +271,7 @@ We can check the output with the commands below:
tokenizer_config.json
The model lib ``dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1/RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm``
- can be uploaded to internet. You can pass a ``model_lib_map`` field to WebLLM app config to use this library.
+ can be uploaded to the internet. You can pass a ``model_lib_map`` field to WebLLM app config to use this library.
Each compilation target produces a specific model library for the given platform. The model weight is shared across
@@ -311,7 +311,7 @@ In other cases you need to specify the model via ``--model``.
- ``dist/models/MODEL_NAME_OR_PATH`` (e.g., ``--model Llama-2-7b-chat-hf``),
- ``MODEL_NAME_OR_PATH`` (e.g., ``--model /my-model/Llama-2-7b-chat-hf``).
- When running the compile command using ``--model``, please make sure you have placed the model to compile under ``dist/models/`` or other location on the disk.
+ When running the compile command using ``--model``, please make sure you have placed the model to compile under ``dist/models/`` or another location on the disk.
--hf-path HUGGINGFACE_NAME The name of the model's Hugging Face repository.
We will download the model to ``dist/models/HUGGINGFACE_NAME`` and load the model from this directory.
@@ -336,11 +336,11 @@ The following arguments are optional:
we will use the maximum sequence length from the ``config.json`` in the model directory.
--reuse-lib LIB_NAME Specifies the previously generated library to reuse.
This is useful when building the same model architecture with different weights.
- You can refer to the :ref:`model distribution ` page for detail of this argument.
+ You can refer to the :ref:`model distribution ` page for details of this argument.
--use-cache When ``--use-cache=0`` is specified,
the model compilation will not use cached file from previous builds,
and will compile the model from the very start.
- Using cache can help reduce the time needed to compile.
+ Using a cache can help reduce the time needed to compile.
--debug-dump Specifies whether to dump debugging files during compilation.
--use-safetensors Specifies whether to use ``.safetensors`` instead of the default ``.bin`` when loading in model weights.
@@ -354,7 +354,7 @@ This section lists compile commands for more models that you can try out.
.. tab:: Model: Llama-2-7B
Please `request for access `_ to the Llama-2 weights from Meta first.
- After granted the access, please create directory ``dist/models`` and download the model to the directory.
+ After granted access, please create directory ``dist/models`` and download the model to the directory.
For example, you can run the following code:
.. code:: shell
diff --git a/docs/compilation/distribute_compiled_models.rst b/docs/compilation/distribute_compiled_models.rst
index b6f31f9386..69dc0e847d 100644
--- a/docs/compilation/distribute_compiled_models.rst
+++ b/docs/compilation/distribute_compiled_models.rst
@@ -67,7 +67,7 @@ You can **optionally** customize the chat config file
``dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1/params/mlc-chat-config.json`` (checkout :ref:`configure-mlc-chat-json` for more detailed instructions).
You can also simply use the default configuration and skip this step.
-For demonstration purpose, we update ``mean_gen_len`` to 32 and ``max_gen_len`` to 64.
+For demonstration purposes, we update ``mean_gen_len`` to 32 and ``max_gen_len`` to 64.
We also update ``conv_template`` to ``"LM"`` because the model is instruction-tuned.
@@ -160,8 +160,8 @@ Download the Distributed Models and Run in iOS App
--------------------------------------------------
For iOS app, model libraries are statically packed into the app at the time of app building.
-Therefore, the iOS app supports running any models whose model libraries are integrated into the app.
-You can check the :ref:`list of supported model libraries `.
+Therefore, the iOS app supports running any model whose model libraries are integrated into the app.
+You can check the :ref:`list of supported model libraries `.
To download and run the compiled RedPajama-3B instruct model on iPhone, we need to reuse the integrated ``RedPajama-INCITE-Chat-3B-v1-q4f16_1`` model library.
Please revisit :ref:`distribute-model-step3-specify-model-lib` and make sure the ``model_lib`` field of `mlc-chat-config.json` is set to ``RedPajama-INCITE-Chat-3B-v1-q4f16_1``.
@@ -198,7 +198,7 @@ Now we can download the model weights in iOS app and run the model by following
.. tab:: Step 4
- When the download is finished, click into the model and enjoy.
+ When the download is finished, click on the model and enjoy.
.. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/iPhone-distribute-4.jpeg
:align: center
diff --git a/docs/compilation/get-vicuna-weight.rst b/docs/compilation/get-vicuna-weight.rst
index 0cc42380e9..2ea4ba5d97 100644
--- a/docs/compilation/get-vicuna-weight.rst
+++ b/docs/compilation/get-vicuna-weight.rst
@@ -5,7 +5,7 @@ Getting Vicuna Weights
:local:
:depth: 2
-`Vicuna `_ is a open-source chatbot trained by fine-tuning `LLaMA `_ on `ShartGPT `_ data.
+`Vicuna `_ is an open-source chatbot trained by fine-tuning `LLaMA `_ on `ShartGPT `_ data.
Please note that the official Vicuna weights are delta weights applied to the LLaMA weights in order to comply with the LLaMA license. Users are responsible for applying these delta weights themselves.
@@ -14,7 +14,7 @@ In this tutorial, we will show how to apply the delta weights to LLaMA weights t
Install FastChat
----------------
-FastChat offers convenient utility functions for applying delta to LLaMA weights. You can easily install it using pip.
+FastChat offers convenient utility functions for applying the delta to LLaMA weights. You can easily install it using pip.
.. code-block:: bash
@@ -38,14 +38,14 @@ Then download the weights (both the LLaMA weight and Vicuna delta weight):
git clone https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
-There is a name mis-alignment issue in the LLaMA weights and Vicuna delta weights.
+There is a name misalignment issue in the LLaMA weights and Vicuna delta weights.
Please follow these steps to modify the content of the "config.json" file:
.. code-block:: bash
sed -i 's/LLaMAForCausalLM/LlamaForCausalLM/g' llama-7b-hf/config.json
-Then use ``fschat`` to apply delta to LLaMA weights
+Then use ``fschat`` to apply the delta to LLaMA weights
.. code-block:: bash
diff --git a/docs/compilation/python.rst b/docs/compilation/python.rst
index 99486a751b..98e4f934e7 100644
--- a/docs/compilation/python.rst
+++ b/docs/compilation/python.rst
@@ -5,8 +5,8 @@ Python API for Model Compilation
:local:
:depth: 2
-We expose Python API for compiling/building model in the package :py:mod:`mlc_llm`, so
-that users may build model in any directory in their program (i.e. not just
+We expose Python API for compiling/building models in the package :py:mod:`mlc_llm`, so
+that users may build a model in any directory in their program (i.e. not just
within the mlc-llm repo).
Install MLC-LLM as a Package
@@ -44,7 +44,7 @@ After installing the package, you can build the model using :meth:`mlc_llm.build
which takes in an instance of :class:`BuildArgs` (a dataclass that represents
the arguments for building a model).
-For detailed instruction with code, please refer to `the python notebook
+For detailed instructions with code, please refer to `the Python notebook
`_
(executable in Colab), where we walk you through compiling Llama-2 with :py:mod:`mlc_llm`
in Python.
@@ -56,7 +56,7 @@ API Reference
In order to use the python API :meth:`mlc_llm.build_model`, users need to create
an instance of the dataclass :class:`BuildArgs`. The corresponding arguments in
-command line shown in :ref:`compile-command-specification` are automatically
+the command line shown in :ref:`compile-command-specification` are automatically
converted from the definition of :class:`BuildArgs` and are equivalent.
Then with an instantiated :class:`BuildArgs`, users can call the build API
diff --git a/docs/conf.py b/docs/conf.py
index ee42500d51..0f7ed19014 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -7,6 +7,8 @@
# -- General configuration ------------------------------------------------
sys.path.insert(0, os.path.abspath("../python"))
+sys.path.insert(0, os.path.abspath("../"))
+autodoc_mock_imports = ["torch"]
# do not load mlc-llm.so in docs
os.environ["SKIP_LOADING_MLCLLM_SO"] = "1"
@@ -29,9 +31,7 @@
"sphinx_reredirects",
]
-redirects = {
- "get_started/try_out": "../index.html#getting-started"
-}
+redirects = {"get_started/try_out": "../index.html#getting-started"}
source_suffix = [".rst"]
diff --git a/docs/deploy/android.rst b/docs/deploy/android.rst
index c26e9f3445..0c2ed8535f 100644
--- a/docs/deploy/android.rst
+++ b/docs/deploy/android.rst
@@ -98,11 +98,11 @@ To deploy models on Android with reasonable performance, one has to cross-compil
--model ./dist/models/$MODEL_NAME \
--quantization $QUANTIZATION
-This generates directory ``./dist/$MODEL_NAME-$QUANTIZATION`` which contains the necessary components to run the model, as explained below.
+This generates the directory ``./dist/$MODEL_NAME-$QUANTIZATION`` which contains the necessary components to run the model, as explained below.
**Expected output format**. By default models are placed under ``./dist/${MODEL_NAME}-${QUANTIZATION}``, and the result consists of 3 major components:
-- Runtime configuration: It configures conversation templates including system prompts, repetition repetition penalty, sampling including temperature and top-p probability, maximum sequence length, etc. It is usually named as ``mlc-chat-config.json`` under ``params/`` along side with tokenizer configurations.
+- Runtime configuration: It configures conversation templates including system prompts, repetition repetition penalty, sampling including temperature and top-p probability, maximum sequence length, etc. It is usually named as ``mlc-chat-config.json`` under ``params/``alongside with tokenizer configurations.
- Model lib: The compiled library that uses mobile GPU. It is usually named as ``${MODEL_NAME}-${QUANTIZATION}-android.tar``, for example, ``Llama-2-7b-chat-hf-q4f16_0-android.tar``.
- Model weights: the model weights are sharded as ``params_shard_*.bin`` under ``params/`` and the metadata is stored in ``ndarray-cache.json``.
@@ -144,16 +144,16 @@ The model execution logic in mobile GPUs is incorporated into ``libtvm4j_runtime
**Build the Android app**. Open folder ``./android/MLCChat`` as an Android Studio Project. Connect your Android device to your machine. In the menu bar of Android Studio, click "Build → Make Project". Once the build is finished, click "Run → Run 'app'" and you will see the app launched on your phone.
.. note::
- ❗ This app cannot be run in an emulator and thus a physical phone is required, because MLC LLM needs an actual mobile GPU to meaningfully run at accelerated speed.
+ ❗ This app cannot be run in an emulator and thus a physical phone is required, because MLC LLM needs an actual mobile GPU to meaningfully run at an accelerated speed.
Incorporate Model Weights
-------------------------
Instructions have been provided to build an Android App with MLC LLM in previous sections, but it requires run-time weight downloading from HuggingFace, as configured in `app-config.json` in previous steps under `model_url`. However, it could be desirable to bundle weights together into the app to avoid downloading over the network. In this section, we provide a simple ADB-based walkthrough that hopefully helps with further development.
-**Generating APK**. Enter Android Studio, click "Build → Generate Signed Bundle/APK" to build an APK for release. If it is the first time you generate an APK, you will need to create a key according to `the official guide from Android `_. This APK will be placed under ``android/MLCChat/app/release/app-release.apk``.
+**Generating APK**. Enter Android Studio, and click "Build → Generate Signed Bundle/APK" to build an APK for release. If it is the first time you generate an APK, you will need to create a key according to `the official guide from Android `_. This APK will be placed under ``android/MLCChat/app/release/app-release.apk``.
-**Install ADB and USB debugging**. Enable "USB debugging" in the developer mode in your phone settings. In SDK manager, install `Android SDK Platform-Tools `_. Add the path to platform-tool path to environment variable ``PATH``. Run the following commands, and if ADB is installed correctly, your phone will appear as a device:
+**Install ADB and USB debugging**. Enable "USB debugging" in the developer mode in your phone settings. In SDK manager, install `Android SDK Platform-Tools `_. Add the path to platform-tool path to the environment variable ``PATH``. Run the following commands, and if ADB is installed correctly, your phone will appear as a device:
.. code-block:: bash
diff --git a/docs/deploy/cli.rst b/docs/deploy/cli.rst
index 79501b113d..2f0951686d 100644
--- a/docs/deploy/cli.rst
+++ b/docs/deploy/cli.rst
@@ -3,7 +3,7 @@
CLI and C++ API
===============
-MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box. You may install it from the prebuilt package we provide, or compile it from source.
+MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box. You may install it from the prebuilt package we provide, or compile it from the source.
.. contents:: Table of Contents
:local:
@@ -25,16 +25,16 @@ To use other GPU runtimes, e.g. CUDA, please instead :ref:`build it from source
After installation, activating ``mlc-chat-venv`` environment in Conda will give the ``mlc_chat_cli`` command available.
.. note::
- The prebuilt package supports **Metal** on macOS and **Vulkan** on Linux and Windows. It is possible to use other GPU runtimes such as **CUDA** by compiling MLCChat CLI from source.
+ The prebuilt package supports **Metal** on macOS and **Vulkan** on Linux and Windows. It is possible to use other GPU runtimes such as **CUDA** by compiling MLCChat CLI from the source.
.. _mlcchat_build_from_source:
Option 2. Build MLC Runtime from Source
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-We also provid options to build mlc runtime libraries and ``mlc_chat_cli`` from source.
+We also provide options to build mlc runtime libraries and ``mlc_chat_cli`` from source.
This step is useful when you want to directly obtain a version of mlc runtime library
-and the cli. Please click the details below to see the instruction.
+and the cli. Please click the details below to see the instructions.
.. collapse:: Details
@@ -63,7 +63,7 @@ and the cli. Please click the details below to see the instruction.
conda activate mlc-chat-venv
.. note::
- :doc:`TVM Unity ` compiler is not a dependency to MLCChat CLI. Only its runtime is required, which is automatically included in `3rdparty/tvm `_.
+ :doc:`TVM Unity ` compiler is not a dependency on MLCChat CLI. Only its runtime is required, which is automatically included in `3rdparty/tvm `_.
**Step 2. Configure and build.** A standard git-based workflow is recommended to download MLC LLM, after which you can specify build requirements with our lightweight config generation tool:
@@ -96,7 +96,7 @@ and the cli. Please click the details below to see the instruction.
Run Models through MLCChat CLI
------------------------------
-Once ``mlc_chat_cli`` is installed, you are able to run any MLC-compiled model on command line.
+Once ``mlc_chat_cli`` is installed, you are able to run any MLC-compiled model on the command line.
**Ensure Model Exists.** As the input to ``mlc_chat_cli``, it is always good to double check if the compiled model exists.
@@ -111,6 +111,12 @@ Once ``mlc_chat_cli`` is installed, you are able to run any MLC-compiled model o
- Model lib should be placed at ``./dist/prebuilt/lib/$(local_id)-$(arch).$(suffix)``.
- Model weights and chat config are located under ``./dist/prebuilt/mlc-chat-$(local_id)/``.
+ .. note::
+ Please make sure that you have the same directory structure as above, because the CLI tool
+ relies on it to automatically search for model lib and weights. If you would like to directly
+ provide a full model lib path to override the auto-search, you can pass in a ``--model-lib-path`` argument
+ to the CLI
+
.. collapse:: Example
.. code:: shell
@@ -134,6 +140,12 @@ Once ``mlc_chat_cli`` is installed, you are able to run any MLC-compiled model o
- Model libraries should be placed at ``./dist/$(local_id)/$(local_id)-$(arch).$(suffix)``.
- Model weights and chat config are located under ``./dist/$(local_id)/params/``.
+ .. note::
+ Please make sure that you have the same directory structure as above, because the CLI tool
+ relies on it to automatically search for model lib and weights. If you would like to directly
+ provide a full model lib path to override the auto-search, you can pass in a ``--model-lib-path`` argument
+ to the CLI
+
.. collapse:: Example
.. code:: shell
diff --git a/docs/deploy/ios.rst b/docs/deploy/ios.rst
index 597f594bfb..b6e8e7b55a 100644
--- a/docs/deploy/ios.rst
+++ b/docs/deploy/ios.rst
@@ -7,9 +7,9 @@ iOS App and Swift API
:local:
:depth: 2
-The MLC LLM iOS app can be installed in two ways: through the pre-built package or by building from source.
+The MLC LLM iOS app can be installed in two ways: through the pre-built package or by building from the source.
If you are an iOS user looking to try out the models, the pre-built package is recommended. If you are a
-developer seeking to integrate new features into the package, building the iOS package from source is required.
+developer seeking to integrate new features into the package, building the iOS package from the source is required.
Use Pre-built iOS App
---------------------
@@ -23,7 +23,7 @@ The MLC Chat app is now available in App Store at no cost. You can download and
Build iOS App from Source
-------------------------
-This section shows how we can build the app from source.
+This section shows how we can build the app from the source.
Step 1. Install Build Dependencies
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -134,7 +134,7 @@ Ensure that all the necessary dependencies and configurations are
correctly set up in the Xcode project.
Once you have made the necessary changes, build the iOS app using Xcode.
-If you have an Apple Silicon Mac, you can select target "My Mac (designed for ipad)"
+If you have an Apple Silicon Mac, you can select target "My Mac (designed for iPad)"
to run on your Mac. You can also directly run it on your iPad or iPhone.
.. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/xcode-build.jpg
@@ -163,7 +163,7 @@ controls the list of model URLs and model libs to be packaged into the app.
Additionally, the app prepackages the models under ``./ios/dist``.
This built-in list can be controlled by editing ``prepare_params.sh``.
-You can package new prebuilt models or compiled models by changing the above fields and then repeat the steps above.
+You can package new prebuilt models or compiled models by changing the above fields and then repeating the steps above.
Build Apps with MLC Swift API
@@ -193,8 +193,8 @@ your own app. The package is located under `ios/MLCSwift`.
-ltokenizers_c
-You can then can import the `MLCSwift` package in your app.
-The following code shows an illustrative example about how to use the chat module.
+You can then import the `MLCSwift` package into your app.
+The following code shows an illustrative example of how to use the chat module.
.. code:: swift
@@ -221,7 +221,7 @@ The following code shows an illustrative example about how to use the chat modul
Because the chat module makes heavy use of GPU and thread-local
resources, it needs to run on a dedicated background thread.
Therefore, **avoid using** `DispatchQueue`, which can cause context switching to
- different threads and segfaults due to thread-safety issue.
+ different threads and segfaults due to thread-safety issues.
Use the `ThreadWorker` class to launch all the jobs related
to the chat module. You can check out the source code of
the MLCChat app for a complete example.
diff --git a/docs/deploy/javascript.rst b/docs/deploy/javascript.rst
index cdbd4cc79e..08dd2cde26 100644
--- a/docs/deploy/javascript.rst
+++ b/docs/deploy/javascript.rst
@@ -6,9 +6,9 @@ WebLLM and Javascript API
:depth: 2
WebLLM is a MLC chat web runtime (`WebLLM `_)
-that allows you to build chat applications directly in browser.
+that allows you to build chat applications directly in the browser.
-Try out Prebuilt Webpage
+Try out the Prebuilt Webpage
------------------------
To get started, you can try out `WebLLM prebuilt webpage `__.
diff --git a/docs/deploy/python.rst b/docs/deploy/python.rst
index ccdfec743d..3df5a08241 100644
--- a/docs/deploy/python.rst
+++ b/docs/deploy/python.rst
@@ -11,8 +11,7 @@ We also provide a web demo based on `gradio `_ as an exampl
Python API
----------
-The Python API is a part of the MLC-Chat package, which we have prepared pre-built pip wheels and you can install it by
-following the instructions in ``_.
+The Python API is a part of the MLC-Chat package, which we have prepared pre-built pip wheels via the :doc:`installation page <../install/mlc_llm>`.
Verify Installation
^^^^^^^^^^^^^^^^^^^
@@ -29,7 +28,7 @@ that supports other GPU runtime than the prebuilt version. Please refer our :ref
Get Started
^^^^^^^^^^^
After confirming that the package ``mlc_chat`` is installed, we can follow the steps
-below to chat with a MLC-compiled model in Python.
+below to chat with an MLC-compiled model in Python.
First, let us make sure that the MLC-compiled ``model`` we want to chat with already exists.
@@ -52,6 +51,11 @@ If you do not have the MLC-compiled ``model`` ready:
- Model lib should be placed at ``./dist/prebuilt/lib/$(model)-$(arch).$(suffix)``.
- Model weights and chat config are located under ``./dist/prebuilt/mlc-chat-$(model)/``.
+ .. note::
+ Please make sure that you have the same directory structure as above, because Python API
+ relies on it to automatically search for model lib and weights. If you would like to directly
+ provide a full model lib path to override the auto-search, you can specify ``ChatModule.model_lib_path``
+
.. collapse:: Example
.. code:: shell
@@ -75,6 +79,11 @@ If you do not have the MLC-compiled ``model`` ready:
- Model libraries should be placed at ``./dist/$(model)/$(model)-$(arch).$(suffix)``.
- Model weights and chat config are located under ``./dist/$(model)/params/``.
+ .. note::
+ Please make sure that you have the same directory structure as above, because Python API
+ relies on it to automatically search for model lib and weights. If you would like to directly
+ provide a full model lib path to override the auto-search, you can specify ``ChatModule.model_lib_path``
+
.. collapse:: Example
.. code:: shell
@@ -90,7 +99,7 @@ If you do not have the MLC-compiled ``model`` ready:
params_shard_*.bin
...
-After making sure that the files exist, using the conda environment you used
+After making sure that the files exist, use the conda environment you used
to install ``mlc_chat``, from the ``mlc-llm`` directory, you can create a Python
file ``sample_mlc_chat.py`` and paste the following lines:
@@ -158,7 +167,7 @@ You can also checkout the :doc:`/prebuilt_models` page to run other models.
|
.. note::
- You could also specify the address of ``model`` and ``lib_path`` explicitly. If
+ You could also specify the address of ``model`` and ``model_lib_path`` explicitly. If
you only specify ``model`` as ``model_name`` and ``quantize_mode``, we will
do a search for you. See more in the documentation of :meth:`mlc_chat.ChatModule.__init__`.
@@ -244,7 +253,7 @@ We provide an example below.
fields specified.
It is also worth noting that ``ConvConfig`` itself is overriding the original conversation template
- specified by the field ``conv_template`` in chat configuration. Learn more about it in
+ specified by the field ``conv_template`` in the chat configuration. Learn more about it in
:ref:`Configure MLCChat in JSON`.
Raw Text Generation in Python
@@ -263,7 +272,7 @@ We provide an example below.
# Use a `ConvConfig` to define the generation settings
# Since the "LM" template only supports raw text generation,
- # system prompts will not be executed even if provided
+ # System prompts will not be executed even if provided
conv_config = ConvConfig(stop_tokens=[2,], add_bos=True, stop_str="[INST]")
# Note that `conv_config` is an optional subfield of `chat_config`
@@ -297,6 +306,38 @@ We provide an example below.
Additionally, system prompts will not be run when instantiating a `mlc_chat.ChatModule`,
unless explicitly given inside the prompt.
+Stream Iterator in Python
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Stream Iterator gives users an option to stream generated text to the function that the API is called from,
+instead of streaming to stdout, which could be a necessity when building services on top of MLC Chat.
+
+We provide an example below.
+
+.. code:: python
+
+ from mlc_chat import ChatModule
+ from mlc_chat.callback import StreamIterator
+
+ # Create a ChatModule instance
+ cm = ChatModule(model="Llama-2-7b-chat-hf-q4f16_1")
+
+ # Stream to an Iterator
+ from threading import Thread
+
+ stream = StreamIterator(callback_interval=2)
+ generation_thread = Thread(
+ target=cm.generate,
+ kwargs={"prompt": "What is the meaning of life?", "progress_callback": stream},
+ )
+ generation_thread.start()
+
+ output = ""
+ for delta_message in stream:
+ output += delta_message
+
+ generation_thread.join()
+
API Reference
-------------
@@ -314,10 +355,19 @@ The :class:`mlc_chat.ChatModule` class provides the following methods:
.. automethod:: __init__
+.. autoclass:: ChatConfig
+ :members:
+
+.. autoclass:: ConvConfig
+ :members:
+
+.. autoclass:: GenerationConfig
+ :members:
+
Gradio Frontend
---------------
-The gradio frontend provides a web interface for the MLC-Chat model, which allows user to interact with the model in a more user-friendly way and switch between different models to compare performance.
+The gradio frontend provides a web interface for the MLC-Chat model, which allows users to interact with the model in a more user-friendly way and switch between different models to compare performance.
To use gradio frontend, you need to install gradio first:
.. code-block:: bash
@@ -335,7 +385,7 @@ Then you can run the following code to start the interface:
--port The port number to run gradio. The default value is ``7860``.
--share Whether to create a publicly shareable link for the interface.
-After setting up properly, you are expected to see the following interface in your browser:
+After setting it up properly, you are expected to see the following interface in your browser:
.. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/gradio-interface.png
:width: 100%
diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst
index d5955190e9..8451624fdb 100644
--- a/docs/deploy/rest.rst
+++ b/docs/deploy/rest.rst
@@ -6,13 +6,12 @@ Rest API
:depth: 2
We provide `REST API `_
-for user to interact with MLC-Chat in their own programs.
+for a user to interact with MLC-Chat in their own programs.
Install MLC-Chat Package
------------------------
-The REST API is a part of the MLC-Chat package, which we have prepared pre-built pip wheels and you can install it by
-following the instructions in ``_.
+The REST API is a part of the MLC-Chat package, which we have prepared pre-built :doc:`pip wheels <../install/mlc_llm>`.
Verify Installation
^^^^^^^^^^^^^^^^^^^
@@ -34,7 +33,7 @@ of mlc chat runtime. You only need to do this if you choose not to use the prebu
First, make sure you install TVM unity (following the instruction in :ref:`install-tvm-unity`).
You can choose to only pip install `mlc-ai-nightly` that comes with the tvm unity but skip `mlc-chat-nightly`.
-Then please follow the instruction in :ref:`mlcchat_build_from_source` to build the necessary libraries.
+Then please follow the instructions in :ref:`mlcchat_build_from_source` to build the necessary libraries.
You can now use ``mlc_chat`` package by including the `python` directory to ``PYTHONPATH`` environment variable.
@@ -89,6 +88,9 @@ The REST API provides the following endpoints:
Get the latest runtime stats (encode/decode speed).
+.. http:get:: /verbose_stats
+
+ Get the verbose runtime stats (encode/decode speed, total runtime).
Use REST API in your own program
--------------------------------
diff --git a/docs/index.rst b/docs/index.rst
index 89be4d4161..345b5d9603 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -11,13 +11,13 @@ Getting Started
---------------
To begin with, try out MLC LLM support for int4-quantized Llama2 7B.
-It is recommended to have at least 4.5GB of free VRAM to run it.
+It is recommended to have at least 6GB free VRAM to run it.
.. tabs::
.. tab:: Python
- **Install MLC Chat**. `MLC Chat `_ is available via pip.
+ **Install MLC Chat Python**. :doc:`MLC LLM ` is available via pip.
It is always recommended to install it in an isolated conda virtual environment.
**Download pre-quantized weights**. The comamnds below download the int4-quantized Llama2-7B from HuggingFace:
@@ -47,6 +47,8 @@ It is recommended to have at least 4.5GB of free VRAM to run it.
**Colab walkthrough.** A Jupyter notebook on `Colab `_
is provided with detailed walkthrough of the Python API.
+ **Documentation and tutorial.** Python API reference and its tutorials are `available online `_.
+
.. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-api.jpg
:width: 600
:align: center
@@ -209,6 +211,7 @@ It is recommended to have at least 4.5GB of free VRAM to run it.
:hidden:
install/tvm.rst
+ install/mlc_llm.rst
install/conda.rst
install/gpu.rst
install/emcc.rst
diff --git a/docs/install/gpu.rst b/docs/install/gpu.rst
index 48ac7a5e1f..608c238265 100644
--- a/docs/install/gpu.rst
+++ b/docs/install/gpu.rst
@@ -105,7 +105,7 @@ After installation, you can run ``vulkaninfo`` in command line and see if you ca
Vulkan SDK
----------
-Vulkan SDK is required for compiling models to Vulkan backend. To build TVM Unity compiler from source, you will need to install Vulkan SDK as a dependency, but our `pre-built wheels `__ already ships with Vulkan SDK.
+Vulkan SDK is required for compiling models to Vulkan backend. To build TVM Unity compiler from source, you will need to install Vulkan SDK as a dependency, but our :doc:`pre-built wheels <../install/mlc_llm>` already ships with Vulkan SDK.
Check Vulkan SDK installation guide according to your platform:
diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst
new file mode 100644
index 0000000000..13fc373dbf
--- /dev/null
+++ b/docs/install/mlc_llm.rst
@@ -0,0 +1,133 @@
+.. _install-mlc-packages:
+
+Install MLC LLM Python Package
+==============================
+
+.. contents:: Table of Contents
+ :local:
+ :depth: 2
+
+MLC LLM Python Package can be installed directly from a prebuilt developer package, or built from source.
+
+Option 1. Prebuilt Package
+--------------------------
+
+We provide nightly built pip wheels for MLC-LLM via pip.
+Select your operating system/compute platform and run the command in your terminal:
+
+.. note::
+ ❗ Whenever using Python, it is highly recommended to use **conda** to manage an isolated Python environment to avoid missing dependencies, incompatible versions, and package conflicts.
+
+.. tabs::
+
+ .. tab:: Linux
+
+ .. tabs::
+
+ .. tab:: CPU
+
+ .. code-block:: bash
+
+ conda activate your-environment
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly
+
+ .. tab:: CUDA 11.7
+
+ .. code-block:: bash
+
+ conda activate your-environment
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu117 mlc-ai-nightly-cu117
+
+ .. tab:: CUDA 11.8
+
+ .. code-block:: bash
+
+ conda activate your-environment
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu118 mlc-ai-nightly-cu118
+
+ .. tab:: CUDA 12.1
+
+ .. code-block:: bash
+
+ conda activate your-environment
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu121 mlc-ai-nightly-cu121
+
+ .. tab:: CUDA 12.2
+
+ .. code-block:: bash
+
+ conda activate your-environment
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu122 mlc-ai-nightly-cu122
+
+ .. tab:: ROCm 5.6
+
+ .. code-block:: bash
+
+ conda activate your-environment
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-rocm56 mlc-ai-nightly-rocm56
+
+ .. tab:: ROCm 5.7
+
+ .. code-block:: bash
+
+ conda activate your-environment
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-rocm57 mlc-ai-nightly-rocm57
+
+ .. tab:: Vulkan
+
+ Supported in all Linux packages.
+
+ .. note::
+
+ If encountering issues with GLIBC not found, please install the latest glibc in conda:
+
+ .. code-block:: bash
+
+ conda install -c conda-forge libgcc-ng
+
+ .. tab:: macOS
+
+ .. tabs::
+
+ .. tab:: CPU + Metal
+
+ .. code-block:: bash
+
+ conda activate your-environment
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly
+
+ .. note::
+
+ Always check if conda is installed properly in macOS using the command below:
+
+ .. code-block:: bash
+
+ conda info | grep platform
+
+ It should return "osx-64" for Mac with Intel chip, and "osx-arm64" for Mac with Apple chip.
+
+ .. tab:: Windows
+
+ .. tabs::
+
+ .. tab:: CPU + Vulkan
+
+ .. code-block:: bash
+
+ conda activate your-environment
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly
+
+ .. note::
+ If encountering the error below:
+
+ .. code-block:: bash
+
+ FileNotFoundError: Could not find module 'path\to\site-packages\tvm\tvm.dll' (or one of its dependencies). Try using the full path with constructor syntax.
+
+ It is likely `zstd`, a dependency to LLVM, was missing. Please `download `__ the 64 bit version of precompiled binary, rename it to `zstd.dll` and copy to the same folder as `tvm.dll`.
+
+
+Option 2. Build from Source
+---------------------------
+
+Upcoming.
diff --git a/docs/install/tvm.rst b/docs/install/tvm.rst
index 97ec1c9e40..c2b7998ada 100644
--- a/docs/install/tvm.rst
+++ b/docs/install/tvm.rst
@@ -37,49 +37,49 @@ A nightly prebuilt Python package of Apache TVM Unity is provided.
.. code-block:: bash
conda activate your-environment
- python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly
.. tab:: CUDA 11.7
.. code-block:: bash
conda activate your-environment
- python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-cu117
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu117
.. tab:: CUDA 11.8
.. code-block:: bash
conda activate your-environment
- python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-cu118
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu118
.. tab:: CUDA 12.1
.. code-block:: bash
conda activate your-environment
- python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-cu121
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu121
.. tab:: CUDA 12.2
.. code-block:: bash
conda activate your-environment
- python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-cu122
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu122
.. tab:: ROCm 5.6
.. code-block:: bash
conda activate your-environment
- python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-rocm56
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-rocm56
.. tab:: ROCm 5.7
.. code-block:: bash
conda activate your-environment
- python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-rocm57
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-rocm57
.. tab:: Vulkan
@@ -97,19 +97,12 @@ A nightly prebuilt Python package of Apache TVM Unity is provided.
.. tabs::
- .. tab:: CPU
-
- .. code-block:: bash
-
- conda activate your-environment
- python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly
-
- .. tab:: Metal
+ .. tab:: CPU + Metal
.. code-block:: bash
conda activate your-environment
- python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly
.. note::
@@ -125,16 +118,12 @@ A nightly prebuilt Python package of Apache TVM Unity is provided.
.. tabs::
- .. tab:: CPU
+ .. tab:: CPU + Vulkan
.. code-block:: bash
conda activate your-environment
- python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly
-
- .. tab:: Vulkan
-
- Supported in all Windows packages.
+ python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly
.. note::
If encountering the error below:
@@ -143,7 +132,7 @@ A nightly prebuilt Python package of Apache TVM Unity is provided.
FileNotFoundError: Could not find module 'path\to\site-packages\tvm\tvm.dll' (or one of its dependencies). Try using the full path with constructor syntax.
- It is likely `zstd`, a dependency to LLVM, was missing. Please `download `__ the precompiled binary, rename it to `zstd.dll` and copy to the same folder as `tvm.dll`.
+ It is likely `zstd`, a dependency to LLVM, was missing. Please `download `__ the precompiled binary, rename it to `zstd.dll` and copy to the same folder as `tvm.dll`. Hint - To locate the "tvm.dll" file in Conda, navigate to your user home directory (e.g., "/users/xxxx"). Search for "tvm.dll" and find the folder whose path contains the name of the current environment, such as "mlc-chat-venv." Once located, copy "zstd.dll" to that specific folder.
.. _tvm-unity-build-from-source:
diff --git a/docs/prebuilt_models.rst b/docs/prebuilt_models.rst
index 99a86780db..8a9f74253b 100644
--- a/docs/prebuilt_models.rst
+++ b/docs/prebuilt_models.rst
@@ -7,168 +7,164 @@ Model Prebuilts
:depth: 3
:local:
-MLC-LLM is a universal solution for deploying different language models. Any language models that can be described in `TVM Relax `__ (a general representation for Neural Networks and can be imported from models written in PyTorch) can be recognized by MLC-LLM and thus deployed to different backends with the help of :doc:`TVM Unity `.
+.. _model-prebuilts-overview:
-The community has already supported several LLM architectures (LLaMA, GPT-NeoX, etc.) and have prebuilt some models (Vicuna, RedPajama, etc.) which you can use off the shelf.
-With the goal of democratizing the deployment of LLMs, we eagerly anticipate further contributions from the community to expand the range of supported model architectures.
+Overview
+--------
-This page contains the list of prebuilt models for our CLI (command line interface) app, iOS and Android apps.
-The models have undergone extensive testing on various devices, and their performance has been optimized by developers with the help of TVM.
+MLC-LLM is a universal solution for deploying different language models. Any models that can be described in `TVM Relax `__
+(a general representation for Neural Networks and can be imported from models written in PyTorch) can be recognized by MLC-LLM and thus deployed to different backends with the
+help of :doc:`TVM Unity `.
-.. _prebuilt-models-cli:
+There are two ways to run a model on MLC-LLM:
-Prebuilt Models for CLI
------------------------
+1. Compile your own models following :doc:`the model compilation page `.
+2. Use off-the-shelf prebuilts models following this current page.
-.. list-table::
- :widths: 15 15 15 15
- :header-rows: 1
+This page focuses on the second option:
- * - Model code
- - Original Model
- - Quantization Mode
- - Hugging Face repo
- * - `Llama-2-{7, 13, 70}b-chat-hf-q4f16_1`
- - `Llama-2 `__
- - * Weight storage data type: int4
- * Running data type: float16
- * Symmetric quantization
- - * `7B link `__
- * `13B link `__
- * `70B link `__
- * - `vicuna-v1-7b-q3f16_0`
- - `Vicuna `__
- - * Weight storage data type: int3
- * Running data type: float16
- * Symmetric quantization
- - `link `__
- * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1`
- - `RedPajama `__
- - * Weight storage data type: int4
- * Running data type: float16
- * Symmetric quantization
- - `link `__
- * - `rwkv-raven-{1b5, 3b, 7b}-q8f16_0`
- - `RWKV `__
- - * Weight storage data type: uint8
- * Running data type: float16
- * Symmetric quantization
- - * `1b5 link `__
- * `3b link `__
- * `7b link `__
- * - `WizardLM-13B-V1.2-{q4f16_1, q4f32_1}`
- - `WizardLM `__
- - * Weight storage data type: int4
- * Running data type: float{16, 32}
- * Symmetric quantization
- - * `q4f16_1 link `__
- * `q4f32_1 link `__
- * - `WizardCoder-15B-V1.0-{q4f16_1, q4f32_1}`
- - `WizardCoder `__
- - * Weight storage data type: int4
- * Running data type: float{16, 32}
- * Symmetric quantization
- - * `q4f16_1 link `__
- * `q4f32_1 link `__
- * - `WizardMath-{7, 13, 70}B-V1.0-q4f16_1`
- - `WizardMath `__
- - * Weight storage data type: int4
- * Running data type: float16
- * Symmetric quantization
- - * `7B link `__
- * `13B link `__
- * `70B link `__
- * - `llama2-7b-chat-uncensored-{q4f16_1, q4f32_1}`
- - `georgesung `__
- - * Weight storage data type: int4
- * Running data type: float{16, 32}
- * Symmetric quantization
- - * `q4f16_1 link `__
- * `q4f32_1 link `__
- * - `Llama2-Chinese-7b-Chat-{q4f16_1, q4f32_1}`
- - `FlagAlpha `__
- - * Weight storage data type: int4
- * Running data type: float{16, 32}
- * Symmetric quantization
- - * `q4f16_1 link `__
- * `q4f32_1 link `__
- * - `GOAT-7B-Community-{q4f16_1, q4f32_1}`
- - `GOAT-AI `__
- - * Weight storage data type: int4
- * Running data type: float{16, 32}
- * Symmetric quantization
- - * `q4f16_1 link `__
- * `q4f32_1 link `__
- * - `OpenOrca-Platypus2-13B-q4f16_1`
- - `Llama-2 `__
- - * Weight storage data type: int4
- * Running data type: float16
- * Symmetric quantization
- - `link `__
-
-To download and run one model with CLI, follow the instructions below:
-
-.. code:: shell
-
- # Create conda environment and install CLI if you have not installed.
- conda create -n mlc-chat-venv -c mlc-ai -c conda-forge mlc-chat-cli-nightly
- conda activate mlc-chat-venv
- conda install git git-lfs
- git lfs install
-
- # Download prebuilt model binary libraries from GitHub if you have not downloaded.
- mkdir -p dist/prebuilt
- git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt/lib
-
- # Download prebuilt model weights and run CLI.
- cd dist/prebuilt
- git clone https://huggingface.co/mlc-ai/mlc-chat-[model-code]
- cd ../..
- mlc_chat_cli --model [model-code]
-
- # e.g.,
- # cd dist/prebuilt
- # git clone https://huggingface.co/mlc-ai/mlc-chat-rwkv-raven-7b-q8f16_0
- # cd ../..
- # mlc_chat_cli --model rwkv-raven-7b-q8f16_0
-
-
-.. _prebuilt-models-ios:
-
-Prebuilt Models for iOS
------------------------
-
-.. list-table:: Prebuilt models for iOS
- :widths: 15 15 15 15
- :header-rows: 1
+- Documenting :ref:`how to use prebuilts ` for various platforms, and
+- Tracking what current :ref:`prebuilt models we provide `.
+
+Prerequisite: Model Libraries and Compiled Weights
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+In order to run a specific model on MLC-LLM, you need:
+
+**1. A model library:** a binary file containing the end-to-end functionality to inference a model (e.g. ``Llama-2-7b-chat-hf-q4f16_1-cuda.so``). See the full list of all precompiled model libraries `here `__.
+
+**2. Compiled weights:** a folder containing multiple files that store the compiled and quantized weights of a model (e.g. https://huggingface.co/mlc-ai/mlc-chat-Llama-2-7b-chat-hf-q4f16_1). See the full list of all precompiled weights `here `__.
+
+.. _using-model-prebuilts:
+
+Using Prebuilt Models for Different Platforms
+---------------------------------------------
+
+We quickly go over how to use prebuilt models for each platform. You can find detailed instruction on each platform's corresponding page.
+
+.. _using-prebuilt-models-cli:
+
+
+Prebuilt Models on CLI / Python
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+For more, please see :doc:`the CLI page `, and the :doc:`the Python page `.
+
+.. collapse:: Click to show details
+
+ First create the conda environment if you have not done so.
+
+ .. code:: shell
+
+ conda create -n mlc-chat-venv -c mlc-ai -c conda-forge mlc-chat-cli-nightly
+ conda activate mlc-chat-venv
+ conda install git git-lfs
+ git lfs install
+
+ Download the prebuilt model libraries from github.
+
+ .. code:: shell
+
+ mkdir -p dist/prebuilt
+ git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt/lib
+
+ Download the prebuilt model weights from hugging face for the model variant you want.
+
+ .. code:: shell
+
+ # Say we want to run rwkv-raven-7b-q8f16_0
+ cd dist/prebuilt
+ git clone https://huggingface.co/mlc-ai/mlc-chat-rwkv-raven-7b-q8f16_0
+ cd ../..
+
+ # The format being:
+ # cd dist/prebuilt
+ # git clone https://huggingface.co/mlc-ai/mlc-chat-[model-code]
+ # cd ../..
+ # mlc_chat_cli --model [model-code]
+
+ Run the model with CLI:
- * - Model code
- - Model Series
- - Quantization Mode
- - Hugging Face repo
- * - `Llama-2-7b-q3f16_1`
- - `Llama `__
- - * Weight storage data type: int3
- * Running data type: float16
- * Symmetric quantization
- - `link `__
- * - `vicuna-v1-7b-q3f16_0`
- - `Vicuna `__
- - * Weight storage data type: int3
- * Running data type: float16
- * Symmetric quantization
- - `link `__
- * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1`
- - `RedPajama `__
- - * Weight storage data type: int4
- * Running data type: float16
- * Symmetric quantization
- - `link `__
-
-The `downloadable iOS app `_ has builtin RedPajama-3B model support.
-To add a model to the iOS app, follow the steps below:
-
-.. collapse:: Click to show instructions
+ .. code:: shell
+
+ # For CLI
+ mlc_chat_cli --model rwkv-raven-7b-q8f16_0
+
+ To run the model with Python API, see :doc:`the Python page ` (all other downloading steps are the same as CLI).
+
+
+.. for a blank line
+
+|
+
+.. _using-prebuilt-models-ios:
+
+Prebuilt Models on iOS
+^^^^^^^^^^^^^^^^^^^^^^
+
+For more, please see :doc:`the iOS page `.
+
+.. collapse:: Click to show details
+
+ The `iOS app `_ has builtin RedPajama-3B and Llama-2-7b support.
+
+ All prebuilt models with an entry in ``iOS`` in the :ref:`model library table ` are supported by iOS. Namely, we have:
+
+ .. list-table:: Prebuilt model libraries integrated in the iOS app
+ :widths: 15 15 15
+ :header-rows: 1
+
+ * - Model library name
+ - Model Family
+ - Quantization Mode
+ * - `Llama-2-7b-chat-hf-q3f16_1`
+ - LLaMA
+ - * Weight storage data type: int3
+ * Running data type: float16
+ * Symmetric quantization
+ * - `vicuna-v1-7b-q3f16_0`
+ - LLaMA
+ - * Weight storage data type: int3
+ * Running data type: float16
+ * Symmetric quantization
+ * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1`
+ - GPT-NeoX
+ - * Weight storage data type: int4
+ * Running data type: float16
+ * Symmetric quantization
+
+ As for prebuilt model weights, the ones we have integrated into app are listed below:
+
+ .. list-table:: Tested prebuilt model weights for iOS
+ :widths: 15 15 15 15
+ :header-rows: 1
+
+ * - Model code
+ - Model Series
+ - Quantization Mode
+ - Hugging Face repo
+ * - `Llama-2-7b-q3f16_1`
+ - `Llama `__
+ - * Weight storage data type: int3
+ * Running data type: float16
+ * Symmetric quantization
+ - `link `__
+ * - `vicuna-v1-7b-q3f16_0`
+ - `Vicuna `__
+ - * Weight storage data type: int3
+ * Running data type: float16
+ * Symmetric quantization
+ - `link `__
+ * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1`
+ - `RedPajama `__
+ - * Weight storage data type: int4
+ * Running data type: float16
+ * Symmetric quantization
+ - `link `__
+
+ To run a model variant you compiled on your own, you can directly reuse the above integrated prebuilt model libraries, as long as the model shares the architecture and is compiled with the same quantization mode. For example, if you compile `OpenLLaMA-7B `_ with quantization mode ``q3f16_0``, then you can run the compiled OpenLLaMA model on iPhone without rebuilding the iOS app by reusing the `vicuna-v1-7b-q3f16_0` model library. Then you can upload the compiled weights to hugging face so that you can download the weights in the app as shown below (for more on uploading to hugging face, please check the :doc:`model distribution page `).
+
+ To add a model to the iOS app, follow the steps below:
.. tabs::
@@ -210,126 +206,623 @@ To add a model to the iOS app, follow the steps below:
|
-The iOS app has integrated with the following model libraries, which can be directly reused when you want to run a model you compiled in iOS, as long as the model is in the supported model family and is compiled with supported quantization mode.
-For example, if you compile `OpenLLaMA-7B `_ with quantization mode ``q3f16_0``, then you can run the compiled OpenLLaMA model on iPhone without rebuilding the iOS app by reusing the `vicuna-v1-7b-q3f16_0` model library. Please check the :doc:`model distribution page ` for detailed instructions.
-
-.. list-table:: Prebuilt model libraries which are integrated in the iOS app
- :widths: 15 15 15
- :header-rows: 1
-
- * - Model library name
- - Model Family
- - Quantization Mode
- * - `Llama-2-7b-chat-hf-q3f16_1`
- - LLaMA
- - * Weight storage data type: int3
- * Running data type: float16
- * Symmetric quantization
- * - `vicuna-v1-7b-q3f16_0`
- - LLaMA
- - * Weight storage data type: int3
- * Running data type: float16
- * Symmetric quantization
- * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1`
- - GPT-NeoX
- - * Weight storage data type: int4
- * Running data type: float16
- * Symmetric quantization
-
-
.. _prebuilt-models-android:
-Prebuilt Models for Android
----------------------------
-
-.. list-table:: Prebuilt models for Android
- :widths: 15 15 15 15
- :header-rows: 1
+Prebuilt Models on Android
+^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+For more, please see :doc:`the Android page `.
+
+.. collapse:: Click to show details
+
+ The apk for demo Android app includes the following models. To add more, check out the Android page.
+
+ .. list-table:: Prebuilt Models for Android
+ :widths: 15 15 15 15
+ :header-rows: 1
+
+ * - Model code
+ - Model Series
+ - Quantization Mode
+ - Hugging Face repo
+ * - `Llama-2-7b-q4f16_1`
+ - `Llama `__
+ - * Weight storage data type: int4
+ * Running data type: float16
+ * Symmetric quantization
+ - `link `__
+ * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1`
+ - `RedPajama `__
+ - * Weight storage data type: int4
+ * Running data type: float16
+ * Symmetric quantization
+ - `link `__
+.. for a blank line
- * - Model code
- - Model Series
- - Quantization Mode
- - Hugging Face repo
- * - `vicuna-v1-7b-q4f16_1`
- - `Vicuna `__
- - * Weight storage data type: int4
- * Running data type: float16
- * Symmetric quantization
- - `link `__
- * - `RedPajama-INCITE-Chat-3B-v1-q4f16_0`
- - `RedPajama `__
- - * Weight storage data type: int4
- * Running data type: float16
- * Symmetric quantization
- - `link `__
+|
-------------------
+.. _supported-model-architectures:
-You can check `MLC-LLM pull requests `__ to track the ongoing efforts of new models. We encourage users to upload their compiled models to Hugging Face and share with the community.
+Level 1: Supported Model Architectures (The All-In-One Table)
+-------------------------------------------------------------
-.. _supported-model-architectures:
+For each model architecture (e.g. Llama), there are multiple variants (e.g. CodeLlama, WizardLM). The variants share the same code for inference and only differ in their weights. In other words, running CodeLlama and WizardLM can use the same model library file (specified in Level 2 tables), but different precompiled weights (specified in Level 3 tables). Note that we have not provided prebuilt weights for all model variants.
-Supported Model Architectures
------------------------------
+Each entry below hyperlinks to the corresponding level 2 and level 3 tables.
MLC-LLM supports the following model architectures:
.. list-table:: Supported Model Architectures
- :widths: 15 15 15 15
+ :widths: 10 10 15 15
:header-rows: 1
- * - Category Code
- - Series
- - Model Definition
- - Variants
- * - ``llama``
- - `LLaMa `__
- - `Relax Code `__
- - * `Llama-2 `__
- * `Alpaca `__
- * `Vicuna `__
+ * - Model Architecture
+ - Support
+ - Available MLC Prebuilts
+ - Unavailable in MLC Prebuilts
+ * - `LLaMA `__
+ - * :ref:`Prebuilt Model Library `
+ * `MLC Implementation `__
+ - * :ref:`Llama-2 `
+ * :ref:`Code Llama `
+ * :ref:`Vicuna `
+ * :ref:`WizardLM `
+ * :ref:`WizardMath `
+ * :ref:`OpenOrca Platypus2 `
+ * :ref:`FlagAlpha Llama-2 Chinese `
+ * :ref:`georgesung Llama-2 Uncensored `
+ - * `Alpaca `__
* `Guanaco `__
* `OpenLLaMA `__
* `Gorilla `__
- * `WizardLM `__
* `YuLan-Chat `__
- * `WizardMath `__
- * `FlagAlpha Llama-2 Chinese `__
- * - ``gpt-neox``
- - `GPT-NeoX `__
- - `Relax Code `__
- - * `RedPajama `__
- * `Dolly `__
+ * `WizardCoder (new) `__
+ * - `GPT-NeoX `__
+ - * :ref:`Prebuilt Model Library `
+ * `MLC Implementation `__
+ - * :ref:`RedPajama `
+ - * `Dolly `__
* `Pythia `__
* `StableCode `__
- * - ``gptj``
- - `GPT-J `__
- - `Relax Code `__
+ * - `GPT-J `__
+ - * Prebuilt not compiled yet
+ * `MLC Implementation `__
+ -
- * `MOSS `__
- * - ``rwkv``
- - `RWKV `__
- - `Relax Code `__
- - * `RWKV-raven `__
- * - ``minigpt``
- - `MiniGPT `__
- - `Relax Code `__
- -
- * - ``gpt_bigcode``
- - `GPTBigCode `__
- - `Relax Code `__
+ * - `RWKV `__
+ - * :ref:`Prebuilt Model Library `
+ * `MLC Implementation `__
+ - * :ref:`RWKV-raven `
+ -
+ * - `MiniGPT `__
+ - * Prebuilt not compiled yet
+ * `MLC Implementation `__
+ -
+ - * `MiniGPT-4 `__
+ * - `GPTBigCode `__
+ - * :ref:`Prebuilt Model Library `
+ * `MLC Implementation `__
+ - * :ref:`WizardCoder (old) `
- * `StarCoder `__
- * `WizardCoder `__
* `SantaCoder `__
- * - ``chatglm``
- - `ChatGLM `__
- - `Relax Code `__
+ * - `ChatGLM `__
+ - * Prebuilt not compiled yet
+ * `MLC Implementation `__
+ -
- * `ChatGLM2 `__
* `CodeGeeX2 `__
+ * - `StableLM `__
+ - * Prebuilt not compiled yet
+ * `MLC Implementation