Skip to content

Commit

Permalink
Split out config.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Jan 10, 2025
1 parent 001f6c7 commit f79e9ea
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 86 deletions.
23 changes: 23 additions & 0 deletions configs/10B/H100_devel.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name_model = "10B"
project = "debug_I2_zero_band"

[train]
micro_bs = 1
ac_ckpt = true

[optim]
sched_type = "wsd-sqrt"
batch_size = 128 #1M tokens bs
warmup_steps = 0
total_steps = 1
lr = 7.5e-5

adam_betas1 = 0.9
adam_betas2 = 0.95
weight_decay = 0.1

z_loss = true

[data]
seq_length = 8192
num_workers = 4
90 changes: 90 additions & 0 deletions src/zeroband/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@


from typing import Literal

from pydantic import model_validator
from pydantic_config import BaseConfig

from zeroband.checkpoint import CkptConfig
from zeroband.data import DataConfig
from zeroband.diloco import DilocoConfig
from zeroband.models.llama.model import AttnFnType


class OptimConfig(BaseConfig):
lr: float = 4e-4
weight_decay: float = 0.1
adam_betas1: float = 0.9
adam_betas2: float = 0.95

sched_type: Literal["cosine", "linear", "wsd-sqrt"] = "cosine"
warmup_steps: int = 1000
stable_steps: int = 80_000
total_steps: int = 88_000
batch_size: int = 512

z_loss: bool = False
z_loss_weight: float = 2e-4


class MemoryProfilerConfig(BaseConfig):
freq: int = 10
snapshot_dir: str


class TrainConfig(BaseConfig):
micro_bs: int
torch_compile: bool = True
ac_ckpt: bool | int = False
reshard_after_forward: bool = True # old shard grad op True mean full shard

reduce_fp32: bool = False # should be True if SXM. Keep to false as default for backward compatibility

log_model_hash: bool = False

memory_profiler: MemoryProfilerConfig | None = None

sequence_packing: bool = True

attn_fn: AttnFnType = "flex"


class MonitorConfig(BaseConfig):
log_flush_interval: int = 10
base_url: str | None = None
auth_token: str | None = None


class Config(BaseConfig):
# main config
name_model: Literal["debugmodel", "150M", "271M", "1B", "7B", "10B", "13B", "26B", "70B"] = "150M"
type_model: Literal["llama2", "llama3"] = "llama3"

project: str = "zeroband"
run_id: str | None = None
metric_logger_type: Literal["wandb", "dummy"] = "wandb"
wandb_resume: bool = False

# sub config
diloco: DilocoConfig | None = None
data: DataConfig = DataConfig()
optim: OptimConfig = OptimConfig()
train: TrainConfig
monitor: MonitorConfig | None = None

ckpt: CkptConfig = CkptConfig()

@model_validator(mode="after")
def ckpt_diloco_step(self):
if self.ckpt is not None and self.ckpt.interval is not None and self.diloco is not None:
assert (
self.ckpt.interval % self.diloco.inner_steps == 0
), "ckpt interval must be a multiple of diloco inner steps as we only save at the end of an outer step"
return self

@model_validator(mode="after")
def validate_live_recovery_rank_src(self):
if self.ckpt is not None and self.ckpt.live_recovery_rank_src is not None and self.diloco is None:
raise ValueError("live_recovery_rank_src is only supported with diloco")
return self

3 changes: 3 additions & 0 deletions src/zeroband/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,9 @@ def load_all_datasets(
split_rank = rank
split_world_size = world_size


logger.info(f"Loading Train dataset(s)")

ds = _load_datasets(
dataset_names=data_config.dataset_name_or_paths,
split=split,
Expand Down
95 changes: 9 additions & 86 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os
from typing import Literal
import time
from pydantic import model_validator
from multiprocessing.process import _children

import torch
from pydantic_config import parse_argv, BaseConfig
from pydantic_config import parse_argv
from einops import rearrange
from torch.nn import functional as F

Expand All @@ -15,10 +13,11 @@

import torch.distributed as dist
from zeroband import utils
from zeroband.diloco import Diloco, DilocoConfig
from zeroband.diloco import Diloco
from zeroband.comms import ElasticDeviceMesh
from zeroband.loss import cross_entropy_max_z_loss
from zeroband.models.llama.model import AttnFnType, create_block_mask_from_seqlens
from zeroband.models.llama.model import create_block_mask_from_seqlens
from zeroband.config import Config, MemoryProfilerConfig

from zeroband.utils import (
FakeTokenizer,
Expand All @@ -28,95 +27,16 @@
get_tensor_list_signature,
)
from zeroband.utils.activation_ckpt import apply_ac_ckpt
from zeroband.data import TEST_VOCAB_SIZE, get_dataloader, DataConfig
from zeroband.data import TEST_VOCAB_SIZE, get_dataloader
from zeroband.utils.metric_logger import MetricLogger, WandbMetricLogger, DummyMetricLogger
from zeroband.utils.monitor import HttpMonitor
from zeroband.models.llama import get_model
from zeroband.utils.profiler import MemoryProfiler
from zeroband.utils.world_info import get_world_info
from zeroband.utils.logging import get_logger
from zeroband.checkpoint import CkptConfig, CkptManager, TrainingProgress
from zeroband.checkpoint import CkptManager, TrainingProgress
from zeroband.lr_scheduler import get_scheduler


class OptimConfig(BaseConfig):
lr: float = 4e-4
weight_decay: float = 0.1
adam_betas1: float = 0.9
adam_betas2: float = 0.95

sched_type: Literal["cosine", "linear", "wsd-sqrt"] = "cosine"
warmup_steps: int = 1000
stable_steps: int = 80_000
total_steps: int = 88_000
batch_size: int = 512

z_loss: bool = False
z_loss_weight: float = 2e-4


class MemoryProfilerConfig(BaseConfig):
freq: int = 10
snapshot_dir: str


class TrainConfig(BaseConfig):
micro_bs: int
torch_compile: bool = True
ac_ckpt: bool | int = False
reshard_after_forward: bool = True # old shard grad op True mean full shard

reduce_fp32: bool = False # should be True if SXM. Keep to false as default for backward compatibility

log_model_hash: bool = False

memory_profiler: MemoryProfilerConfig | None = None

sequence_packing: bool = True

attn_fn: AttnFnType = "flex"


class MonitorConfig(BaseConfig):
log_flush_interval: int = 10
base_url: str | None = None
auth_token: str | None = None


class Config(BaseConfig):
# main config
name_model: Literal["debugmodel", "150M", "271M", "1B", "7B", "10B", "13B", "26B", "70B"] = "150M"
type_model: Literal["llama2", "llama3"] = "llama3"

project: str = "zeroband"
run_id: str | None = None
metric_logger_type: Literal["wandb", "dummy"] = "wandb"
wandb_resume: bool = False

# sub config
diloco: DilocoConfig | None = None
data: DataConfig = DataConfig()
optim: OptimConfig = OptimConfig()
train: TrainConfig
monitor: MonitorConfig | None = None

ckpt: CkptConfig = CkptConfig()

@model_validator(mode="after")
def ckpt_diloco_step(self):
if self.ckpt is not None and self.ckpt.interval is not None and self.diloco is not None:
assert (
self.ckpt.interval % self.diloco.inner_steps == 0
), "ckpt interval must be a multiple of diloco inner steps as we only save at the end of an outer step"
return self

@model_validator(mode="after")
def validate_live_recovery_rank_src(self):
if self.ckpt is not None and self.ckpt.live_recovery_rank_src is not None and self.diloco is None:
raise ValueError("live_recovery_rank_src is only supported with diloco")
return self


def log_hash_training_state(
config: Config,
model: torch.nn.Module,
Expand Down Expand Up @@ -186,6 +106,7 @@ def train(config: Config):
data_config=config.data,
)

logger.debug("Getting model")
model, model_config = get_model(
config.name_model,
config.type_model,
Expand All @@ -194,6 +115,7 @@ def train(config: Config):
attn_fn=config.train.attn_fn,
)

logger.debug(f"Distributing model to {world_info.local_rank}")
model = model.to(world_info.local_rank)
logger.debug("model loaded")

Expand Down Expand Up @@ -550,6 +472,7 @@ def train(config: Config):
torch.cuda.set_device(world_info.local_rank)

config = Config(**parse_argv())
config.train.memory_profiler = MemoryProfilerConfig(snapshot_dir="logs/", freq=1)
logger.debug(f"config: {config.model_dump()}")

try:
Expand Down

0 comments on commit f79e9ea

Please sign in to comment.