Skip to content

Commit

Permalink
saving time by not saving models during grid search
Browse files Browse the repository at this point in the history
  • Loading branch information
donglihe-hub committed Jun 17, 2024
1 parent 8017338 commit 10383b7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
16 changes: 15 additions & 1 deletion libmultilabel/nn/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def init_trainer(
limit_val_batches=1.0,
limit_test_batches=1.0,
save_checkpoints=True,
is_tune_mode=False,
):
"""Initialize a torch lightning trainer.
Expand All @@ -146,6 +147,7 @@ def init_trainer(
limit_val_batches (Union[int, float]): Percentage of validation dataset to use. Defaults to 1.0.
limit_test_batches (Union[int, float]): Percentage of test dataset to use. Defaults to 1.0.
save_checkpoints (bool): Whether to save the last and the best checkpoint or not. Defaults to True.
is_tune_mode (bool): Whether is parameter search is running or not. Defaults to False.
Returns:
lightning.trainer: A torch lightning trainer.
Expand All @@ -163,7 +165,19 @@ def init_trainer(
strict=False,
)
callbacks = [early_stopping_callback]
if save_checkpoints:

if is_tune_mode:
callbacks += [
ModelCheckpoint(
dirpath=checkpoint_dir,
filename="best_model",
save_top_k=1,
save_weights_only=True,
monitor=val_metric,
mode="min" if val_metric == "Loss" else "max",
)
]
elif save_checkpoints:
callbacks += [
ModelCheckpoint(
dirpath=checkpoint_dir,
Expand Down
3 changes: 2 additions & 1 deletion search_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def train_libmultilabel_tune(config, datasets, classes, word_dict):
datasets=datasets,
classes=classes,
word_dict=word_dict,
save_checkpoints=True,
save_checkpoints=False,
is_tune_mode=True,
)
val_score = trainer.train()
return {f"val_{config.val_metric}": val_score}
Expand Down
2 changes: 2 additions & 0 deletions torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
word_dict: dict = None,
embed_vecs=None,
save_checkpoints: bool = True,
is_tune_mode: bool = False,
):
self.run_name = config.run_name
self.checkpoint_dir = config.checkpoint_dir
Expand Down Expand Up @@ -119,6 +120,7 @@ def __init__(
limit_val_batches=config.limit_val_batches,
limit_test_batches=config.limit_test_batches,
save_checkpoints=save_checkpoints,
is_tune_mode=is_tune_mode,
)
callbacks = [callback for callback in self.trainer.callbacks if isinstance(callback, ModelCheckpoint)]
self.checkpoint_callback = callbacks[0] if callbacks else None
Expand Down

0 comments on commit 10383b7

Please sign in to comment.