diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 32f9cdd0..357c2d61 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -203,7 +203,7 @@ def train(config: Config): logger_cls = WandbMetricLogger if config.metric_logger_type == "wandb" else DummyMetricLogger metric_logger = logger_cls( project=config.project, - config={"config": config.model_dump(), "world_info": world_info.json()}, + logger_config={"config": config.model_dump(), "world_info": world_info.json()}, resume=config.wandb_resume, ) else: diff --git a/src/zeroband/utils/metric_logger.py b/src/zeroband/utils/metric_logger.py index 0a47dc3f..85847925 100644 --- a/src/zeroband/utils/metric_logger.py +++ b/src/zeroband/utils/metric_logger.py @@ -2,11 +2,9 @@ from typing import Any, Protocol import importlib.util -from zeroband.config import get_env_config - class MetricLogger(Protocol): - def __init__(self, project, config): ... + def __init__(self, project, logger_config): ... def log(self, metrics: dict[str, Any]): ... @@ -14,16 +12,16 @@ def finish(self): ... class WandbMetricLogger(MetricLogger): - def __init__(self, project, config, resume: bool): + def __init__(self, project, logger_config, resume: bool): if importlib.util.find_spec("wandb") is None: raise ImportError("wandb is not installed. Please install it to use WandbMonitor.") import wandb - run_name = get_env_config(config, "run_name") + run_name = logger_config["config"]["run_name"] wandb.init( - project=project, config=config, name=run_name, resume="auto" if resume else None + project=project, config=logger_config, name=run_name, resume="auto" if resume else None ) # make wandb reuse the same run id if possible def log(self, metrics: dict[str, Any]): @@ -38,9 +36,9 @@ def finish(self): class DummyMetricLogger(MetricLogger): - def __init__(self, project, config, *args, **kwargs): + def __init__(self, project, logger_config, *args, **kwargs): self.project = project - self.config = config + self.logger_config = logger_config open(self.project, "a").close() # Create an empty file to append to self.data = []