Skip to content

Commit

Permalink
Last bits of the equation
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebgorman committed May 8, 2024
1 parent 3054142 commit 1beeca7
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,16 @@ information.

### Validation

Validation is run at intervals requested by the user using the lightning interface.
See `--val_check_interval` and `--check_val_every_n_epoch`
Validation is run at intervals requested by the user. See
`--val_check_interval` and `--check_val_every_n_epoch`
[here](https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api).
Particular evaluation metrics can also be requested with `--eval_metric`. For example
Additional evaluation metrics can also be requested with `--eval_metric`. For
example

yoyodyne-train --eval_metric accuracy --eval_metric ser ...
yoyodyne-train --eval_metric ser ...

will compute both accuracy and symbol error rate (SER) each time validation is
requested. Additional metrics can be added in
will additionally compute symbol error rate (SER) each time validation is
performed. Additional metrics can be added to
[`evaluators.py`](yoyodyne/evaluators.py).

### Prediction
Expand Down
3 changes: 2 additions & 1 deletion yoyodyne/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
"ignore", ".*does not have many workers which may be a bottleneck.*"
)
warnings.filterwarnings(
"ignore", ".*option adds dropout after all but last recurrent layer*."
"ignore", ".*option adds dropout after all but last recurrent layer.*"
)
warnings.filterwarnings("ignore", ".*is a wandb run already in progress.*")
15 changes: 8 additions & 7 deletions yoyodyne/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Error(Exception):
pass


def _get_logger(experiment: str, model_dir: str, log_wandb: bool) -> List:
def _get_loggers(experiment: str, model_dir: str, log_wandb: bool) -> List:
"""Creates the logger(s).
Args:
Expand All @@ -25,12 +25,12 @@ def _get_logger(experiment: str, model_dir: str, log_wandb: bool) -> List:
Returns:
List: logger.
"""
trainer_logger = [loggers.CSVLogger(model_dir, name=experiment)]
trainer_loggers = [loggers.CSVLogger(model_dir, name=experiment)]
if log_wandb:
trainer_logger.append(loggers.WandbLogger(project=experiment))
trainer_loggers.append(loggers.WandbLogger(project=experiment))
# Logs the path to local artifacts made by PTL.
wandb.config["local_run_dir"] = trainer_logger[0].log_dir
return trainer_logger
wandb.config["local_run_dir"] = trainer_loggers[0].log_dir
return trainer_loggers


def _get_callbacks(
Expand Down Expand Up @@ -72,7 +72,6 @@ def _get_callbacks(
mode=metric.mode,
monitor=metric.monitor,
patience=patience,
verbose=True,
)
)
# Checkpointing callback. Ensure that this is the last checkpoint,
Expand Down Expand Up @@ -114,7 +113,7 @@ def get_trainer_from_argparse_args(
),
default_root_dir=args.model_dir,
enable_checkpointing=True,
logger=_get_logger(args.experiment, args.model_dir, args.log_wandb),
logger=_get_loggers(args.experiment, args.model_dir, args.log_wandb),
)


Expand Down Expand Up @@ -379,6 +378,8 @@ def main() -> None:
add_argparse_args(parser)
args = parser.parse_args()
util.log_arguments(args)
if args.log_wandb:
wandb.init()
pl.seed_everything(args.seed)
trainer = get_trainer_from_argparse_args(args)
datamodule = get_datamodule_from_argparse_args(args)
Expand Down

0 comments on commit 1beeca7

Please sign in to comment.