Skip to content

Commit

Permalink
Logs best checkpoint metric on W&B
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebgorman committed May 8, 2024
1 parent 09b478d commit 3054142
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 7 deletions.
2 changes: 1 addition & 1 deletion yoyodyne/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
HIDDEN_SIZE = 512
MAX_SOURCE_LENGTH = 128
MAX_TARGET_LENGTH = 128
EVAL_METRICS = {"accuracy"}
EVAL_METRICS = set()

# Training arguments.
BATCH_SIZE = 32
Expand Down
2 changes: 1 addition & 1 deletion yoyodyne/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,5 +326,5 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None:
action=util.UniqueAddAction,
choices=_eval_factory.keys(),
default=defaults.EVAL_METRICS,
help="Which evaluation metrics to use. Default: %(default)s.",
help="Additional metrics to compute. Default: %(default)s.",
)
4 changes: 2 additions & 2 deletions yoyodyne/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import argparse
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, Optional, Set

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -51,7 +51,7 @@ class BaseEncoderDecoder(pl.LightningModule):
source_encoder_cls: modules.base.BaseModule
# Constructed inside __init__.
dropout_layer: nn.Dropout
eval_metrics: List[evaluators.Evaluator]
eval_metrics: Set[evaluators.Evaluator]
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]

def __init__(
Expand Down
18 changes: 15 additions & 3 deletions yoyodyne/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ def _get_logger(experiment: str, model_dir: str, log_wandb: bool) -> List:
trainer_logger = [loggers.CSVLogger(model_dir, name=experiment)]
if log_wandb:
trainer_logger.append(loggers.WandbLogger(project=experiment))
# Tells PTL to log the best validation accuracy.
wandb.define_metric("val_accuracy", summary="max")
# Logs the path to local artifacts made by PTL.
wandb.config["local_run_dir"] = trainer_logger[0].log_dir
return trainer_logger
Expand All @@ -40,6 +38,7 @@ def _get_callbacks(
checkpoint_metric: str = defaults.CHECKPOINT_METRIC,
patience: Optional[int] = None,
patience_metric: str = defaults.PATIENCE_METRIC,
log_wandb: bool = False,
) -> List[callbacks.Callback]:
"""Creates the callbacks.
Expand All @@ -56,6 +55,7 @@ def _get_callbacks(
early stopping.
patience_metric (string, optional): validation metric used to
trigger early stopping.
log_wandb (bool).
Returns:
List[callbacks.Callback]: callbacks.
Expand All @@ -72,6 +72,7 @@ def _get_callbacks(
mode=metric.mode,
monitor=metric.monitor,
patience=patience,
verbose=True,
)
)
# Checkpointing callback. Ensure that this is the last checkpoint,
Expand All @@ -85,6 +86,9 @@ def _get_callbacks(
save_top_k=num_checkpoints,
)
)
# Logs the best value for the checkpointing metric.
if log_wandb:
wandb.define_metric(metric.monitor, summary=metric.mode)
return trainer_callbacks


Expand All @@ -106,6 +110,7 @@ def get_trainer_from_argparse_args(
args.checkpoint_metric,
args.patience,
args.patience_metric,
args.log_wandb,
),
default_root_dir=args.model_dir,
enable_checkpointing=True,
Expand Down Expand Up @@ -167,6 +172,7 @@ def get_model_from_argparse_args(
source_encoder_cls = models.modules.get_encoder_cls(
encoder_arch=args.source_encoder_arch, model_arch=args.arch
)
# Loads expert if needed.
expert = (
models.expert.get_expert(
datamodule.train_dataloader().dataset,
Expand Down Expand Up @@ -198,6 +204,12 @@ def get_model_from_argparse_args(
if not separate_features
else datamodule.index.source_vocab_size
)
# This makes sure we compute all metrics that'll be needed.
eval_metrics = args.eval_metric.copy()
if args.checkpoint_metric != "loss":
eval_metrics.add(args.checkpoint_metric)
if args.patience_metric != "loss":
eval_metrics.add(args.patience_metric)
# Please pass all arguments by keyword and keep in lexicographic order.
return model_cls(
arch=args.arch,
Expand All @@ -211,7 +223,7 @@ def get_model_from_argparse_args(
embedding_size=args.embedding_size,
encoder_layers=args.encoder_layers,
end_idx=datamodule.index.end_idx,
eval_metrics=args.eval_metric,
eval_metrics=eval_metrics,
expert=expert,
features_encoder_cls=features_encoder_cls,
features_vocab_size=features_vocab_size,
Expand Down

0 comments on commit 3054142

Please sign in to comment.