From a11e3f89ae01c7e9d2001e167a1b6ed6820e4ab9 Mon Sep 17 00:00:00 2001 From: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com> Date: Sun, 24 Dec 2023 14:13:30 +0100 Subject: [PATCH] shard llama model after conversion and unshard on loading --- llms/llama/convert.py | 24 ++++++++++++++++++++++-- llms/llama/llama.py | 20 ++++++++++++++++++-- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/llms/llama/convert.py b/llms/llama/convert.py index dae337ee6..6f5285c34 100644 --- a/llms/llama/convert.py +++ b/llms/llama/convert.py @@ -15,7 +15,6 @@ from llama import Llama, ModelArgs, sanitize_config from mlx.utils import tree_flatten, tree_map, tree_unflatten - def llama(model_path): SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"] SHARD_SECOND = ["tok_embeddings", "wo", "w2"] @@ -140,6 +139,22 @@ def quantize(weights, config, args): return quantized_weights, quantized_config +def make_shards(weights: dict, max_file_size_gibibyte: int = 15): + max_file_size_bytes = max_file_size_gibibyte << 30 + shards = [] + shard, shard_size = {}, 0 + for k, v in weights.items(): + # TODO: simplify to v.nbytes as soon as mx.array exposes it + estimated_size = v.size * v.dtype.size if isinstance(v, mx.array) else v.nbytes + if shard_size + estimated_size > max_file_size_bytes: + shards.append(shard) + shard, shard_size = {}, 0 + shard[k] = v + shard_size += estimated_size + shards.append(shard) + return shards + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") parser.add_argument( @@ -200,6 +215,11 @@ def quantize(weights, config, args): str(torch_path / "tokenizer.model"), str(mlx_path / "tokenizer.model"), ) - np.savez(str(mlx_path / "weights.npz"), **weights) + shards = make_shards(weights) + if len(shards) == 1: + np.savez(str(mlx_path / f"weights.npz"), **shards[0]) + else: + for i, shard in enumerate(shards): + np.savez(str(mlx_path / f"weights.{i:02d}.npz"), **shard) with open(mlx_path / "config.json", "w") as fid: json.dump(params, fid, indent=4) diff --git a/llms/llama/llama.py b/llms/llama/llama.py index d684ed6d3..97ec41016 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -3,6 +3,7 @@ import argparse import json import time +import glob from dataclasses import dataclass from pathlib import Path from typing import Optional, Tuple @@ -330,7 +331,23 @@ def sanitize_config(config, weights): def load_model(model_path): model_path = Path(model_path) - weights = mx.load(str(model_path / "weights.npz")) + + unsharded_weights_path = Path(model_path / "weights.npz") + if unsharded_weights_path.is_file(): + print("[INFO] Loading model from {}.".format(unsharded_weights_path)) + weights = mx.load(str(unsharded_weights_path)) + else: + sharded_weights_glob = str(model_path / "weights.*.npz") + weight_files = glob.glob(sharded_weights_glob) + print("[INFO] Loading model from {}.".format(sharded_weights_glob)) + + if len(weight_files) == 0: + raise FileNotFoundError("No weights found in {}".format(model_path)) + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) + with open(model_path / "config.json", "r") as f: config = sanitize_config(json.loads(f.read()), weights) quantization = config.pop("quantization", None) @@ -373,7 +390,6 @@ def load_model(model_path): mx.random.seed(args.seed) - print("[INFO] Loading model from disk.") model, tokenizer = load_model(args.model_path) if args.few_shot: few_shot_generate(args)