-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Your Name
committed
Jan 10, 2025
1 parent
001f6c7
commit f79e9ea
Showing
4 changed files
with
125 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
name_model = "10B" | ||
project = "debug_I2_zero_band" | ||
|
||
[train] | ||
micro_bs = 1 | ||
ac_ckpt = true | ||
|
||
[optim] | ||
sched_type = "wsd-sqrt" | ||
batch_size = 128 #1M tokens bs | ||
warmup_steps = 0 | ||
total_steps = 1 | ||
lr = 7.5e-5 | ||
|
||
adam_betas1 = 0.9 | ||
adam_betas2 = 0.95 | ||
weight_decay = 0.1 | ||
|
||
z_loss = true | ||
|
||
[data] | ||
seq_length = 8192 | ||
num_workers = 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
|
||
|
||
from typing import Literal | ||
|
||
from pydantic import model_validator | ||
from pydantic_config import BaseConfig | ||
|
||
from zeroband.checkpoint import CkptConfig | ||
from zeroband.data import DataConfig | ||
from zeroband.diloco import DilocoConfig | ||
from zeroband.models.llama.model import AttnFnType | ||
|
||
|
||
class OptimConfig(BaseConfig): | ||
lr: float = 4e-4 | ||
weight_decay: float = 0.1 | ||
adam_betas1: float = 0.9 | ||
adam_betas2: float = 0.95 | ||
|
||
sched_type: Literal["cosine", "linear", "wsd-sqrt"] = "cosine" | ||
warmup_steps: int = 1000 | ||
stable_steps: int = 80_000 | ||
total_steps: int = 88_000 | ||
batch_size: int = 512 | ||
|
||
z_loss: bool = False | ||
z_loss_weight: float = 2e-4 | ||
|
||
|
||
class MemoryProfilerConfig(BaseConfig): | ||
freq: int = 10 | ||
snapshot_dir: str | ||
|
||
|
||
class TrainConfig(BaseConfig): | ||
micro_bs: int | ||
torch_compile: bool = True | ||
ac_ckpt: bool | int = False | ||
reshard_after_forward: bool = True # old shard grad op True mean full shard | ||
|
||
reduce_fp32: bool = False # should be True if SXM. Keep to false as default for backward compatibility | ||
|
||
log_model_hash: bool = False | ||
|
||
memory_profiler: MemoryProfilerConfig | None = None | ||
|
||
sequence_packing: bool = True | ||
|
||
attn_fn: AttnFnType = "flex" | ||
|
||
|
||
class MonitorConfig(BaseConfig): | ||
log_flush_interval: int = 10 | ||
base_url: str | None = None | ||
auth_token: str | None = None | ||
|
||
|
||
class Config(BaseConfig): | ||
# main config | ||
name_model: Literal["debugmodel", "150M", "271M", "1B", "7B", "10B", "13B", "26B", "70B"] = "150M" | ||
type_model: Literal["llama2", "llama3"] = "llama3" | ||
|
||
project: str = "zeroband" | ||
run_id: str | None = None | ||
metric_logger_type: Literal["wandb", "dummy"] = "wandb" | ||
wandb_resume: bool = False | ||
|
||
# sub config | ||
diloco: DilocoConfig | None = None | ||
data: DataConfig = DataConfig() | ||
optim: OptimConfig = OptimConfig() | ||
train: TrainConfig | ||
monitor: MonitorConfig | None = None | ||
|
||
ckpt: CkptConfig = CkptConfig() | ||
|
||
@model_validator(mode="after") | ||
def ckpt_diloco_step(self): | ||
if self.ckpt is not None and self.ckpt.interval is not None and self.diloco is not None: | ||
assert ( | ||
self.ckpt.interval % self.diloco.inner_steps == 0 | ||
), "ckpt interval must be a multiple of diloco inner steps as we only save at the end of an outer step" | ||
return self | ||
|
||
@model_validator(mode="after") | ||
def validate_live_recovery_rank_src(self): | ||
if self.ckpt is not None and self.ckpt.live_recovery_rank_src is not None and self.diloco is None: | ||
raise ValueError("live_recovery_rank_src is only supported with diloco") | ||
return self | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters