Skip to content

Commit

Permalink
fix wandb config run name (#195)
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja authored Jan 16, 2025
1 parent 8bec8a8 commit 39b3c5e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def train(config: Config):
logger_cls = WandbMetricLogger if config.metric_logger_type == "wandb" else DummyMetricLogger
metric_logger = logger_cls(
project=config.project,
config={"config": config.model_dump(), "world_info": world_info.json()},
logger_config={"config": config.model_dump(), "world_info": world_info.json()},
resume=config.wandb_resume,
)
else:
Expand Down
14 changes: 6 additions & 8 deletions src/zeroband/utils/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,26 @@
from typing import Any, Protocol
import importlib.util

from zeroband.config import get_env_config


class MetricLogger(Protocol):
def __init__(self, project, config): ...
def __init__(self, project, logger_config): ...

def log(self, metrics: dict[str, Any]): ...

def finish(self): ...


class WandbMetricLogger(MetricLogger):
def __init__(self, project, config, resume: bool):
def __init__(self, project, logger_config, resume: bool):
if importlib.util.find_spec("wandb") is None:
raise ImportError("wandb is not installed. Please install it to use WandbMonitor.")

import wandb

run_name = get_env_config(config, "run_name")
run_name = logger_config["config"]["run_name"]

wandb.init(
project=project, config=config, name=run_name, resume="auto" if resume else None
project=project, config=logger_config, name=run_name, resume="auto" if resume else None
) # make wandb reuse the same run id if possible

def log(self, metrics: dict[str, Any]):
Expand All @@ -38,9 +36,9 @@ def finish(self):


class DummyMetricLogger(MetricLogger):
def __init__(self, project, config, *args, **kwargs):
def __init__(self, project, logger_config, *args, **kwargs):
self.project = project
self.config = config
self.logger_config = logger_config
open(self.project, "a").close() # Create an empty file to append to

self.data = []
Expand Down

0 comments on commit 39b3c5e

Please sign in to comment.