From 1e1c04d6b8f543801ae2a6ff2c422370d79d941c Mon Sep 17 00:00:00 2001 From: Antoine Broyelle Date: Mon, 3 Jun 2024 19:48:31 +0200 Subject: [PATCH] Optional CheckpointSaver instantiation inside the Trainer (#3334) --- composer/trainer/trainer.py | 23 ++++++++- tests/trainer/test_checkpoint.py | 84 +++++++++++++++++++++++++++++--- 2 files changed, 100 insertions(+), 7 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 8f6c3974de..353cd97258 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1397,7 +1397,28 @@ def __init__( # Checkpoint Saving self._checkpoint_saver = None latest_remote_file_name = None - if save_folder is not None: + + _checkpoint_savers = [cb for cb in self.state.callbacks if isinstance(cb, CheckpointSaver)] + if len(_checkpoint_savers) >= 1: + if len(_checkpoint_savers) > 1: + log.info('Multiple CheckpointSaver provided as callbacks. Using the first one as reference.') + self._checkpoint_saver = _checkpoint_savers[0] + + if self._checkpoint_saver.folder != save_folder: + log.info(f'Using {self._checkpoint_saver.folder} as save_folder.') + save_folder = self._checkpoint_saver.folder + + if self._checkpoint_saver.latest_filename is None: + save_latest_filename = None + log.info(f'Using {save_latest_filename} as latest_filename.') + elif self._checkpoint_saver.latest_filename.filename != save_latest_filename: + save_latest_filename = str(self._checkpoint_saver.latest_filename.filename) + log.info(f'Using {save_latest_filename} as latest_filename.') + + if self._checkpoint_saver.latest_remote_file_name is not None: + latest_remote_file_name = str(self._checkpoint_saver.latest_remote_file_name.filename) + + if self._checkpoint_saver is None and save_folder is not None: if save_weights_only: log.info( 'save_weights_only=True now also saves metadata and integrations! Please adjust your workflow accordingly.', diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 41411e9697..d23b55875f 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -33,6 +33,7 @@ from composer.utils import dist, is_tar, reproducibility from composer.utils.checkpoint import ( _COMPOSER_STATES_FILENAME, + PartialFilePath, _ensure_valid_checkpoint, _write_checkpoint_file, glob_filter, @@ -394,9 +395,9 @@ def test_checkpoint_saver_properly_constructed( ): mock_validate_credentials = MagicMock() monkeypatch.setattr(remote_uploader_downloader, '_validate_credentials', mock_validate_credentials) - mock_checkpoint_saver = MagicMock() - monkeypatch.setattr(trainer, 'CheckpointSaver', mock_checkpoint_saver) - self.get_trainer(save_folder=save_folder) + + trainer = self.get_trainer(save_folder=save_folder) + expected_prefix = expected_path + '/' if expected_path != '' else expected_path rest_of_checkpoint_saver_kwargs = { 'filename': 'ep{epoch}-ba{batch}-rank{rank}.pt', @@ -409,8 +410,14 @@ def test_checkpoint_saver_properly_constructed( 'num_checkpoints_to_keep': -1, 'ignore_keys': None, } - expected_folder = expected_path.rstrip('/') if expected_path != '' else '.' - mock_checkpoint_saver.assert_called_once_with(folder=expected_folder, **rest_of_checkpoint_saver_kwargs) + for attr_name, value in rest_of_checkpoint_saver_kwargs.items(): + attr = getattr(trainer._checkpoint_saver, attr_name) + if attr_name == 'save_interval': + assert attr.__closure__[-1].cell_contents == Time.from_timestring(value) + elif isinstance(attr, PartialFilePath): + assert attr.filename == value + else: + assert attr == value @pytest.mark.parametrize('save_interval', ['1tok', '64tok', '65tok']) @pytest.mark.parametrize('batch_size', [1, 4]) @@ -616,6 +623,29 @@ def test_checkpoint_intervals( # we should have one extra call from the fit end checkpoint assert trainer._checkpoint_saver._save_checkpoint.call_count == expected_save_calls + @pytest.mark.parametrize(('save_folder'), [None, 'local_checkpoints']) + @pytest.mark.parametrize(('save_latest_filename'), [None, 'latest.pt']) + def test_checkpoint_multiple_callbacks( + self, + save_folder: Optional[str], + save_latest_filename: Optional[str], + tmp_path: pathlib.Path, + ): + checkpoint_savers = [ + CheckpointSaver(str(tmp_path / 'checkpoints1')), + CheckpointSaver(str(tmp_path / 'checkpoints2')), + ] + + trainer = self.get_trainer( + max_duration='1ep', + callbacks=checkpoint_savers, + save_folder=save_folder, + save_latest_filename=save_latest_filename, + ) + + assert id(trainer._checkpoint_saver) == id(checkpoint_savers[0]) + assert len([cb for cb in trainer.state.callbacks if isinstance(cb, CheckpointSaver)]) == len(checkpoint_savers) + class TestCheckpointLoading: @@ -647,6 +677,11 @@ def get_trainer( eval_dataset = RandomImageDataset() train_batch_size = 2 + callbacks = [DummyStatefulCallback()] + if 'callbacks' in kwargs: + callbacks += kwargs['callbacks'] + del kwargs['callbacks'] + return Trainer( model=model, train_dataloader=DataLoader( @@ -670,7 +705,7 @@ def get_trainer( max_duration=max_duration, optimizers=optimizer, schedulers=ExponentialScheduler(gamma=0.9), - callbacks=[DummyStatefulCallback()], + callbacks=callbacks, **kwargs, ) @@ -769,6 +804,43 @@ def test_autoresume( assert trainer_1.state.run_name == trainer_2.state.run_name + @pytest.mark.parametrize(('save_folder'), [None, 'first']) + def test_autoresume_from_callback( + self, + save_folder: Optional[str], + tmp_path: pathlib.Path, + ): + checkpoint_saver = CheckpointSaver(str(tmp_path / 'checkpoints'), latest_filename='latest-rank{rank}.pt') + + trainer_1 = self.get_trainer( + file_extension='.pt', + save_folder=save_folder, + device='cpu', + run_name='big-chungus', + autoresume=True, + callbacks=[checkpoint_saver], + ) + + # trains the model, saving the checkpoint files + trainer_1.fit() + trainer_1.close() + + trainer_2 = self.get_trainer( + file_extension='.pt', + save_folder=save_folder, + device='cpu', + run_name='big-chungus', + autoresume=True, + callbacks=[checkpoint_saver], + ) + + self._assert_weights_equivalent( + trainer_1.state.model, + trainer_2.state.model, + ) + + assert trainer_1.state.run_name == trainer_2.state.run_name + @pytest.mark.parametrize( 'load_path,load_object_store', [