From 5aef0dd3a31f191726a4f5a18301d871dce74cf7 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 24 May 2023 16:31:14 -0700 Subject: [PATCH] Isolate model fetching in a separate process (#227) * Isolate model fetching in a separate process * update target detection logic * update metal compilation callback --- build.py | 3 +- mlc_llm/relax_model/gpt_neox.py | 12 +- mlc_llm/transformers.py | 69 ++++++++ mlc_llm/utils.py | 280 ++++++++++++++++++++------------ 4 files changed, 248 insertions(+), 116 deletions(-) create mode 100644 mlc_llm/transformers.py diff --git a/build.py b/build.py index 0c2d33f128..271d7c9acd 100644 --- a/build.py +++ b/build.py @@ -352,6 +352,7 @@ def main(): cache_path = os.path.join( ARGS.artifact_path, f"mod_cache_before_build_{ARGS.target_kind}.pkl" ) + ARGS.raw_params_path = os.path.join(ARGS.artifact_path, "raw_params") use_cache = ARGS.use_cache and os.path.isfile(cache_path) with open(os.path.join(ARGS.model_path, "config.json"), encoding="utf-8") as i_f: config = json.load(i_f) @@ -380,7 +381,7 @@ def main(): if not ARGS.reuse_lib: build(mod, ARGS) else: - print("Reuse existing preuilt lib {ARGS.reuse_lib}...") + print("Reuse existing prebuilt lib {ARGS.reuse_lib}...") dump_default_mlc_chat_config(ARGS) diff --git a/mlc_llm/relax_model/gpt_neox.py b/mlc_llm/relax_model/gpt_neox.py index 15d1089cce..329543c490 100644 --- a/mlc_llm/relax_model/gpt_neox.py +++ b/mlc_llm/relax_model/gpt_neox.py @@ -1,7 +1,6 @@ # pylint: disable=missing-docstring,too-few-public-methods,too-many-instance-attributes,invalid-name,too-many-locals,too-many-arguments import argparse import math -from dataclasses import dataclass from typing import List, Optional, Tuple, Union import tvm @@ -21,6 +20,7 @@ from tvm.runtime import NDArray from tvm.script import relax as R +from .. import transformers from .commons import create_metadata_func from .modules import ( Embedding, @@ -595,8 +595,6 @@ def get_model( args: argparse.Namespace, hf_config, ): - from transformers import AutoModelForCausalLM # type: ignore[import] - model = args.model dtype = args.quantization.model_dtype ffn_out_dtype = "float32" @@ -623,9 +621,10 @@ def get_model( hidden_size = config.hidden_size head_dim = hidden_size // num_heads param_list: List[Tuple[str, NDArray]] = [] - hf_model = AutoModelForCausalLM.from_pretrained(args.model_path) - for name, param in hf_model.named_parameters(): - param = param.detach().cpu().numpy() + for name, param in transformers.get_model( + args.model_path, + args.raw_params_path, + ): if param.dtype == "float32": if "layernorm" in name or "layer_norm" in name or "embed_out" in name: param = param.astype("float32") @@ -655,7 +654,6 @@ def get_model( param_list.append((name.format("v"), v)) else: param_list.append((name, param)) - del hf_model param_list = [ ( name, diff --git a/mlc_llm/transformers.py b/mlc_llm/transformers.py new file mode 100644 index 0000000000..983eba5748 --- /dev/null +++ b/mlc_llm/transformers.py @@ -0,0 +1,69 @@ +# pylint: disable=import-outside-toplevel +def _get_model_worker(_args) -> None: + import json + import os + + import numpy as np + from transformers import AutoModelForCausalLM # type: ignore[import] + + model: str + dump_path: str + model, dump_path = _args + config_path = os.path.join(dump_path, "config.json") + if os.path.exists(config_path): + print("Model weights already exist under:", dump_path) + return + + print("Extracting weights for model:", model) + hf_model = AutoModelForCausalLM.from_pretrained( + model, + trust_remote_code=True, + ) + params = [ + ( + name, + param.detach().cpu().numpy(), + ) + for name, param in hf_model.named_parameters() + ] + del hf_model + + os.makedirs(dump_path, exist_ok=True) + for i, (name, param) in enumerate(params): + param_path = os.path.join(dump_path, f"param_{i}.npy") + np.save(param_path, param) + + with open(config_path, "w", encoding="utf-8") as o_f: + json.dump( + [name for name, _ in params], + o_f, + ) + print("Model weights dumped to:", dump_path) + + +def get_model(model: str, dump_path: str): + import json + import multiprocessing + import os + from typing import List, Tuple + + import numpy as np + from tqdm import tqdm + + with multiprocessing.Pool(processes=1) as pool: + result = pool.map( + _get_model_worker, + [ + (model, dump_path), + ], + ) + print("Loading model weights from:", dump_path) + config_path = os.path.join(dump_path, "config.json") + with open(config_path, "r", encoding="utf-8") as i_f: + config = json.load(i_f) + param_dict: List[Tuple[str, np.ndarray]] = [] + for i, name in tqdm(enumerate(config), total=len(config)): + param_path = os.path.join(dump_path, f"param_{i}.npy") + param_dict.append((name, np.load(param_path))) + print("Loading done") + return param_dict diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index ac09f30526..1e155ab738 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -3,7 +3,6 @@ import os import shutil from dataclasses import dataclass -from platform import system from typing import List, Tuple import tvm @@ -126,13 +125,8 @@ def transform_params( transform_func_name = gv.name_hint assert transform_func_name is not None - if tvm.cuda().exist: - target = "cuda" - elif tvm.metal().exist: - target = "metal" - else: - target = "llvm" - target = tvm.target.Target(target) + target = detect_local_target() + print(f"Automatically using target for weight quantization: {target}") device = tvm.device(target.kind.default_keys[0]) @tvm.register_func("get_item", override=True) @@ -257,63 +251,115 @@ def get_database(db_paths: str) -> ms.Database: return db +def _detect_local_metal(): + dev = tvm.metal() + if not dev.exist: + return None + return tvm.target.Target( + { + "kind": "metal", + "max_shared_memory_per_block": 32768, + "max_threads_per_block": dev.max_threads_per_block, + "thread_warp_size": 32, + }, + host=tvm.target.Target( # TODO: assuming ARM mac for now + { + "kind": "llvm", + "mtriple": "arm64-apple-macos", + "mcpu": "apple-latest", + } + ), + ) + + +def _detect_local_cuda(): + dev = tvm.cuda() + if not dev.exist: + return None + return tvm.target.Target( + { + "kind": "cuda", + "max_shared_memory_per_block": dev.max_shared_memory_per_block, + "max_threads_per_block": dev.max_threads_per_block, + "thread_warp_size": dev.warp_size, + "registers_per_block": 65536, + "arch": "sm_" + tvm.cuda().compute_version.replace(".", ""), + } + ) + + +def _detect_local_vulkan(): + dev = tvm.vulkan() + if not dev.exist: + return None + return tvm.target.Target( + { + "kind": "vulkan", + "max_threads_per_block": dev.max_threads_per_block, + "max_shared_memory_per_block": dev.max_shared_memory_per_block, + "thread_warp_size": dev.warp_size, + "supports_float16": 1, + "supports_int16": 1, + "supports_16bit_buffer": 1, + } + ) + + +def detect_local_target(): + dev = tvm.metal() + if dev.exist: + return tvm.target.Target("apple/m1-gpu") + + for method in [ + _detect_local_metal, + _detect_local_cuda, + _detect_local_vulkan, + ]: + target = method() + if target is not None: + return target + + print("Failed to detect local GPU, falling back to CPU as a target") + return tvm.target.Target("llvm") + + def parse_target(args: argparse.Namespace) -> None: if not hasattr(args, "target"): return if args.target == "auto": - if system() == "Darwin": - target = tvm.target.Target("apple/m1-gpu") - elif tvm.cuda().exist: - dev = tvm.cuda() - target = tvm.target.Target( - { - "kind": "cuda", - "max_shared_memory_per_block": dev.max_shared_memory_per_block, - "max_threads_per_block": dev.max_threads_per_block, - "thread_warp_size": dev.warp_size, - "registers_per_block": 65536, - "arch": "sm_" + tvm.cuda().compute_version.replace(".", ""), - } - ), - elif tvm.vulkan().exist: - dev = tvm.vulkan() + target = detect_local_target() + if target.host is None: target = tvm.target.Target( - { - "kind": "vulkan", - "max_threads_per_block": dev.max_threads_per_block, - "max_shared_memory_per_block": dev.max_shared_memory_per_block, - "thread_warp_size": dev.warp_size, - "supports_float16": 1, - "supports_int16": 1, - "supports_16bit_buffer": 1, - } - ), - else: - has_gpu = tvm.cuda().exist + target, + host="llvm", # TODO: detect host CPU + ) + args.target = target + args.target_kind = args.target.kind.default_keys[0] + elif args.target == "metal": + target = _detect_local_metal() + if target is None: + print("Cannot detect local Apple Metal GPU target! Falling back...") target = tvm.target.Target( - "cuda" # TODO: cuda details are required, for example, max shared memory - if has_gpu - else "llvm" + tvm.target.Target( + { + "kind": "metal", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + } + ), + host=tvm.target.Target( # TODO: assuming ARM mac for now + { + "kind": "llvm", + "mtriple": "arm64-apple-macos", + "mcpu": "apple-latest", + } + ), ) - print(f"Automatically configuring target: {target}") - args.target = tvm.target.Target(target, host="llvm") + args.target = target args.target_kind = args.target.kind.default_keys[0] - elif args.target == "webgpu": - args.target = tvm.target.Target( - "webgpu", host="llvm -mtriple=wasm32-unknown-unknown-wasm" - ) - args.target_kind = "webgpu" - args.lib_format = "wasm" - args.system_lib = True - elif args.target.startswith("iphone"): - from tvm.contrib import tar, xcode # pylint: disable=import-outside-toplevel - - # override - @tvm.register_func("tvm_callback_metal_compile") - def compile_metal(src): - return xcode.compile_metal(src, sdk="iphoneos") - - dylib = args.target == "iphone-dylib" + elif args.target == "metal_x86_64": + from tvm.contrib import xcode # pylint: disable=import-outside-toplevel args.target = tvm.target.Target( tvm.target.Target( @@ -324,12 +370,19 @@ def compile_metal(src): "thread_warp_size": 1, } ), - host="llvm -mtriple=arm64-apple-darwin", + host="llvm -mtriple=x86_64-apple-darwin", ) - args.target_kind = "iphone" - args.export_kwargs = {"fcompile": tar.tar} + args.target_kind = "metal_x86_64" + args.export_kwargs = { + "fcompile": xcode.create_dylib, + "sdk": "macosx", + "arch": "x86_64", + } + args.lib_format = "dylib" + elif args.target in ["iphone", "iphone-dylib", "iphone-tar"]: + from tvm.contrib import tar, xcode # pylint: disable=import-outside-toplevel - if dylib: + if args.target == "iphone-dylib": args.export_kwargs = { "fcompile": xcode.create_dylib, "sdk": "iphoneos", @@ -337,23 +390,65 @@ def compile_metal(src): } args.lib_format = "dylib" else: + args.export_kwargs = {"fcompile": tar.tar} args.lib_format = "tar" args.system_lib = True - system_lib_prefix = f"{args.model}-{args.quantization}_" - args.system_lib_prefix = system_lib_prefix.replace("-", "_") - - elif args.target.startswith("android"): - # android-opencl - from tvm.contrib import cc, ndk + args.system_lib_prefix = f"{args.model}_{args.quantization}_".replace( + "-", "_" + ) - dylib = args.target == "android-dylib" + @tvm.register_func("tvm_callback_metal_compile") + def compile_metal(src, target): + if target.libs: + return xcode.compile_metal(src, sdk=target.libs[0]) + return xcode.compile_metal(src) + target = tvm.target.Target( + tvm.target.Target( + { + "kind": "metal", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + "libs": ["iphoneos"], + } + ), + host="llvm -mtriple=arm64-apple-darwin", + ) + args.target = target + args.target_kind = "iphone" + elif args.target == "vulkan": + target = _detect_local_vulkan() + if target is None: + print("Cannot detect local Vulkan GPU target! Falling back...") + target = tvm.target.Target( + tvm.target.Target( + { + "kind": "vulkan", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + "supports_float16": 1, + "supports_int16": 1, + "supports_16bit_buffer": 1, + } + ), + host="llvm", + ) + args.target = target + args.target_kind = args.target.kind.default_keys[0] + elif args.target == "webgpu": args.target = tvm.target.Target( - "opencl", - host="llvm -mtriple=aarch64-linux-android", # Only support arm64 for now + "webgpu", + host="llvm -mtriple=wasm32-unknown-unknown-wasm", ) - args.target_kind = "android" - if dylib: + args.target_kind = "webgpu" + args.lib_format = "wasm" + args.system_lib = True + elif args.target in ["android", "android-dylib"]: # android-opencl + from tvm.contrib import cc, ndk + + if args.target == "android-dylib": args.export_kwargs = { "fcompile": ndk.create_shared, } @@ -364,44 +459,11 @@ def compile_metal(src): } args.lib_format = "a" args.system_lib = True - - elif args.target == "vulkan": args.target = tvm.target.Target( - tvm.target.Target( - { - "kind": "vulkan", - "max_threads_per_block": 256, - "max_shared_memory_per_block": 32768, - "thread_warp_size": 1, - "supports_float16": 1, - "supports_int16": 1, - "supports_16bit_buffer": 1, - } - ), - host="llvm", - ) - args.target_kind = args.target.kind.default_keys[0] - elif args.target == "metal_x86_64": - from tvm.contrib import xcode # pylint: disable=import-outside-toplevel - - args.target = tvm.target.Target( - tvm.target.Target( - { - "kind": "metal", - "max_threads_per_block": 256, - "max_shared_memory_per_block": 32768, - "thread_warp_size": 1, - } - ), - host="llvm -mtriple=x86_64-apple-darwin", + "opencl", + host="llvm -mtriple=aarch64-linux-android", # TODO: Only support arm64 for now ) - args.target_kind = "metal_x86_64" - args.export_kwargs = { - "fcompile": xcode.create_dylib, - "sdk": "macosx", - "arch": "x86_64", - } - args.lib_format = "dylib" + args.target_kind = "android" else: args.target = tvm.target.Target(args.target, host="llvm") args.target_kind = args.target.kind.default_keys[0] @@ -420,3 +482,5 @@ def compile_metal(src): } args.target = args.target.with_host("llvm -mtriple=x86_64-w64-windows-gnu") args.lib_format = "dll" + + print(f"Target configured: {args.target}")