Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace get_env_config() with resolve_env_vars(). #198

Merged
merged 14 commits into from
Jan 17, 2025
3 changes: 2 additions & 1 deletion scripts/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from zeroband.collectives import Compression, all_reduce
from zeroband.utils.world_info import get_world_info
from zeroband.utils.logging import get_logger

from enum import Enum


Expand Down Expand Up @@ -63,6 +64,6 @@ def main(config: Config):
torch.set_float32_matmul_precision("high")
init_process_group(backend="gloo")

logger = get_logger(config)
logger = get_logger()
main(config)
destroy_process_group()
2 changes: 2 additions & 0 deletions scripts/convert_dl_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# python scripts/convert_dl_state.py @configs/10B/H100.toml --input_path /workspace/step_49200/diloco_0/data/_3.pt --output_path ./meow.pt --rank 3 --world_size 8

import torch
from zeroband.config import resolve_env_vars
from zeroband.data import get_dataloader
from transformers import AutoTokenizer
from zeroband.train import Config
Expand Down Expand Up @@ -133,6 +134,7 @@ def test_dl(config: ExportConfig):
if __name__ == "__main__":
logger = get_logger()
config = ExportConfig(**parse_argv())
resolve_env_vars(config)
logger.debug(f"config: {config.model_dump()}")

main(config)
Expand Down
2 changes: 2 additions & 0 deletions scripts/export_dcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Literal
import torch.distributed.checkpoint as dcp
from zeroband.models.llama import get_model
from zeroband.config import resolve_env_vars
from zeroband.checkpoint import ModelWrapper
from zeroband.utils import get_module_signature
from zeroband.train import Config
Expand Down Expand Up @@ -221,6 +222,7 @@ def main(config: ExportConfig):
if __name__ == "__main__":
logger = get_logger()
config = ExportConfig(**parse_argv())
resolve_env_vars(config)
logger.debug(f"config: {config.model_dump()}")

main(config)
3 changes: 2 additions & 1 deletion scripts/skip_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from transformers import AutoTokenizer
from zeroband.checkpoint import CkptManager

from zeroband.config import resolve_env_vars
from zeroband.train import Config

from zeroband.data import get_dataloader
Expand Down Expand Up @@ -83,5 +83,6 @@ def skip_data(config: Config):
logger = get_logger()

config = Config(**parse_argv())
resolve_env_vars(config)

skip_data(config)
141 changes: 78 additions & 63 deletions src/zeroband/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Literal, TypeAlias
import os

from pydantic import model_validator
from pydantic import create_model, model_validator
from pydantic_config import BaseConfig

from zeroband.collectives import Compression
Expand Down Expand Up @@ -143,6 +143,8 @@ def validate_remote_data_path(self):
return self


ENV_VAR_PREFIX = "ZERO_BAND_"

class Config(BaseConfig):
# main config
name_model: Literal["debugmodel", "150M", "271M", "1B", "7B", "10B", "13B", "26B", "70B"] = "150M"
Expand Down Expand Up @@ -183,68 +185,81 @@ def validate_live_recovery_rank_src(self):
return self


def get_env_config(config: Config | None, item: str | None, default: Any | None = None) -> Any:
"""
Get a config value from the environment or the config.
item: item is of the form "train.memory_profiler.freq"
default: default value if not found

If either config or item are None, returns default. This is so you can call get_logger() as before.

Examples:
```
# Returns ZERO_BAND_RUN_NAME if set in env.
# Otherwise returns config.run_name.
get_env_config(config, "run_name")
```
```
# Returns ZERO_BAND_TRAIN_MEMORY_PROFILER_FREQ if set in env.
# Then returns 10 if train or config.train.memory_profiler are None.
# Otherwise, returns the value of config.train.memory_profiler.freq.
get_env_config(config, "train.memory_profiler.freq", 10)
```

def resolve_env_vars(config: Config) -> None:
"""

if config is None or item is None:
return default

# Check env
env_name = "ZERO_BAND_" + item.upper().replace(".", "_")
if env_name in os.environ:
return os.environ[env_name]

# Check config
spt = item.split(".")
cfg: Any = config
for s in spt:
print(cfg)
print(s)
if cfg is None:
return default
try:
cfg = getattr(cfg, s)
except AttributeError:
# TODO: Fancier error message for debugging
raise ValueError(f"Config item {item} not found.")

return cfg


def get_env_config_bool(config: Config | None, item: str | None, default: bool | None = None) -> bool:
"""
Call get_env_config and convert strings to bools where makes sense.

Throws an exception if the value is not a string and not convertable.
Resolve environment variables for config fields.
Modifies the config in place.
Environment variables should be prefixed with ZERO_BAND_.
"""

val = get_env_config(config, item, default)
if val is None and default is not None:
return default
if val is None:
return False
if isinstance(val, bool):
return val
if isinstance(val, str):
return val.lower() == "true" or val.lower() == "1"
return bool(val)
def _resolve_value(env_var: str, field_name: str, config_obj: Any) -> Any:
"""
Resolve a single value from an environment variable
env_var: full environment variable name (e.g. ZERO_BAND_TRAIN_MICRO_BS)
field_name: actual field name in the config object (e.g. micro_bs)
"""
value = os.environ.get(env_var)
if value is not None:
if (field_info := config_obj.__class__.model_fields.get(field_name)) is None:
raise AttributeError(f"Config {config_obj} has no attribute {field_name}")

try:
# Create a temporary model with just this field, then validate and rip it out.
py_model = create_model('TempModel', __base__ = BaseConfig, **{field_name: (field_info.annotation, ...)}) # type: ignore
validated = py_model.model_validate({field_name: value})
return getattr(validated, field_name)
except Exception as e:
raise ValueError(f"Error setting {env_var}={value}: {e}")
return None

def _resolve_nested(prefix: str, config_obj: Any) -> None:
if not hasattr(config_obj, 'model_fields'):
return

for field_name, _ in config_obj.__class__.model_fields.items():
# Build the full env var name
full_env_var = f"{ENV_VAR_PREFIX}{prefix}_{field_name}".upper() if prefix else f"{ENV_VAR_PREFIX}{field_name}".upper()

# Try to resolve the field directly using the local field name
value = _resolve_value(full_env_var, field_name, config_obj)
if value is not None:
setattr(config_obj, field_name, value)

# Handle nested configs
field_value = getattr(config_obj, field_name)
if field_value is not None and hasattr(field_value, 'model_fields'):
# Pass the prefix for building env var names, but use local field names for lookup
_resolve_nested(f"{prefix}_{field_name}" if prefix else field_name, field_value)

def _get_valid_env_vars(prefix: str, config_obj: Any) -> set[str]:
"""Recursively collect all valid environment variable names"""
valid_vars = set()
if not hasattr(config_obj, 'model_fields'):
return valid_vars

for field_name, _ in config_obj.__class__.model_fields.items():
full_env_var = f"{ENV_VAR_PREFIX}{prefix}_{field_name}".upper() if prefix else f"{ENV_VAR_PREFIX}{field_name}".upper()
valid_vars.add(full_env_var)

field_value = getattr(config_obj, field_name)
if field_value is not None and hasattr(field_value, 'model_fields'):
nested_prefix = f"{prefix}_{field_name}" if prefix else field_name
valid_vars.update(_get_valid_env_vars(nested_prefix, field_value))

return valid_vars

# Check for any invalid ZERO_BAND_ environment variables
valid_env_vars = _get_valid_env_vars("", config)
invalid_vars = []
for env_var in os.environ:
if env_var.startswith(ENV_VAR_PREFIX) and env_var not in valid_env_vars:
invalid_vars.append(env_var)

if invalid_vars:
raise ValueError(
f"Found invalid environment variables with {ENV_VAR_PREFIX} prefix: {', '.join(invalid_vars)}\n"
"See the full list of valid config veriables in src/zeroband/config.py."
)

# Now resolve the valid ones.
_resolve_nested("", config)
20 changes: 8 additions & 12 deletions src/zeroband/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@

TEST_VOCAB_SIZE = 1024

# TODO sami: make sure the init of the model is the same on all rank

logger = get_logger(name=__name__)


class FakeTokenizedDataset(IterableDataset):
"""This is a dummy dataset that generates random sequences of length seq_len and vocab_size"""
Expand Down Expand Up @@ -160,7 +156,7 @@ def _lazy_init(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
if worker_info.num_workers > len(self.arg_files):
logger.warning(
get_logger().warning(
f"dataloader rank {worker_info.id} Number of workers {worker_info.num_workers} is greater than the number of files {len(self.arg_files)}"
)
self.state = PQDatasetState(
Expand Down Expand Up @@ -238,7 +234,7 @@ def __init__(self, datasets: List[ParquetDataset], probabilities: Optional[List[
self.datasets.append(dataset)
self.probabilities.append(prob)
else:
logger.warning(f"Dataset {dataset} is empty. Skipping.")
get_logger().warning(f"Dataset {dataset} is empty. Skipping.")

self.state = InterleaveDatasetState(current_index=0, seed=seed)
self._init_random_state()
Expand Down Expand Up @@ -312,7 +308,7 @@ def _get_datafiles(path: str, name: Optional[str] = None, split: str = "train")
builder_config = _get_ds_config_dict(path=path, name=name)
if name is None or len(name) == 0:
if "default" not in builder_config:
logger.warning(f"Default config not found for {path}. Using first config.")
get_logger().warning(f"Default config not found for {path}. Using first config.")
name = next(iter(builder_config.keys()))
else:
name = "default"
Expand All @@ -338,7 +334,7 @@ def _load_datasets(
probabilities: Optional[List[float]] = None,
reverse_data_files: bool = False,
) -> InterleaveDataset:
logger.debug(dataset_names)
get_logger().debug(dataset_names)
ds_args = []
for _ds in dataset_names.split(","):
_ds_name, _, _ds_config = _ds.partition(":")
Expand All @@ -356,7 +352,7 @@ def _load_datasets(

# logger.debug(f"Datasets ({split}):\n" + "\n".join(map(_nice_print, ds_args)))
# logger.debug(f"Probabilities: {probabilities}")
logger.debug(f"Loading datasets{' in streaming mode' if streaming else ''}")
get_logger().debug(f"Loading datasets{' in streaming mode' if streaming else ''}")
datasets = []
for ds_arg in ds_args:
# logger.debug(f"Loading dataset: {ds_arg['data_files']}")
Expand All @@ -368,7 +364,7 @@ def _load_datasets(
else:
ds = datasets[0]

logger.info(f"Loaded datasets ({split})")
get_logger().info(f"Loaded datasets ({split})")
return ds


Expand Down Expand Up @@ -401,7 +397,7 @@ def load_all_datasets(
split_world_size = world_size


logger.info("Loading Train dataset(s)")
get_logger().info("Loading Train dataset(s)")

ds = _load_datasets(
dataset_names=data_config.dataset_name_or_paths,
Expand All @@ -413,6 +409,6 @@ def load_all_datasets(
tokenizer=tokenizer,
)

logger.info(f"Train dataset:\n{ds}")
get_logger().info(f"Train dataset:\n{ds}")

return ds
5 changes: 3 additions & 2 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from zeroband.checkpoint import CkptManager, TrainingProgress
from zeroband.comms import ElasticDeviceMesh
from zeroband.config import Config
from zeroband.config import Config, resolve_env_vars
from zeroband.data import TEST_VOCAB_SIZE, get_dataloader
from zeroband.diloco import Diloco
from zeroband.loss import compute_cross_entropy_loss
Expand Down Expand Up @@ -83,7 +83,7 @@ def train(config: Config):
assert config.optim.batch_size % world_info.local_world_size == 0
batch_size = config.optim.batch_size // world_info.local_world_size

assert batch_size % config.train.micro_bs == 0
assert batch_size % config.train.micro_bs == 0, f'The micro batch size ({config.train.micro_bs}) must divide the number of samples on each GPU ({batch_size}).'
gradient_accumulation_steps = batch_size // config.train.micro_bs

if config.ckpt is not None and config.ckpt.interval is not None and config.diloco is not None:
Expand Down Expand Up @@ -496,6 +496,7 @@ def train(config: Config):
torch.manual_seed(42)

config = Config(**parse_argv()) # type: ignore
resolve_env_vars(config)
world_info = get_world_info()
logger = get_logger(config)

Expand Down
13 changes: 6 additions & 7 deletions src/zeroband/utils/logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from zeroband.config import Config, get_env_config, get_env_config_bool
from zeroband.config import Config
from zeroband.utils.world_info import get_world_info

logger = None
Expand Down Expand Up @@ -29,6 +29,8 @@ def get_logger(config: Config | None = None, name: str | None = None) -> logging
if logger is not None:
return logger

assert isinstance(config, Config)

try:
world_info = get_world_info()
except KeyError:
Expand All @@ -38,14 +40,11 @@ def get_logger(config: Config | None = None, name: str | None = None) -> logging
world_info.local_rank = 0
logger = logging.getLogger(name or __name__)

log_level = get_env_config(config, "log_level", "INFO")
assert isinstance(log_level, str)

if world_info.local_rank == 0:
logger.setLevel(level=getattr(logging, log_level, logging.INFO))
logger.setLevel(level=getattr(logging, config.log_level, logging.INFO))
else:
if get_env_config_bool(config, "log_all_rank", False):
logger.setLevel(level=getattr(logging, log_level, logging.INFO))
if config.log_all_rank:
logger.setLevel(level=getattr(logging, config.log_level, logging.INFO))
else:
logger.setLevel(level=logging.CRITICAL) # Disable logging for non-zero ranks

Expand Down
4 changes: 1 addition & 3 deletions src/zeroband/utils/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@ def __init__(self, project, logger_config, resume: bool):

import wandb

run_name = logger_config["config"]["run_name"]

wandb.init(
project=project, config=logger_config, name=run_name, resume="auto" if resume else None
project=project, config=logger_config, name=logger_config["config"]["run_name"], resume="auto" if resume else None
) # make wandb reuse the same run id if possible

def log(self, metrics: dict[str, Any]):
Expand Down
Loading