Skip to content

Commit

Permalink
Cover with more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBraquet committed Sep 26, 2024
1 parent 3f31e4a commit d7ec593
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 89 deletions.
1 change: 0 additions & 1 deletion llm/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pathlib import Path

from llm import BASE_DIR
from llm.utils import DataclassUtils


# @dataclass
Expand Down
57 changes: 29 additions & 28 deletions llm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def get_num_params(self, non_embedding=True):
n_params -= self.transformer.wpe.weight.numel()
return n_params

def _init_weights(self, module):
@staticmethod
def _init_weights(module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
Expand Down Expand Up @@ -325,24 +326,24 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):

return optimizer

def estimate_mfu(self, fwdbwd_per_iter, dt):
""" estimate model flops utilization (MFU) """
# first estimate the number of flops we do per iteration.
# see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
gpu_info = get_gpu_info()
if gpu_info is None:
return
N = self.get_num_params()
cfg = self.config
L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd // cfg.n_head, cfg.block_size
flo_per_token = 6 * N + 12 * L * H * Q * T
flo_per_fwdbwd = flo_per_token * T
flo_per_iter = flo_per_fwdbwd * fwdbwd_per_iter
# express our flops throughput as ratio of A100 bfloat16 peak flops
flops_achieved = flo_per_iter / dt # per second
flops_promised = gpu_info['flops']
mfu = flops_achieved / flops_promised
return mfu
# def estimate_mfu(self, fwdbwd_per_iter, dt):
# """ estimate model flops utilization (MFU) """
# # first estimate the number of flops we do per iteration.
# # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
# gpu_info = get_gpu_info()
# if gpu_info is None:
# return
# N = self.get_num_params()
# cfg = self.config
# L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd // cfg.n_head, cfg.block_size
# flo_per_token = 6 * N + 12 * L * H * Q * T
# flo_per_fwdbwd = flo_per_token * T
# flo_per_iter = flo_per_fwdbwd * fwdbwd_per_iter
# # express our flops throughput as ratio of A100 bfloat16 peak flops
# flops_achieved = flo_per_iter / dt # per second
# flops_promised = gpu_info['flops']
# mfu = flops_achieved / flops_promised
# return mfu

@torch.no_grad()
def generate(self, idx, max_tokens, temperature=1.0, top_k=None):
Expand Down Expand Up @@ -372,12 +373,12 @@ def generate(self, idx, max_tokens, temperature=1.0, top_k=None):
return idx


@lru_cache
def get_gpu_info():
if ... == 'A100 bfloat16':
flops = 312e12
else:
return
return dict(
flops=flops
)
# @lru_cache
# def get_gpu_info():
# if ... == 'A100 bfloat16':
# flops = 312e12
# else:
# return
# return dict(
# flops=flops
# )
1 change: 1 addition & 0 deletions llm/tests/prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Hello darling
7 changes: 6 additions & 1 deletion llm/tests/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_sample(self):
text = sampler.generate_text(prompt='Love is the answer to', max_tokens=20)
self.assertEqual('Love is the answer toN4fpPLbNK\\9A3Necys\n"', text)

def test_sample_from_file(self):
def test_config_file(self):
config_file = DIR / 'config.json'
text = Sampler(config_file=config_file, model_path=model_path).generate_text(max_tokens=20)
self.assertEqual('\nN4fpPLbNK\\9A3Necys\n"', text)
Expand All @@ -42,3 +42,8 @@ def test_sample_from_online(self):
sampler = Sampler(init_from='online', model_path='gpt2')
text = sampler.generate_text(prompt='The sun', max_tokens=10)
self.assertEqual('The sun has been shining for weeks, but it has set', text)

def test_file_prompt(self):
sampler = Sampler(model_path=model_path)
text = sampler.generate_text(prompt='FILE:prompt.txt', max_tokens=10)
self.assertEqual(23, len(text))
76 changes: 52 additions & 24 deletions llm/tests/test_train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import shutil
from pathlib import Path
from unittest import TestCase
from unittest.mock import patch
Expand All @@ -9,70 +10,85 @@

class FakeWandb:
@classmethod
def init(cls, **kwargs):
def init(cls, *args, **kwargs):
pass

@classmethod
def log(cls, **kwargs):
def log(cls, *args, **kwargs):
pass


BASE_CONFIG = dict(
log_interval=9,
eval_interval=2,
eval_iters=2,
batch_size=1,
block_size=4,
n_layer=2,
n_head=2,
n_embd=4,
dropout=0.2,
learning_rate=10e-3,
lr_decay_iters=20,
min_lr=5e-3,
beta2=0.99,
patience=2
)


class TestTrain(TestCase):
@classmethod
def setUpClass(cls):
shutil.rmtree(DIR / 'data', ignore_errors=True)

def test_train(self):
config = dict(
init_from="scratch",
model_path=DIR / 'results' / ".out_test_train",
training_data_path="prince",
torch_compile=False,
log_interval=9,
eval_interval=2,
eval_iters=2,
batch_size=1,
block_size=4,
n_layer=2,
n_head=2,
n_embd=4,
dropout=0.2,
learning_rate=10e-3,
max_iters=20,
lr_decay_iters=20,
min_lr=5e-3,
beta2=0.99,
patience=2
**BASE_CONFIG,
)
trainer = Trainer(**config)
trainer.run()

self.assertFalse(trainer.resume())

def test_train_resume(self):
results_dir = DIR / 'results'
shutil.copytree(
results_dir / 'test1',
results_dir / '.test_train_resume',
dirs_exist_ok=True,
)
trainer = Trainer(
init_from="resume",
model_path=DIR / 'results' / "test1",
model_path=results_dir / ".test_train_resume",
training_data_path="prince",
max_iters=20,
max_iters=30,
)
trainer.run()

@patch('wandb.init', new=FakeWandb.init)
@patch('wandb.log', new=FakeWandb.log)
def test_train_wandb(self):
trainer = Trainer(
init_from="resume",
model_path=DIR / 'results' / "test1",
init_from="scratch",
model_path=DIR / 'results' / ".train_wandb",
training_data_path="prince",
max_iters=20,
**BASE_CONFIG,
wandb_log=True,
)
trainer.run()

def test_train_profile(self):
trainer = Trainer(
init_from="resume",
model_path=DIR / 'results' / "test1",
init_from="scratch",
model_path=DIR / 'results' / ".test_train_profile",
training_data_path="prince",
max_iters=20,
max_iters=1,
**BASE_CONFIG,
profile=True,
)
trainer.run()
Expand All @@ -83,9 +99,21 @@ def test_train_finetune(self):
model_path=DIR / 'results' / ".gpt2",
training_data_path="prince",
max_iters=2,
block_size=8,
)
trainer.load_model()
model = trainer.model
size = model.get_num_params()
self.assertEqual(123653376, size)

def test_vocab_size(self):
trainer = Trainer(
init_from="scratch",
model_path=DIR / 'results' / ".test_vocab_size",
training_data_path="prince",
encoding='char',
max_iters=2,
**BASE_CONFIG,
)
trainer.run()

12 changes: 2 additions & 10 deletions llm/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from tempfile import TemporaryDirectory
from unittest import TestCase

from llm.utils import get_last_checkpoint, box, to_path
from llm.utils import get_last_checkpoint, box, to_path, unbox

DIR = Path(__file__).parent

Expand All @@ -15,14 +15,6 @@ def test_last_checkpoint(self):
get_last_checkpoint(tmpdir)
self.assertEqual(f'no checkpoints found in {tmpdir}', str(e.exception))

Path(tmpdir, 'ckpt_0001.pt').touch()
Path(tmpdir, 'ckpt_init.pt').touch()
Path(tmpdir, 'ckpt_0002.pt').touch()
self.assertEqual('ckpt_0002.pt', get_last_checkpoint(tmpdir))

def test_box(self):
self.assertEqual([1], box(1))
self.assertEqual([1], box([1]))

def test_path(self):
self.assertEqual(Path('a'), to_path('a'))
self.assertEqual(Path('a'), to_path(Path('a')))
30 changes: 17 additions & 13 deletions llm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,27 +506,16 @@ def run(self):
def load_model(self):
checkpoint = None
if self.init_from == 'scratch':
# init a new model from scratch
print("Initializing a new model from scratch")

# attempt to derive vocab_size from the dataset
meta_vocab_size = None
meta_path = self.data_dir / 'meta.pkl'
if os.path.exists(meta_path):
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
stoi, itos = meta.get('stoi'), meta.get('itos')
if stoi is not None and itos is not None:
assert len(stoi) == len(itos), f"{len(stoi)} != {len(itos)}"
meta_vocab_size = len(stoi)
print(f"found vocab_size = {meta_vocab_size}")

# determine the vocab size we'll use for from-scratch training
meta_vocab_size = self.get_vocab_size()
if meta_vocab_size is not None:
self.model_args['vocab_size'] = meta_vocab_size
else:
print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
self.model_args['vocab_size'] = 50304

gpt_conf = GPTConfig(**self.model_args)
self.model = GPT(gpt_conf)
elif self.init_from == 'resume':
Expand Down Expand Up @@ -577,6 +566,21 @@ def load_model(self):
self.compile_model()
return optimizer, scaler

def get_vocab_size(self):
"""
Attempt to derive vocab_size from the dataset
"""
meta_path = self.data_dir / 'meta.pkl'
if os.path.exists(meta_path):
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
stoi, itos = meta.get('stoi'), meta.get('itos')
if stoi is not None and itos is not None:
assert len(stoi) == len(itos), f"{len(stoi)} != {len(itos)}"
meta_vocab_size = len(stoi)
print(f"found vocab_size = {meta_vocab_size}")
return meta_vocab_size

def _profile_ctx(self):
"""
Useful docs on pytorch profiler:
Expand Down
45 changes: 33 additions & 12 deletions llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,34 +36,49 @@ def parse_model_path(model_path):
def unbox(e):
"""
Returns the only element of e if it has only one element, otherwise returns e
>>> unbox(1)
1
>>> unbox([1])
1
>>> unbox([1, 2, 3])
[1, 2, 3]
>>> unbox({'a': 1})
1
>>> unbox({'a': 1, 'b': 2})
{'a': 1, 'b': 2}
"""
if isinstance(e, str):
return e
default = e
if isinstance(e, dict):
e = e.values()
return next(iter(e)) if hasattr(e, '__len__') and len(e) == 1 else e
return next(iter(e)) if hasattr(e, '__len__') and len(e) == 1 else default


def box(e):
"""
Box a single element into a list
:param e:
:return:
>>> box(1)
[1]
>>> box([1])
[1]
"""
if isinstance(e, (list, tuple, set)):
return e
return [e]


@dataclass
class DataclassUtils:

@classmethod
def keys(cls):
return list(cls.__dataclass_fields__)

def dict(self):
return {k: getattr(self, k) for k in self.keys()}
# @dataclass
# class DataclassUtils:
#
# @classmethod
# def keys(cls):
# return list(cls.__dataclass_fields__)
#
# def dict(self):
# return {k: getattr(self, k) for k in self.keys()}


@lru_cache
Expand All @@ -82,6 +97,12 @@ def list_to_hash(items):


def to_path(s):
"""
>>> to_path('a')
PosixPath('a')
>>> to_path(Path('a'))
PosixPath('a')
"""
if isinstance(s, str):
return Path(s)
return s
Expand Down

0 comments on commit d7ec593

Please sign in to comment.