From 600ed3ea35070144043dd2bf9d492cf16e02d9e1 Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Mon, 15 Jan 2024 14:47:20 -0500 Subject: [PATCH] seems to block on multi-g --- .../training/train_with_pytorch_lightning.py | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/zoobot/pytorch/training/train_with_pytorch_lightning.py b/zoobot/pytorch/training/train_with_pytorch_lightning.py index 99c96f1e..c90a386d 100644 --- a/zoobot/pytorch/training/train_with_pytorch_lightning.py +++ b/zoobot/pytorch/training/train_with_pytorch_lightning.py @@ -341,24 +341,27 @@ def train_default_zoobot_from_scratch( best_model_path = trainer.checkpoint_callback.best_model_path - # can test as per the below, but note that datamodule must have a test dataset attribute as per pytorch lightning docs. - # also be careful not to test regularly, as this breaks train/val/test conceptual separation and may cause hparam overfitting - if datamodule.test_dataloader is not None: - logging.info(f'Testing on {checkpoint_callback.best_model_path} with single GPU. Be careful not to overfit your choices to the test data...') - # test_trainer.validate( - # model=lightning_model, - # datamodule=datamodule, - # ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt" - # ) - datamodule.setup(stage='test') - # temp - print(datamodule.test_urls) - test_trainer.test( - model=lightning_model, - datamodule=datamodule, - ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt" - ) - # TODO may need to remake on 1 gpu only + if test_trainer.is_global_zero: + # can test as per the below, but note that datamodule must have a test dataset attribute as per pytorch lightning docs. + # also be careful not to test regularly, as this breaks train/val/test conceptual separation and may cause hparam overfitting + if datamodule.test_dataloader is not None: + logging.info(f'Testing on {checkpoint_callback.best_model_path} with single GPU. Be careful not to overfit your choices to the test data...') + # test_trainer.validate( + # model=lightning_model, + # datamodule=datamodule, + # ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt" + # ) + datamodule.setup(stage='test') + test_trainer.test( + model=lightning_model, + datamodule=datamodule, + ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt" + ) + # TODO may need to remake on 1 gpu only + else: + logging.info('No test dataloader found, skipping test metrics') + else: + logging.info('Not global zero, skipping test metrics') # explicitly update the model weights to the best checkpoint before returning # (assumes only one checkpoint callback, very likely in practice)