Skip to content

Commit

Permalink
removed ununsed args
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Jan 16, 2025
1 parent 8bec8a8 commit 694e3a1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 13 deletions.
5 changes: 0 additions & 5 deletions configs/10B/H100_devel.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,8 @@ sched_type = "wsd-sqrt"
batch_size = 128 #1M tokens bs
warmup_steps = 0
total_steps = 1
lr = 7.5e-5
num_chunks = 8

adam_betas1 = 0.9
adam_betas2 = 0.95
weight_decay = 0.1

z_loss = true

[data]
Expand Down
15 changes: 7 additions & 8 deletions src/zeroband/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ class DataConfig(BaseConfig):
reverse_data_files: bool = False
split_by_data_rank: bool = True


class AdamConfig(BaseConfig):
type: Literal["adam"] = "adam" # the literal is used to distinguish between the different optimizers configuration in the union type
type: Literal["adam"] = (
"adam" # the literal is used to distinguish between the different optimizers configuration in the union type
)
lr: float = 4e-4
weight_decay: float = 0.1
betas1: float = 0.9
Expand All @@ -47,11 +50,6 @@ class SoapConfig(BaseConfig):
class OptimConfig(BaseConfig):
optim: OptimizersConfig = AdamConfig()

lr: float = 4e-4
weight_decay: float = 0.1
adam_betas1: float = 0.9
adam_betas2: float = 0.95

sched_type: Literal["cosine", "linear", "wsd-sqrt"] = "cosine"
warmup_steps: int = 1000
stable_steps: int = 80_000
Expand All @@ -70,6 +68,7 @@ class DilocoConfig(BaseConfig):

retry_all_reduce: int = 3


class MemoryProfilerConfig(BaseConfig):
freq: int = 10
snapshot_dir: str
Expand Down Expand Up @@ -231,7 +230,8 @@ def get_env_config(config: Config | None, item: str | None, default: Any | None

return cfg

def get_env_config_bool(config: Config | None, item: str | None, default: bool | None = None) -> bool:

def get_env_config_bool(config: Config | None, item: str | None, default: bool | None = None) -> bool:
"""
Call get_env_config and convert strings to bools where makes sense.
Expand All @@ -248,4 +248,3 @@ def get_env_config_bool(config: Config | None, item: str | None, default: bool
if isinstance(val, str):
return val.lower() == "true" or val.lower() == "1"
return bool(val)

0 comments on commit 694e3a1

Please sign in to comment.