From 3bb221f4bcf896451f62c9ec65e56b55f0dbb672 Mon Sep 17 00:00:00 2001 From: apaz-cli Date: Wed, 15 Jan 2025 00:04:05 +0000 Subject: [PATCH 01/10] Ignore project log pickle --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 3d42f8a4..38b3607d 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,6 @@ cython_debug/ # Aider .aider* + +# Files created while testing +debug_I2_zero_band From 6241da31201d8d526fc077f1eef500509870286f Mon Sep 17 00:00:00 2001 From: apaz-cli Date: Thu, 16 Jan 2025 01:28:20 +0000 Subject: [PATCH 02/10] Remove prints --- src/zeroband/config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/zeroband/config.py b/src/zeroband/config.py index ee2d615e..1951955e 100644 --- a/src/zeroband/config.py +++ b/src/zeroband/config.py @@ -219,8 +219,6 @@ def get_env_config(config: Config | None, item: str | None, default: Any | None spt = item.split(".") cfg: Any = config for s in spt: - print(cfg) - print(s) if cfg is None: return default try: From d41dea97c576066928a73283195c6fc9c752eea3 Mon Sep 17 00:00:00 2001 From: sami jaghouar Date: Thu, 16 Jan 2025 01:27:43 +0000 Subject: [PATCH 03/10] fix wandb config run name --- src/zeroband/train.py | 2 +- src/zeroband/utils/metric_logger.py | 14 ++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 32f9cdd0..357c2d61 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -203,7 +203,7 @@ def train(config: Config): logger_cls = WandbMetricLogger if config.metric_logger_type == "wandb" else DummyMetricLogger metric_logger = logger_cls( project=config.project, - config={"config": config.model_dump(), "world_info": world_info.json()}, + logger_config={"config": config.model_dump(), "world_info": world_info.json()}, resume=config.wandb_resume, ) else: diff --git a/src/zeroband/utils/metric_logger.py b/src/zeroband/utils/metric_logger.py index 0a47dc3f..85847925 100644 --- a/src/zeroband/utils/metric_logger.py +++ b/src/zeroband/utils/metric_logger.py @@ -2,11 +2,9 @@ from typing import Any, Protocol import importlib.util -from zeroband.config import get_env_config - class MetricLogger(Protocol): - def __init__(self, project, config): ... + def __init__(self, project, logger_config): ... def log(self, metrics: dict[str, Any]): ... @@ -14,16 +12,16 @@ def finish(self): ... class WandbMetricLogger(MetricLogger): - def __init__(self, project, config, resume: bool): + def __init__(self, project, logger_config, resume: bool): if importlib.util.find_spec("wandb") is None: raise ImportError("wandb is not installed. Please install it to use WandbMonitor.") import wandb - run_name = get_env_config(config, "run_name") + run_name = logger_config["config"]["run_name"] wandb.init( - project=project, config=config, name=run_name, resume="auto" if resume else None + project=project, config=logger_config, name=run_name, resume="auto" if resume else None ) # make wandb reuse the same run id if possible def log(self, metrics: dict[str, Any]): @@ -38,9 +36,9 @@ def finish(self): class DummyMetricLogger(MetricLogger): - def __init__(self, project, config, *args, **kwargs): + def __init__(self, project, logger_config, *args, **kwargs): self.project = project - self.config = config + self.logger_config = logger_config open(self.project, "a").close() # Create an empty file to append to self.data = [] From e135c932e743e7ecdfa90634d50d280f772be0e5 Mon Sep 17 00:00:00 2001 From: apaz-cli Date: Thu, 16 Jan 2025 05:22:04 +0000 Subject: [PATCH 04/10] Replace get_env_config() with resolve_env_vars(). --- src/zeroband/config.py | 139 +++++++++++++++------------- src/zeroband/data.py | 20 ++-- src/zeroband/train.py | 11 ++- src/zeroband/utils/logging.py | 11 +-- src/zeroband/utils/metric_logger.py | 4 +- 5 files changed, 97 insertions(+), 88 deletions(-) diff --git a/src/zeroband/config.py b/src/zeroband/config.py index 1951955e..d22173b0 100644 --- a/src/zeroband/config.py +++ b/src/zeroband/config.py @@ -1,7 +1,7 @@ -from typing import Any, Literal, TypeAlias +from typing import Any, Type, Literal, TypeAlias import os -from pydantic import model_validator +from pydantic import create_model, model_validator from pydantic_config import BaseConfig from zeroband.collectives import Compression @@ -184,66 +184,81 @@ def validate_live_recovery_rank_src(self): return self -def get_env_config(config: Config | None, item: str | None, default: Any | None = None) -> Any: +def resolve_env_vars(config: Config) -> None: """ - Get a config value from the environment or the config. - item: item is of the form "train.memory_profiler.freq" - default: default value if not found - - If either config or item are None, returns default. This is so you can call get_logger() as before. - - Examples: - ``` - # Returns ZERO_BAND_RUN_NAME if set in env. - # Otherwise returns config.run_name. - get_env_config(config, "run_name") - ``` - ``` - # Returns ZERO_BAND_TRAIN_MEMORY_PROFILER_FREQ if set in env. - # Then returns 10 if train or config.train.memory_profiler are None. - # Otherwise, returns the value of config.train.memory_profiler.freq. - get_env_config(config, "train.memory_profiler.freq", 10) - ``` - - """ - - if config is None or item is None: - return default - - # Check env - env_name = "ZERO_BAND_" + item.upper().replace(".", "_") - if env_name in os.environ: - return os.environ[env_name] - - # Check config - spt = item.split(".") - cfg: Any = config - for s in spt: - if cfg is None: - return default - try: - cfg = getattr(cfg, s) - except AttributeError: - # TODO: Fancier error message for debugging - raise ValueError(f"Config item {item} not found.") - - return cfg - -def get_env_config_bool(config: Config | None, item: str | None, default: bool | None = None) -> bool: + Resolve environment variables for config fields. + Modifies the config in place. + Environment variables should be prefixed with ZERO_BAND_. """ - Call get_env_config and convert strings to bools where makes sense. - - Throws an exception if the value is not a string and not convertable. - """ - - val = get_env_config(config, item, default) - if val is None and default is not None: - return default - if val is None: - return False - if isinstance(val, bool): - return val - if isinstance(val, str): - return val.lower() == "true" or val.lower() == "1" - return bool(val) + def _resolve_value(env_var: str, field_name: str, config_obj: Any) -> Any: + """ + Resolve a single value from an environment variable + env_var: full environment variable name (e.g. ZERO_BAND_TRAIN_MICRO_BS) + field_name: actual field name in the config object (e.g. micro_bs) + """ + value = os.environ.get(env_var) + if value is not None: + if (field_info := config_obj.__class__.model_fields.get(field_name)) is None: + raise AttributeError(f"Config {config_obj} has no attribute {field_name}") + + try: + # Create a temporary model with just this field, then validate and rip it out. + py_model = create_model('TempModel', __base__ = BaseConfig, **{field_name: (field_info.annotation, ...)}) + validated = py_model.model_validate({field_name: value}) + return getattr(validated, field_name) + except Exception as e: + raise ValueError(f"Error setting {env_var}={value}: {e}") + return None + + def _resolve_nested(prefix: str, config_obj: Any) -> None: + if not hasattr(config_obj, 'model_fields'): + return + + for field_name, _ in config_obj.__class__.model_fields.items(): + # Build the full env var name + full_env_var = f"ZERO_BAND_{prefix}_{field_name}".upper() if prefix else f"ZERO_BAND_{field_name}".upper() + + # Try to resolve the field directly using the local field name + value = _resolve_value(full_env_var, field_name, config_obj) + if value is not None: + setattr(config_obj, field_name, value) + + # Handle nested configs + field_value = getattr(config_obj, field_name) + if field_value is not None and hasattr(field_value, 'model_fields'): + # Pass the prefix for building env var names, but use local field names for lookup + _resolve_nested(f"{prefix}_{field_name}" if prefix else field_name, field_value) + + def _get_valid_env_vars(prefix: str, config_obj: Any) -> set[str]: + """Recursively collect all valid environment variable names""" + valid_vars = set() + if not hasattr(config_obj, 'model_fields'): + return valid_vars + + for field_name, _ in config_obj.__class__.model_fields.items(): + full_env_var = f"ZERO_BAND_{prefix}_{field_name}".upper() if prefix else f"ZERO_BAND_{field_name}".upper() + valid_vars.add(full_env_var) + + field_value = getattr(config_obj, field_name) + if field_value is not None and hasattr(field_value, 'model_fields'): + nested_prefix = f"{prefix}_{field_name}" if prefix else field_name + valid_vars.update(_get_valid_env_vars(nested_prefix, field_value)) + + return valid_vars + + # Check for any invalid ZERO_BAND_ environment variables + valid_env_vars = _get_valid_env_vars("", config) + invalid_vars = [] + for env_var in os.environ: + if env_var.startswith("ZERO_BAND_") and env_var not in valid_env_vars: + invalid_vars.append(env_var) + + if invalid_vars: + raise ValueError( + f"Found invalid environment variables with ZERO_BAND_ prefix: {', '.join(invalid_vars)}\n" + "See the full list of valid config veriables in src/zeroband/config.py." + ) + + # Now resolve the valid ones. + _resolve_nested("", config) diff --git a/src/zeroband/data.py b/src/zeroband/data.py index 2fe73806..32d47c76 100644 --- a/src/zeroband/data.py +++ b/src/zeroband/data.py @@ -18,10 +18,6 @@ TEST_VOCAB_SIZE = 1024 -# TODO sami: make sure the init of the model is the same on all rank - -logger = get_logger(name=__name__) - class FakeTokenizedDataset(IterableDataset): """This is a dummy dataset that generates random sequences of length seq_len and vocab_size""" @@ -160,7 +156,7 @@ def _lazy_init(self): worker_info = torch.utils.data.get_worker_info() if worker_info is not None: if worker_info.num_workers > len(self.arg_files): - logger.warning( + get_logger().warning( f"dataloader rank {worker_info.id} Number of workers {worker_info.num_workers} is greater than the number of files {len(self.arg_files)}" ) self.state = PQDatasetState( @@ -238,7 +234,7 @@ def __init__(self, datasets: List[ParquetDataset], probabilities: Optional[List[ self.datasets.append(dataset) self.probabilities.append(prob) else: - logger.warning(f"Dataset {dataset} is empty. Skipping.") + get_logger().warning(f"Dataset {dataset} is empty. Skipping.") self.state = InterleaveDatasetState(current_index=0, seed=seed) self._init_random_state() @@ -312,7 +308,7 @@ def _get_datafiles(path: str, name: Optional[str] = None, split: str = "train") builder_config = _get_ds_config_dict(path=path, name=name) if name is None or len(name) == 0: if "default" not in builder_config: - logger.warning(f"Default config not found for {path}. Using first config.") + get_logger().warning(f"Default config not found for {path}. Using first config.") name = next(iter(builder_config.keys())) else: name = "default" @@ -338,7 +334,7 @@ def _load_datasets( probabilities: Optional[List[float]] = None, reverse_data_files: bool = False, ) -> InterleaveDataset: - logger.debug(dataset_names) + get_logger().debug(dataset_names) ds_args = [] for _ds in dataset_names.split(","): _ds_name, _, _ds_config = _ds.partition(":") @@ -356,7 +352,7 @@ def _load_datasets( # logger.debug(f"Datasets ({split}):\n" + "\n".join(map(_nice_print, ds_args))) # logger.debug(f"Probabilities: {probabilities}") - logger.debug(f"Loading datasets{' in streaming mode' if streaming else ''}") + get_logger().debug(f"Loading datasets{' in streaming mode' if streaming else ''}") datasets = [] for ds_arg in ds_args: # logger.debug(f"Loading dataset: {ds_arg['data_files']}") @@ -368,7 +364,7 @@ def _load_datasets( else: ds = datasets[0] - logger.info(f"Loaded datasets ({split})") + get_logger().info(f"Loaded datasets ({split})") return ds @@ -401,7 +397,7 @@ def load_all_datasets( split_world_size = world_size - logger.info("Loading Train dataset(s)") + get_logger().info("Loading Train dataset(s)") ds = _load_datasets( dataset_names=data_config.dataset_name_or_paths, @@ -413,6 +409,6 @@ def load_all_datasets( tokenizer=tokenizer, ) - logger.info(f"Train dataset:\n{ds}") + get_logger().info(f"Train dataset:\n{ds}") return ds diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 357c2d61..ded4c065 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -10,7 +10,7 @@ from zeroband.checkpoint import CkptManager, TrainingProgress from zeroband.comms import ElasticDeviceMesh -from zeroband.config import Config +from zeroband.config import Config, resolve_env_vars from zeroband.data import TEST_VOCAB_SIZE, get_dataloader from zeroband.diloco import Diloco from zeroband.loss import compute_cross_entropy_loss @@ -83,7 +83,7 @@ def train(config: Config): assert config.optim.batch_size % world_info.local_world_size == 0 batch_size = config.optim.batch_size // world_info.local_world_size - assert batch_size % config.train.micro_bs == 0 + assert batch_size % config.train.micro_bs == 0, f'The micro batch size ({config.train.micro_bs}) must divide the number of samples on each GPU ({batch_size}).' gradient_accumulation_steps = batch_size // config.train.micro_bs if config.ckpt is not None and config.ckpt.interval is not None and config.diloco is not None: @@ -496,6 +496,7 @@ def train(config: Config): torch.manual_seed(42) config = Config(**parse_argv()) # type: ignore + resolve_env_vars(config) world_info = get_world_info() logger = get_logger(config) @@ -504,12 +505,12 @@ def train(config: Config): def pretty_dict(d, indent=2): for key, value in d.items(): if isinstance(value, dict): - logger.debug(" " * indent + f"{key}:") + logger.info(" " * indent + f"{key}:") pretty_dict(value, indent + 2) else: - logger.debug(" " * indent + f"{key}: {value}") + logger.info(" " * indent + f"{key}: {value}") - logger.debug("config:") + logger.info("config:") pretty_dict(config.model_dump()) try: diff --git a/src/zeroband/utils/logging.py b/src/zeroband/utils/logging.py index 3ab41d05..3ddf7dc2 100644 --- a/src/zeroband/utils/logging.py +++ b/src/zeroband/utils/logging.py @@ -1,6 +1,6 @@ import logging -from zeroband.config import Config, get_env_config, get_env_config_bool +from zeroband.config import Config from zeroband.utils.world_info import get_world_info logger = None @@ -38,14 +38,13 @@ def get_logger(config: Config | None = None, name: str | None = None) -> logging world_info.local_rank = 0 logger = logging.getLogger(name or __name__) - log_level = get_env_config(config, "log_level", "INFO") - assert isinstance(log_level, str) + assert isinstance(config.log_level, str) if world_info.local_rank == 0: - logger.setLevel(level=getattr(logging, log_level, logging.INFO)) + logger.setLevel(level=getattr(logging, config.log_level, logging.INFO)) else: - if get_env_config_bool(config, "log_all_rank", False): - logger.setLevel(level=getattr(logging, log_level, logging.INFO)) + if config.log_all_rank: + logger.setLevel(level=getattr(logging, config.log_level, logging.INFO)) else: logger.setLevel(level=logging.CRITICAL) # Disable logging for non-zero ranks diff --git a/src/zeroband/utils/metric_logger.py b/src/zeroband/utils/metric_logger.py index 85847925..fe3fc7d1 100644 --- a/src/zeroband/utils/metric_logger.py +++ b/src/zeroband/utils/metric_logger.py @@ -18,10 +18,8 @@ def __init__(self, project, logger_config, resume: bool): import wandb - run_name = logger_config["config"]["run_name"] - wandb.init( - project=project, config=logger_config, name=run_name, resume="auto" if resume else None + project=project, config=logger_config, name=logger_config.config.run_name, resume="auto" if resume else None ) # make wandb reuse the same run id if possible def log(self, metrics: dict[str, Any]): From c35fe7a80c88b14a1ded804cabbd5639445a2bac Mon Sep 17 00:00:00 2001 From: apaz-cli Date: Thu, 16 Jan 2025 05:31:05 +0000 Subject: [PATCH 05/10] Clean up scripts. --- scripts/all_reduce.py | 4 +++- scripts/convert_dl_state.py | 2 ++ scripts/export_dcp.py | 2 ++ scripts/skip_data.py | 3 ++- src/zeroband/config.py | 2 +- src/zeroband/train.py | 6 +++--- src/zeroband/utils/logging.py | 4 ++-- 7 files changed, 15 insertions(+), 8 deletions(-) diff --git a/scripts/all_reduce.py b/scripts/all_reduce.py index 766a790b..c60fbeba 100644 --- a/scripts/all_reduce.py +++ b/scripts/all_reduce.py @@ -4,8 +4,10 @@ import torch.utils.benchmark as benchmark from zeroband.collectives import Compression, all_reduce +from zeroband.config import resolve_env_vars from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger + from enum import Enum @@ -63,6 +65,6 @@ def main(config: Config): torch.set_float32_matmul_precision("high") init_process_group(backend="gloo") - logger = get_logger(config) + logger = get_logger() main(config) destroy_process_group() diff --git a/scripts/convert_dl_state.py b/scripts/convert_dl_state.py index 3fe8d004..950e03f6 100755 --- a/scripts/convert_dl_state.py +++ b/scripts/convert_dl_state.py @@ -4,6 +4,7 @@ # python scripts/convert_dl_state.py @configs/10B/H100.toml --input_path /workspace/step_49200/diloco_0/data/_3.pt --output_path ./meow.pt --rank 3 --world_size 8 import torch +from zeroband.config import resolve_env_vars from zeroband.data import get_dataloader from transformers import AutoTokenizer from zeroband.train import Config @@ -133,6 +134,7 @@ def test_dl(config: ExportConfig): if __name__ == "__main__": logger = get_logger() config = ExportConfig(**parse_argv()) + resolve_env_vars(config) logger.debug(f"config: {config.model_dump()}") main(config) diff --git a/scripts/export_dcp.py b/scripts/export_dcp.py index cf7460dd..d606bc78 100644 --- a/scripts/export_dcp.py +++ b/scripts/export_dcp.py @@ -7,6 +7,7 @@ from typing import Literal import torch.distributed.checkpoint as dcp from zeroband.models.llama import get_model +from zeroband.config import resolve_env_vars from zeroband.checkpoint import ModelWrapper from zeroband.utils import get_module_signature from zeroband.train import Config @@ -221,6 +222,7 @@ def main(config: ExportConfig): if __name__ == "__main__": logger = get_logger() config = ExportConfig(**parse_argv()) + resolve_env_vars(config) logger.debug(f"config: {config.model_dump()}") main(config) diff --git a/scripts/skip_data.py b/scripts/skip_data.py index 4b32b55e..04b5ed3b 100644 --- a/scripts/skip_data.py +++ b/scripts/skip_data.py @@ -19,7 +19,7 @@ from transformers import AutoTokenizer from zeroband.checkpoint import CkptManager - +from zeroband.config import resolve_env_vars from zeroband.train import Config from zeroband.data import get_dataloader @@ -83,5 +83,6 @@ def skip_data(config: Config): logger = get_logger() config = Config(**parse_argv()) + resolve_env_vars(config) skip_data(config) diff --git a/src/zeroband/config.py b/src/zeroband/config.py index d22173b0..f11ee0d8 100644 --- a/src/zeroband/config.py +++ b/src/zeroband/config.py @@ -204,7 +204,7 @@ def _resolve_value(env_var: str, field_name: str, config_obj: Any) -> Any: try: # Create a temporary model with just this field, then validate and rip it out. - py_model = create_model('TempModel', __base__ = BaseConfig, **{field_name: (field_info.annotation, ...)}) + py_model = create_model('TempModel', __base__ = BaseConfig, **{field_name: (field_info.annotation, ...)}) # type: ignore validated = py_model.model_validate({field_name: value}) return getattr(validated, field_name) except Exception as e: diff --git a/src/zeroband/train.py b/src/zeroband/train.py index ded4c065..dfbfa5a4 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -505,12 +505,12 @@ def train(config: Config): def pretty_dict(d, indent=2): for key, value in d.items(): if isinstance(value, dict): - logger.info(" " * indent + f"{key}:") + logger.debug(" " * indent + f"{key}:") pretty_dict(value, indent + 2) else: - logger.info(" " * indent + f"{key}: {value}") + logger.debug(" " * indent + f"{key}: {value}") - logger.info("config:") + logger.debug("config:") pretty_dict(config.model_dump()) try: diff --git a/src/zeroband/utils/logging.py b/src/zeroband/utils/logging.py index 3ddf7dc2..5721d0d6 100644 --- a/src/zeroband/utils/logging.py +++ b/src/zeroband/utils/logging.py @@ -29,6 +29,8 @@ def get_logger(config: Config | None = None, name: str | None = None) -> logging if logger is not None: return logger + assert isinstance(config, Config) + try: world_info = get_world_info() except KeyError: @@ -38,8 +40,6 @@ def get_logger(config: Config | None = None, name: str | None = None) -> logging world_info.local_rank = 0 logger = logging.getLogger(name or __name__) - assert isinstance(config.log_level, str) - if world_info.local_rank == 0: logger.setLevel(level=getattr(logging, config.log_level, logging.INFO)) else: From 4791139be958ecd4feb982f1597e17d71fc96a0b Mon Sep 17 00:00:00 2001 From: apaz-cli Date: Thu, 16 Jan 2025 05:47:50 +0000 Subject: [PATCH 06/10] Fix logger_config (is a dict) --- src/zeroband/utils/metric_logger.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/zeroband/utils/metric_logger.py b/src/zeroband/utils/metric_logger.py index fe3fc7d1..d444712e 100644 --- a/src/zeroband/utils/metric_logger.py +++ b/src/zeroband/utils/metric_logger.py @@ -18,8 +18,9 @@ def __init__(self, project, logger_config, resume: bool): import wandb + print(logger_config["config"]) wandb.init( - project=project, config=logger_config, name=logger_config.config.run_name, resume="auto" if resume else None + project=project, config=logger_config, name=logger_config["config"]["run_name"], resume="auto" if resume else None ) # make wandb reuse the same run id if possible def log(self, metrics: dict[str, Any]): From e00f17b3fb2483d0b8144b1eaebcb5903842dfdd Mon Sep 17 00:00:00 2001 From: apaz-cli Date: Thu, 16 Jan 2025 05:49:38 +0000 Subject: [PATCH 07/10] Remove print --- src/zeroband/utils/metric_logger.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/zeroband/utils/metric_logger.py b/src/zeroband/utils/metric_logger.py index d444712e..73befcaf 100644 --- a/src/zeroband/utils/metric_logger.py +++ b/src/zeroband/utils/metric_logger.py @@ -18,7 +18,6 @@ def __init__(self, project, logger_config, resume: bool): import wandb - print(logger_config["config"]) wandb.init( project=project, config=logger_config, name=logger_config["config"]["run_name"], resume="auto" if resume else None ) # make wandb reuse the same run id if possible From 856732a413a8461cf560d1e737ffdd870b96db08 Mon Sep 17 00:00:00 2001 From: apaz-cli Date: Thu, 16 Jan 2025 15:26:31 -0600 Subject: [PATCH 08/10] Make prefix configurable. --- src/zeroband/config.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/zeroband/config.py b/src/zeroband/config.py index 42ccdc8f..302aa23b 100644 --- a/src/zeroband/config.py +++ b/src/zeroband/config.py @@ -143,6 +143,8 @@ def validate_remote_data_path(self): return self +ENV_VAR_PREFIX = "ZERO_BAND_" + class Config(BaseConfig): # main config name_model: Literal["debugmodel", "150M", "271M", "1B", "7B", "10B", "13B", "26B", "70B"] = "150M" @@ -216,7 +218,7 @@ def _resolve_nested(prefix: str, config_obj: Any) -> None: for field_name, _ in config_obj.__class__.model_fields.items(): # Build the full env var name - full_env_var = f"ZERO_BAND_{prefix}_{field_name}".upper() if prefix else f"ZERO_BAND_{field_name}".upper() + full_env_var = f"{ENV_VAR_PREFIX}{prefix}_{field_name}".upper() if prefix else f"{ENV_VAR_PREFIX}{field_name}".upper() # Try to resolve the field directly using the local field name value = _resolve_value(full_env_var, field_name, config_obj) @@ -236,7 +238,7 @@ def _get_valid_env_vars(prefix: str, config_obj: Any) -> set[str]: return valid_vars for field_name, _ in config_obj.__class__.model_fields.items(): - full_env_var = f"ZERO_BAND_{prefix}_{field_name}".upper() if prefix else f"ZERO_BAND_{field_name}".upper() + full_env_var = f"{ENV_VAR_PREFIX}{prefix}_{field_name}".upper() if prefix else f"{ENV_VAR_PREFIX}{field_name}".upper() valid_vars.add(full_env_var) field_value = getattr(config_obj, field_name) @@ -250,12 +252,12 @@ def _get_valid_env_vars(prefix: str, config_obj: Any) -> set[str]: valid_env_vars = _get_valid_env_vars("", config) invalid_vars = [] for env_var in os.environ: - if env_var.startswith("ZERO_BAND_") and env_var not in valid_env_vars: + if env_var.startswith(ENV_VAR_PREFIX) and env_var not in valid_env_vars: invalid_vars.append(env_var) if invalid_vars: raise ValueError( - f"Found invalid environment variables with ZERO_BAND_ prefix: {', '.join(invalid_vars)}\n" + f"Found invalid environment variables with {ENV_VAR_PREFIX} prefix: {', '.join(invalid_vars)}\n" "See the full list of valid config veriables in src/zeroband/config.py." ) From 47f6bcb598dd5c68e91d94dc67ad7a65f35726c9 Mon Sep 17 00:00:00 2001 From: apaz-cli Date: Thu, 16 Jan 2025 22:24:54 +0000 Subject: [PATCH 09/10] Fix ruff --- src/zeroband/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zeroband/config.py b/src/zeroband/config.py index 302aa23b..ab9a0235 100644 --- a/src/zeroband/config.py +++ b/src/zeroband/config.py @@ -1,4 +1,4 @@ -from typing import Any, Type, Literal, TypeAlias +from typing import Any, Literal, TypeAlias import os from pydantic import create_model, model_validator From fbd995d71a718c13ad2f38b6457808cda8b3b6a8 Mon Sep 17 00:00:00 2001 From: apaz-cli Date: Thu, 16 Jan 2025 22:28:06 +0000 Subject: [PATCH 10/10] Fix ruff again --- scripts/all_reduce.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/all_reduce.py b/scripts/all_reduce.py index c60fbeba..2993b7f6 100644 --- a/scripts/all_reduce.py +++ b/scripts/all_reduce.py @@ -4,7 +4,6 @@ import torch.utils.benchmark as benchmark from zeroband.collectives import Compression, all_reduce -from zeroband.config import resolve_env_vars from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger