Skip to content
/ BlaGPT Public

Experimental playground for benchmarking language model (LM) architectures, layers, and tricks on smaller datasets. Designed for flexible experimentation and exploration.

Notifications You must be signed in to change notification settings

erogol/BlaGPT

Repository files navigation

BlaGPT

Experimental playground for benchmarking language model (LM) architectures, layers, and tricks on smaller datasets. Designed for flexible experimentation and exploration.

BlaGPT Model

BlaGPT is a flexible Transformer implementation that you can turn on/off following things in the config.

Multi-token prediction - link

Weight tying - link

Grouped query attention - link

Capping logits - link

QKV bias - link

Zero-init projection layer - link

Post and pre-RMSNorm - link

Setting base theta to 1_000_000 - llama3 - increased the final validation loss - best 3.3324

Z-loss regularization - link - increased the final validation loss by 0.02 - loss: 3.3527

KV-Shifting attention - link - seems to improve performance - loss: 3.3310 -> 3.3138 - peak memory consumption: 42858 MiB

Dilated Attention (LongNet) - link

Other Models

MegaByte - link - loss: 3.810

FTP (heavily modified) - link - loss: 3.901

Rene - link - loss: 3.340

Rwkv7 - link - loss: 4.450

Zamba2 - link - Zamba2 > Rene > Rwkv7

Hourglass Transformer (modified) - link - Hourglass > MegaByte > FTP - loss: 3.710

Hymba - link - train step time is significantly slower than the transformers. Best validation loss so far: 4.7505

Tokenformer (in BlaGPT model) - link - loss: 3.390

Optimizers

PaLMForeachSOAP - link - almost 2 times slower than Adam but the best results

Ademamix - link - Unstable even after trying different learning rates.

Adopt - link - straight up Nan

CAdamW - link - loss: 3.3517

AdamW with independent weight decay - link - loss: 3.320

Adam - loss: 3.3224

AdamW - loss: 3.3310, peak VRAM: 42053 MiB, step_time: 533ms

DeMo - link - Saves 7 GB per GPU, loss is higher than baseline, step time is slower than Adam - loss: 3.4676, peak VRAM: 41534 MiB, step_time: 820ms

Adam-Mini - link - loss is higher than Adam and AdamW and also slower ??, saved a bit of VRAM - loss: 3.3324, peak VRAM: 41534 MiB, step_time: 610ms

Best Model So Far

BlaGPT with the following configurations:

{
    "params": {
      "norm_layer": "rmsnorm",
      "attention": "GQA",
      "activation": "swiglu",
      "tie_embed_weights": true,
      "zero_init_proj_layers": true,
      "use_rotary_emb": true,
      "rmsnorm_before_qk": true
    },
    "config": {
      "block_size": 1024,
      "vocab_size": 50304,
      "n_layer": 12,
      "n_head": 12,
      "n_embd": 768,
      "dropout": 0.0,
      "bias": true,
      "norm_layer": "rmsnorm",
      "attention": "GQA",
      "activation": "swiglu",
      "use_soft_logit_capping": false,
      "n_kv_head": 4,
      "tie_embed_weights": true,
      "zero_init_proj_layers": true,
      "rmsnorm_before_qk": true,
      "use_rotary_emb": true
    },
    "val_loss": 3.2993,
    "memory_usage": 49403,
  },

Adding a New Model

  • Implement the model
  • Return the loss in the forward function
  • Register the model
  • And start training

See one of the implementations for details.

Training

  • Get the data by running data/fineweb10B_cached.py

  • Start training with:

torchrun --standalone --nproc_per_node=8 train.py --run_name pre_post_norm --model_name blagpt

Acknowledgements

The initial code is based on

Nano GPT - link

Modded NanoGPT - link

About

Experimental playground for benchmarking language model (LM) architectures, layers, and tricks on smaller datasets. Designed for flexible experimentation and exploration.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages