Skip to content

Commit

Permalink
Merge pull request #190 from kylebgorman/unique
Browse files Browse the repository at this point in the history
This enforces uniqueness on --eval_metric.
  • Loading branch information
kylebgorman authored May 8, 2024
2 parents a71a46e + 1beeca7 commit c064067
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 30 deletions.
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,16 @@ information.

### Validation

Validation is run at intervals requested by the user using the lightning interface.
See `--val_check_interval` and `--check_val_every_n_epoch`
Validation is run at intervals requested by the user. See
`--val_check_interval` and `--check_val_every_n_epoch`
[here](https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api).
Particular evaluation metrics can also be requested with `--eval_metric`. For example
Additional evaluation metrics can also be requested with `--eval_metric`. For
example

yoyodyne-train --eval_metric accuracy --eval_metric ser ...
yoyodyne-train --eval_metric ser ...

will compute both accuracy and symbol error rate (SER) each time validation is
requested. Additional metrics can be added in
will additionally compute symbol error rate (SER) each time validation is
performed. Additional metrics can be added to
[`evaluators.py`](yoyodyne/evaluators.py).

### Prediction
Expand Down
3 changes: 2 additions & 1 deletion yoyodyne/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
"ignore", ".*does not have many workers which may be a bottleneck.*"
)
warnings.filterwarnings(
"ignore", ".*option adds dropout after all but last recurrent layer*."
"ignore", ".*option adds dropout after all but last recurrent layer.*"
)
warnings.filterwarnings("ignore", ".*is a wandb run already in progress.*")
2 changes: 1 addition & 1 deletion yoyodyne/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
HIDDEN_SIZE = 512
MAX_SOURCE_LENGTH = 128
MAX_TARGET_LENGTH = 128
EVAL_METRICS = ["accuracy"]
EVAL_METRICS = set()

# Training arguments.
BATCH_SIZE = 32
Expand Down
8 changes: 4 additions & 4 deletions yoyodyne/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from torch.nn import functional

from . import defaults
from . import defaults, util


class Error(Exception):
Expand Down Expand Up @@ -311,7 +311,7 @@ def get_evaluator(eval_metric: str) -> Evaluator:
"""
try:
return _eval_factory[eval_metric]
except KeyError(eval_metric):
except KeyError:
raise Error(f"No evaluation metric {eval_metric}")


Expand All @@ -323,8 +323,8 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None:
"""
parser.add_argument(
"--eval_metric",
action="append",
action=util.UniqueAddAction,
choices=_eval_factory.keys(),
default=defaults.EVAL_METRICS,
help="Which evaluation metrics to use. Default: %(default)s.",
help="Additional metrics to compute. Default: %(default)s.",
)
9 changes: 7 additions & 2 deletions yoyodyne/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ def __init__(self, metric):
"""Initializes the metrics.
Args:
metric (str): one of "accuracy" (maximizes validation accuracy)
or "loss" (minimizes validation loss).
metric (str): one of "accuracy" (maximizes validation accuracy),
"loss" (minimizes validation loss), or "ser" (minimizes
symbol error rate).
Raises:
Error: Unknown metric.
Expand All @@ -35,5 +36,9 @@ def __init__(self, metric):
self.filename = "model-{epoch:03d}-{val_loss:.3f}"
self.mode = "min"
self.monitor = "val_loss"
elif metric == "ser":
self.filename = "model-{epoch:03d}-{val_ser:.3f}"
self.mode = "min"
self.monitor = "val_ser"
else:
raise Error(f"Unknown metric: {metric}")
4 changes: 2 additions & 2 deletions yoyodyne/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import argparse
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, Optional, Set

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -51,7 +51,7 @@ class BaseEncoderDecoder(pl.LightningModule):
source_encoder_cls: modules.base.BaseModule
# Constructed inside __init__.
dropout_layer: nn.Dropout
eval_metrics: List[evaluators.Evaluator]
eval_metrics: Set[evaluators.Evaluator]
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]

def __init__(
Expand Down
42 changes: 28 additions & 14 deletions yoyodyne/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Error(Exception):
pass


def _get_logger(experiment: str, model_dir: str, log_wandb: bool) -> List:
def _get_loggers(experiment: str, model_dir: str, log_wandb: bool) -> List:
"""Creates the logger(s).
Args:
Expand All @@ -25,21 +25,20 @@ def _get_logger(experiment: str, model_dir: str, log_wandb: bool) -> List:
Returns:
List: logger.
"""
trainer_logger = [loggers.CSVLogger(model_dir, name=experiment)]
trainer_loggers = [loggers.CSVLogger(model_dir, name=experiment)]
if log_wandb:
trainer_logger.append(loggers.WandbLogger(project=experiment))
# Tells PTL to log the best validation accuracy.
wandb.define_metric("val_accuracy", summary="max")
trainer_loggers.append(loggers.WandbLogger(project=experiment))
# Logs the path to local artifacts made by PTL.
wandb.config["local_run_dir"] = trainer_logger[0].log_dir
return trainer_logger
wandb.config["local_run_dir"] = trainer_loggers[0].log_dir
return trainer_loggers


def _get_callbacks(
num_checkpoints: int = defaults.NUM_CHECKPOINTS,
checkpoint_metric: str = defaults.CHECKPOINT_METRIC,
patience: Optional[int] = None,
patience_metric: str = defaults.PATIENCE_METRIC,
log_wandb: bool = False,
) -> List[callbacks.Callback]:
"""Creates the callbacks.
Expand All @@ -56,6 +55,7 @@ def _get_callbacks(
early stopping.
patience_metric (string, optional): validation metric used to
trigger early stopping.
log_wandb (bool).
Returns:
List[callbacks.Callback]: callbacks.
Expand Down Expand Up @@ -85,6 +85,9 @@ def _get_callbacks(
save_top_k=num_checkpoints,
)
)
# Logs the best value for the checkpointing metric.
if log_wandb:
wandb.define_metric(metric.monitor, summary=metric.mode)
return trainer_callbacks


Expand All @@ -106,10 +109,11 @@ def get_trainer_from_argparse_args(
args.checkpoint_metric,
args.patience,
args.patience_metric,
args.log_wandb,
),
default_root_dir=args.model_dir,
enable_checkpointing=True,
logger=_get_logger(args.experiment, args.model_dir, args.log_wandb),
logger=_get_loggers(args.experiment, args.model_dir, args.log_wandb),
)


Expand Down Expand Up @@ -167,6 +171,7 @@ def get_model_from_argparse_args(
source_encoder_cls = models.modules.get_encoder_cls(
encoder_arch=args.source_encoder_arch, model_arch=args.arch
)
# Loads expert if needed.
expert = (
models.expert.get_expert(
datamodule.train_dataloader().dataset,
Expand Down Expand Up @@ -198,6 +203,12 @@ def get_model_from_argparse_args(
if not separate_features
else datamodule.index.source_vocab_size
)
# This makes sure we compute all metrics that'll be needed.
eval_metrics = args.eval_metric.copy()
if args.checkpoint_metric != "loss":
eval_metrics.add(args.checkpoint_metric)
if args.patience_metric != "loss":
eval_metrics.add(args.patience_metric)
# Please pass all arguments by keyword and keep in lexicographic order.
return model_cls(
arch=args.arch,
Expand All @@ -211,7 +222,7 @@ def get_model_from_argparse_args(
embedding_size=args.embedding_size,
encoder_layers=args.encoder_layers,
end_idx=datamodule.index.end_idx,
eval_metrics=args.eval_metric,
eval_metrics=eval_metrics,
expert=expert,
features_encoder_cls=features_encoder_cls,
features_vocab_size=features_vocab_size,
Expand Down Expand Up @@ -298,10 +309,10 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None:
)
parser.add_argument(
"--checkpoint_metric",
choices=["accuracy", "loss"],
choices=["accuracy", "loss", "ser"],
default=defaults.CHECKPOINT_METRIC,
help="Selects checkpoints to maximize validation `accuracy` "
"or minimize validation `loss`. "
help="Selects checkpoints to maximize validation `accuracy`, "
"or to minimize validation `loss` or `ser`. "
"Default: %(default)s.",
)
parser.add_argument(
Expand All @@ -313,10 +324,11 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None:
)
parser.add_argument(
"--patience_metric",
choices=["accuracy", "loss"],
choices=["accuracy", "loss", "ser"],
default=defaults.PATIENCE_METRIC,
help="Stops early when validation `accuracy` stops increasing or "
"when validation `loss` stops decreasing. Default: %(default)s.",
"when validation `loss` or `ser` stops decreasing. "
"Default: %(default)s.",
)
parser.add_argument("--seed", type=int, help="Random seed.")
parser.add_argument(
Expand Down Expand Up @@ -366,6 +378,8 @@ def main() -> None:
add_argparse_args(parser)
args = parser.parse_args()
util.log_arguments(args)
if args.log_wandb:
wandb.init()
pl.seed_everything(args.seed)
trainer = get_trainer_from_argparse_args(args)
datamodule = get_datamodule_from_argparse_args(args)
Expand Down
21 changes: 21 additions & 0 deletions yoyodyne/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,27 @@
import argparse
import sys

from typing import Any, Optional


# Argument parsing.


class UniqueAddAction(argparse.Action):
"""Custom action that enforces uniqueness using a set."""

def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Any,
option_string: Optional[str] = None,
) -> None:
getattr(namespace, self.dest).add(values)


# Logging.


def log_info(msg: str) -> None:
"""Logs msg to sys.stderr.
Expand Down

0 comments on commit c064067

Please sign in to comment.