Skip to content

Commit

Permalink
Isolate model fetching in a separate process (octoml#227)
Browse files Browse the repository at this point in the history
* Isolate model fetching in a separate process

* update target detection logic

* update metal compilation callback
  • Loading branch information
junrushao authored May 24, 2023
1 parent eed5a28 commit 5aef0dd
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 116 deletions.
3 changes: 2 additions & 1 deletion build.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def main():
cache_path = os.path.join(
ARGS.artifact_path, f"mod_cache_before_build_{ARGS.target_kind}.pkl"
)
ARGS.raw_params_path = os.path.join(ARGS.artifact_path, "raw_params")
use_cache = ARGS.use_cache and os.path.isfile(cache_path)
with open(os.path.join(ARGS.model_path, "config.json"), encoding="utf-8") as i_f:
config = json.load(i_f)
Expand Down Expand Up @@ -380,7 +381,7 @@ def main():
if not ARGS.reuse_lib:
build(mod, ARGS)
else:
print("Reuse existing preuilt lib {ARGS.reuse_lib}...")
print("Reuse existing prebuilt lib {ARGS.reuse_lib}...")
dump_default_mlc_chat_config(ARGS)


Expand Down
12 changes: 5 additions & 7 deletions mlc_llm/relax_model/gpt_neox.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# pylint: disable=missing-docstring,too-few-public-methods,too-many-instance-attributes,invalid-name,too-many-locals,too-many-arguments
import argparse
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import tvm
Expand All @@ -21,6 +20,7 @@
from tvm.runtime import NDArray
from tvm.script import relax as R

from .. import transformers
from .commons import create_metadata_func
from .modules import (
Embedding,
Expand Down Expand Up @@ -595,8 +595,6 @@ def get_model(
args: argparse.Namespace,
hf_config,
):
from transformers import AutoModelForCausalLM # type: ignore[import]

model = args.model
dtype = args.quantization.model_dtype
ffn_out_dtype = "float32"
Expand All @@ -623,9 +621,10 @@ def get_model(
hidden_size = config.hidden_size
head_dim = hidden_size // num_heads
param_list: List[Tuple[str, NDArray]] = []
hf_model = AutoModelForCausalLM.from_pretrained(args.model_path)
for name, param in hf_model.named_parameters():
param = param.detach().cpu().numpy()
for name, param in transformers.get_model(
args.model_path,
args.raw_params_path,
):
if param.dtype == "float32":
if "layernorm" in name or "layer_norm" in name or "embed_out" in name:
param = param.astype("float32")
Expand Down Expand Up @@ -655,7 +654,6 @@ def get_model(
param_list.append((name.format("v"), v))
else:
param_list.append((name, param))
del hf_model
param_list = [
(
name,
Expand Down
69 changes: 69 additions & 0 deletions mlc_llm/transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# pylint: disable=import-outside-toplevel
def _get_model_worker(_args) -> None:
import json
import os

import numpy as np
from transformers import AutoModelForCausalLM # type: ignore[import]

model: str
dump_path: str
model, dump_path = _args
config_path = os.path.join(dump_path, "config.json")
if os.path.exists(config_path):
print("Model weights already exist under:", dump_path)
return

print("Extracting weights for model:", model)
hf_model = AutoModelForCausalLM.from_pretrained(
model,
trust_remote_code=True,
)
params = [
(
name,
param.detach().cpu().numpy(),
)
for name, param in hf_model.named_parameters()
]
del hf_model

os.makedirs(dump_path, exist_ok=True)
for i, (name, param) in enumerate(params):
param_path = os.path.join(dump_path, f"param_{i}.npy")
np.save(param_path, param)

with open(config_path, "w", encoding="utf-8") as o_f:
json.dump(
[name for name, _ in params],
o_f,
)
print("Model weights dumped to:", dump_path)


def get_model(model: str, dump_path: str):
import json
import multiprocessing
import os
from typing import List, Tuple

import numpy as np
from tqdm import tqdm

with multiprocessing.Pool(processes=1) as pool:
result = pool.map(
_get_model_worker,
[
(model, dump_path),
],
)
print("Loading model weights from:", dump_path)
config_path = os.path.join(dump_path, "config.json")
with open(config_path, "r", encoding="utf-8") as i_f:
config = json.load(i_f)
param_dict: List[Tuple[str, np.ndarray]] = []
for i, name in tqdm(enumerate(config), total=len(config)):
param_path = os.path.join(dump_path, f"param_{i}.npy")
param_dict.append((name, np.load(param_path)))
print("Loading done")
return param_dict
Loading

0 comments on commit 5aef0dd

Please sign in to comment.