Skip to content

Commit

Permalink
update: pytorch_lightning and gpu training (#149)
Browse files Browse the repository at this point in the history
* fix: deactivate default cudnn, update lightning

* fix: add variable for cudnn

* fix mypy, flag cudnn in trainer, fix argument granular
  • Loading branch information
georgosgeorgos authored Oct 8, 2022
1 parent 8b6ad2d commit 6f3fd0a
Show file tree
Hide file tree
Showing 12 changed files with 45 additions and 32 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ minio==7.0.1
modlamp>=4.0.0
numpy>=1.16.5
protobuf<3.20
pytorch_lightning<=1.5.0
pyarrow<=6.0.1
pytorch_lightning>=1.7.0
pydantic>=1.7.3,<=1.9.2
PyTDC>=0.3.7
pyyaml>=5.4.1
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -238,4 +238,7 @@ ignore_missing_imports = True
ignore_missing_imports = True

[mypy-pdbfixer.*]
ignore_missing_imports = True

[mypy-packaging.*]
ignore_missing_imports = True
9 changes: 8 additions & 1 deletion src/gt4sd/cli/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from dataclasses import dataclass, field
from typing import IO, Iterable, Optional, Tuple, cast

from ..configuration import GT4SDConfiguration
from ..training_pipelines import (
TRAINING_PIPELINE_ARGUMENTS_MAPPING,
TRAINING_PIPELINE_MAPPING,
Expand All @@ -44,6 +45,12 @@
list(set(TRAINING_PIPELINE_ARGUMENTS_MAPPING) & set(TRAINING_PIPELINE_MAPPING))
)

# disable cudnn if issues with gpu training
if GT4SDConfiguration.get_instance().gt4sd_disable_cudnn:
import torch

torch.backends.cudnn.enabled = False


@dataclass
class TrainerArguments:
Expand Down Expand Up @@ -121,7 +128,7 @@ def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: # type: ign
if "gt4sd.cli.trainer.TrainerArguments" not in str(dataclass_type)
]
try:
parsed_arguments = super().parse_json_file(
parsed_arguments = super().parse_json_file( # type:ignore
json_file=json_file, allow_extra_keys=True
)
except Exception:
Expand Down
1 change: 1 addition & 0 deletions src/gt4sd/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class GT4SDConfiguration(BaseSettings):
gt4sd_max_number_of_samples: int = 1000000
gt4sd_max_runtime: int = 86400
gt4sd_create_unverified_ssl_context: bool = False
gt4sd_disable_cudnn: bool = False

gt4sd_s3_host: str = "s3.par01.cloud-object-storage.appdomain.cloud"
gt4sd_s3_access_key: str = "6e9891531d724da89997575a65f4592e"
Expand Down
2 changes: 1 addition & 1 deletion src/gt4sd/frameworks/gflownet/arg_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def parse_arguments_from_config(conf_file: Optional[str] = None) -> argparse.Nam
parser.add_argument("--num_offline", type=int, default=10)
parser.add_argument("--sampling_iterator", type=bool, default=True)
parser.add_argument("--ratio", type=float, default=0.9)
parser.add_argument("--distributed_training_strategy", type=str, default="ddp")
parser.add_argument("--strategy", type=str, default="ddp")
parser.add_argument("--development_mode", type=bool, default=False)

args_dictionary = vars(parser.parse_args(remaining_argv))
Expand Down
2 changes: 1 addition & 1 deletion src/gt4sd/frameworks/gflownet/tests/test_gfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,6 @@ def test_gfn():
max_epochs=1,
flush_logs_every_n_steps=100,
fast_dev_run=True,
accelerator="ddp",
strategy="ddp",
)
trainer.fit(module, dm)
2 changes: 1 addition & 1 deletion src/gt4sd/frameworks/gflownet/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def train_gflownet(
max_epochs=getattr(arguments, "epoch", 10),
check_val_every_n_epoch=getattr(arguments, "checkpoint_every_n_val_epochs", 5),
fast_dev_run=getattr(arguments, "development_mode", False),
accelerator=getattr(arguments, "distributed_training_strategy", "ddp"),
strategy=getattr(arguments, "strategy", "ddp"),
)
trainer.fit(module, dm)

Expand Down
2 changes: 1 addition & 1 deletion src/gt4sd/frameworks/granular/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def train_granular(configuration: Dict[str, Any]) -> None:
"logs", name=getattr(arguments, "basename", "default")
)
checkpoint_callback = ModelCheckpoint(
every_n_val_epochs=getattr(arguments, "checkpoint_every_n_val_epochs", 5),
every_n_epochs=getattr(arguments, "checkpoint_every_n_val_epochs", 5),
save_top_k=-1,
)
trainer = pl.Trainer.from_argparse_args(
Expand Down
25 changes: 16 additions & 9 deletions src/gt4sd/training_pipelines/pytorch_lightning/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

import sentencepiece as _sentencepiece
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
Expand Down Expand Up @@ -75,7 +74,7 @@ def train( # type: ignore
"save_top_k": pl_trainer_args["save_top_k"],
"mode": pl_trainer_args["mode"],
"every_n_train_steps": pl_trainer_args["every_n_train_steps"],
"every_n_val_epochs": pl_trainer_args["every_n_val_epochs"],
"every_n_epochs": pl_trainer_args["every_n_epochs"],
"save_last": pl_trainer_args["save_last"],
}
}
Expand All @@ -86,7 +85,7 @@ def train( # type: ignore
pl_trainer_args["mode"],
pl_trainer_args["every_n_train_steps"],
pl_trainer_args["save_last"],
pl_trainer_args["every_n_val_epochs"],
pl_trainer_args["every_n_epochs"],
)

pl_trainer_args["callbacks"] = self.add_callbacks(pl_trainer_args["callbacks"])
Expand Down Expand Up @@ -120,7 +119,7 @@ def get_data_and_model_modules(
"Can't get data and model modules for an abstract training pipeline."
)

def add_callbacks(self, callback_args: Dict[str, Any]) -> List[Callback]:
def add_callbacks(self, callback_args: Dict[str, Any]) -> List[Any]:
"""Create the requested callbacks for training.
Args:
Expand All @@ -130,7 +129,7 @@ def add_callbacks(self, callback_args: Dict[str, Any]) -> List[Callback]:
list of pytorch lightning callbacks.
"""

callbacks: List[Callback] = []
callbacks: List[Any] = []
if "early_stopping_callback" in callback_args:
callbacks.append(EarlyStopping(**callback_args["early_stopping_callback"]))

Expand All @@ -150,8 +149,8 @@ class PytorchLightningTrainingArguments(TrainingPipelineArguments):

__name__ = "pl_trainer_args"

accelerator: Optional[str] = field(
default="ddp", metadata={"help": "Accelerator type."}
strategy: Optional[str] = field(
default="ddp", metadata={"help": "Training strategy."}
)
accumulate_grad_batches: int = field(
default=1,
Expand Down Expand Up @@ -199,8 +198,8 @@ class PytorchLightningTrainingArguments(TrainingPipelineArguments):
"help": "When True, always saves the model at the end of the epoch to a file last.ckpt"
},
)
save_top_k: Optional[int] = field(
default=None,
save_top_k: int = field(
default=1,
metadata={
"help": "The best k models according to the quantity monitored will be saved."
},
Expand All @@ -213,3 +212,11 @@ class PytorchLightningTrainingArguments(TrainingPipelineArguments):
default=None,
metadata={"help": "Number of training steps between checkpoints."},
)
every_n_epochs: Optional[int] = field(
default=None,
metadata={"help": "Number of epochs between checkpoints."},
)
check_val_every_n_epoch: Optional[int] = field(
default=None,
metadata={"help": "Number of validation epochs between evaluations."},
)
19 changes: 7 additions & 12 deletions src/gt4sd/training_pipelines/pytorch_lightning/gflownet/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def train( # type: ignore
log_every_n_steps=pl_trainer_args["trainer_log_every_n_steps"],
callbacks=pl_trainer_args["callbacks"],
max_epochs=pl_trainer_args["epochs"],
accelerator=pl_trainer_args["accelerator"],
strategy=pl_trainer_args["strategy"],
fast_dev_run=pl_trainer_args["development_mode"],
)

Expand Down Expand Up @@ -193,8 +193,8 @@ class GFlowNetPytorchLightningTrainingArguments(PytorchLightningTrainingArgument

__name__ = "pl_trainer_args"

accelerator: Optional[str] = field(
default="ddp", metadata={"help": "Accelerator type."}
strategy: Optional[str] = field(
default="ddp", metadata={"help": "Training strategy."}
)
accumulate_grad_batches: int = field(
default=1,
Expand Down Expand Up @@ -253,8 +253,8 @@ class GFlowNetPytorchLightningTrainingArguments(PytorchLightningTrainingArgument
"help": "When True, always saves the model at the end of the epoch to a file last.ckpt"
},
)
save_top_k: Optional[int] = field(
default=-1,
save_top_k: int = field(
default=1,
metadata={
"help": "The best k models according to the quantity monitored will be saved."
},
Expand All @@ -267,9 +267,9 @@ class GFlowNetPytorchLightningTrainingArguments(PytorchLightningTrainingArgument
default=None,
metadata={"help": "Number of training steps between checkpoints."},
)
every_n_val_epochs: Optional[int] = field(
check_val_every_n_epoch: Optional[int] = field(
default=5,
metadata={"help": "Number of training epochs between checkpoints."},
metadata={"help": "Number of validation epochs between checkpoints."},
)
auto_lr_find: bool = field(
default=True,
Expand Down Expand Up @@ -323,11 +323,6 @@ class GFlowNetPytorchLightningTrainingArguments(PytorchLightningTrainingArgument
default="cpu",
metadata={"help": "The device to use."},
)
distributed_training_strategy: str = field(
default="ddp",
metadata={"help": "The distributed training strategy. "},
)

development_mode: bool = field(
default=False,
metadata={"help": "Whether to run in development mode. "},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class GranularPytorchLightningTrainingArguments(PytorchLightningTrainingArgument

__name__ = "pl_trainer_args"

every_n_val_epochs: Optional[int] = field(
check_val_every_n_epoch: Optional[int] = field(
default=5,
metadata={"help": "Number of training epochs between checkpoints."},
)
Expand Down
7 changes: 3 additions & 4 deletions src/gt4sd/training_pipelines/tests/test_training_gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def _create_training_output_filepaths(directory: str) -> Dict[str, str]:
"num_offline": 10,
},
"pl_trainer_args": {
"accelerator": "ddp",
"strategy": "ddp",
"basename": "gflownet",
"every_n_val_epochs": 5,
"check_val_every_n_epoch": 5,
"trainer_log_every_n_steps": 50,
"auto_lr_find": True,
"profiler": "simple",
Expand All @@ -103,10 +103,9 @@ def _create_training_output_filepaths(directory: str) -> Dict[str, str]:
"validate_every": 1000,
"seed": 142857,
"device": "cpu",
"distributed_training_strategy": "ddp",
"development_mode": True,
"resume_from_checkpoint": None,
"save_top_k": -1,
"save_top_k": 1,
"epochs": 3,
},
"dataset_args": {
Expand Down

0 comments on commit 6f3fd0a

Please sign in to comment.