Skip to content

Commit

Permalink
Add DeMo and Adam-mini
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Dec 3, 2024
1 parent e2c2d98 commit ba8e700
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 14 deletions.
32 changes: 19 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# BlaGPT

A collection of LM architectures, layers, and tricks to easily benchmark on a relatively small dataset. It is created fully for my experiments in my free time with no serious intentions.
Experimental playground for benchmarking language model (LM) architectures, layers, and tricks on smaller datasets. Designed for flexible experimentation and exploration.

## BlaGPT Model
BlaGPT is a flexible Transformer implementation that you can turn on/off following things in the config.
Expand All @@ -21,26 +21,28 @@ Post and pre-RMSNorm - [link](https://arxiv.org/pdf/2408.00118)

Setting base theta to 1_000_000 - [llama3](https://github.com/meta-llama/llama3/blob/main/llama/model.py#L49) - increased the final validation loss - best `3.3324`

Z-loss regularization - [link](https://arxiv.org/pdf/2309.14322) - increased the final validation loss by 0.02 - best: `3.3527`
Z-loss regularization - [link](https://arxiv.org/pdf/2309.14322) - increased the final validation loss by 0.02 - loss: `3.3527`

KV-Shifting attention - [link](https://arxiv.org/abs/2411.19574) - best: `3.331` - peak memory consumption: `42858 MiB`
KV-Shifting attention - [link](https://arxiv.org/abs/2411.19574) - seems to improve performance - loss: `3.3310` -> `3.3138` - peak memory consumption: `42858 MiB`

Dilated Attention (LongNet) - [link](https://arxiv.org/pdf/2307.02486)

## Other Models
MegaByte - [link](https://arxiv.org/abs/2305.07185) - best: `3.810`
MegaByte - [link](https://arxiv.org/abs/2305.07185) - loss: `3.810`

FTP (heavily modified) - [link](https://arxiv.org/pdf/2410.18160) - best: `3.901`
FTP (heavily modified) - [link](https://arxiv.org/pdf/2410.18160) - loss: `3.901`

Rene - [link](https://huggingface.co/cartesia-ai/Rene-v0.1-1.3b-pytorch) - best: `3.340`
Rene - [link](https://huggingface.co/cartesia-ai/Rene-v0.1-1.3b-pytorch) - loss: `3.340`

Rwkv7 - [link](https://github.com/BlinkDL/RWKV-LM) - best: `4.450`
Rwkv7 - [link](https://github.com/BlinkDL/RWKV-LM) - loss: `4.450`

Zamba2 - [link](https://huggingface.co/Zyphra/Zamba2-2.7B) - Zamba2 > Rene > Rwkv7

Hourglass Transformer (modified) - [link](https://arxiv.org/abs/2110.13711) - Hourglass > MegaByte > FTP - best: `3.710`
Hourglass Transformer (modified) - [link](https://arxiv.org/abs/2110.13711) - Hourglass > MegaByte > FTP - loss: `3.710`

Hymba - [link](https://arxiv.org/html/2411.13676v1) - train step time is significantly slower than the transformers. Best validation loss so far: `4.7505`

Tokenformer (in BlaGPT model) - [link](https://github.com/Haiyang-W/TokenFormer) - best: `3.390`
Tokenformer (in BlaGPT model) - [link](https://github.com/Haiyang-W/TokenFormer) - loss: `3.390`

## Optimizers
PaLMForeachSOAP - [link](https://github.com/ClashLuke/HeavyBall) - almost 2 times slower than Adam but the best results
Expand All @@ -49,13 +51,17 @@ Ademamix - [link](https://github.com/nanowell/AdEMAMix-Optimizer-Pytorch/blob/ma

Adopt - [link](https://github.com/iShohei220/adopt) - straight up Nan

CAdamW - [link](https://github.com/kyleliang919/C-Optim/blob/main/c_adamw.py) - best: `3.3517`
CAdamW - [link](https://github.com/kyleliang919/C-Optim/blob/main/c_adamw.py) - loss: `3.3517`

AdamW with independent weight decay - [link](https://arxiv.org/pdf/2309.14322) - loss: `3.320`

Adam - loss: `3.3224`

AdamW with independent weight decay - [link](https://arxiv.org/pdf/2309.14322) - best: `3.320`
AdamW - loss: `3.3310`, peak VRAM: `42053 MiB`, step_time: `533ms`

Adam - best: `3.3224`
DeMo - [link](https://arxiv.org/abs/2411.19870) - Saves 7 GB per GPU, loss is higher than baseline, step time is slower than Adam - loss: `3.4676`, peak VRAM: `41534 MiB`, step_time: `820ms`

AdamW - best: `3.3310`
Adam-Mini - [link]() - loss is higher than Adam and AdamW and also slower ??, saved a bit of VRAM - loss: `3.3324`, peak VRAM: `41534 MiB`, step_time: `610ms`

## Best Model So Far
BlaGPT with the following configurations:
Expand Down
23 changes: 22 additions & 1 deletion bla_gpt/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ def get_optimizer(
module = importlib.import_module("optimizers.radam")
optimizer = getattr(module, "RAdam")
elif optimizer_name.lower() == "palm_soap":
from heavyball import PaLMForeachSOAP
try:
from heavyball import PaLMForeachSOAP
except ImportError:
raise ImportError(
"To use PaLMForeachSOAP, please install the heavyball package."
)

module = PaLMForeachSOAP
elif optimizer_name.lower() == "ademamix":
Expand All @@ -45,6 +50,22 @@ def get_optimizer(
elif optimizer_name.lower() == "c_adamw":
module = importlib.import_module("optimizers.c_adamw")
optimizer = getattr(module, "AdamW")
elif optimizer_name.lower() == "demo":
module = importlib.import_module("optimizers.demo")
optimizer = getattr(module, "DeMo")
elif optimizer_name.lower() == "adam_mini":
try:
from adam_mini import Adam_mini
except ImportError:
raise ImportError("To use Adam_mini, please install the adam-mini package.")

optimizer = Adam_mini(
named_parameters=model.named_parameters(), lr=lr, **optimizer_params
)
optimizer.wqk_names.add("kv_proj")
optimizer.attn_proj_names.add("c_proj")

return optimizer
else:
optimizer = getattr(torch.optim, optimizer_name)
return optimizer(parameters, lr=lr, **optimizer_params)

0 comments on commit ba8e700

Please sign in to comment.