-
Notifications
You must be signed in to change notification settings - Fork 429
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create callback to load checkpoint (#3641)
- Loading branch information
Showing
11 changed files
with
253 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright 2024 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Load a checkpoint.""" | ||
import logging | ||
from typing import Optional, Union | ||
|
||
from composer.core import Callback, State | ||
from composer.core.event import Event | ||
from composer.loggers import Logger | ||
from composer.models.huggingface import HuggingFaceModel | ||
from composer.utils.checkpoint import load_checkpoint | ||
from composer.utils.file_helpers import maybe_create_object_store_from_uri, parse_uri | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class LoadCheckpoint(Callback): | ||
"""Callback that loads a checkpoint at the specified event. | ||
Args: | ||
load_path (str): The path to the checkpoint to load. | ||
load_options (Optional[dict]): A dictionary of options to pass to the checkpoint loading function. | ||
event (Union[str, Event]): The event at which to load the checkpoint. Defaults to ``Event.BEFORE_LOAD``. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
load_path: str, | ||
load_weights_only: bool = False, | ||
strict_model_weights: bool = True, | ||
ignore_keys: Optional[list[str]] = None, | ||
event: Union[str, Event] = Event.BEFORE_LOAD, | ||
): | ||
super().__init__() | ||
self.load_path = load_path | ||
self.load_object_store = maybe_create_object_store_from_uri(load_path) | ||
_, _, self.parsed_path = parse_uri(load_path) | ||
|
||
self.load_weights_only = load_weights_only | ||
self.strict_model_weights = strict_model_weights | ||
self.ignore_keys = ignore_keys | ||
|
||
self.event = event if isinstance(event, Event) else Event[event.upper()] | ||
|
||
def run_event(self, event: Event, state: State, logger: Logger) -> None: | ||
if event == self.event: | ||
log.info(f'Loading checkpoint from {self.load_path} at {self.event}.') | ||
self._load(state, logger) | ||
log.info(f'Finished loading checkpoint from {self.load_path} at {self.event}.') | ||
|
||
return super().run_event(event, state, logger) | ||
|
||
def _load(self, state: State, logger: Logger) -> None: | ||
|
||
# We need to temporarily disable the `should_save_peft_only` flag on the model | ||
# so that we can have access to the full model weights for loading. | ||
model = state.model | ||
original_should_save_peft_only = False | ||
if isinstance(model, HuggingFaceModel): | ||
original_should_save_peft_only = model.should_save_peft_only | ||
model.should_save_peft_only = False | ||
|
||
load_checkpoint( | ||
path=self.parsed_path, | ||
state=state, | ||
logger=logger, | ||
object_store=self.load_object_store, | ||
strict_model_weights=self.strict_model_weights, | ||
ignore_keys=self.ignore_keys, | ||
load_weights_only=self.load_weights_only, | ||
) | ||
|
||
# Restore the original `should_save_peft_only` flag on the model | ||
if isinstance(model, HuggingFaceModel): | ||
model.should_save_peft_only = original_should_save_peft_only |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright 2024 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from unittest import mock | ||
from unittest.mock import call | ||
|
||
from torch.utils.data import DataLoader | ||
|
||
from composer.callbacks import LoadCheckpoint | ||
from composer.core.state import State | ||
from composer.models.huggingface import HuggingFaceModel | ||
from composer.trainer.trainer import Trainer | ||
from tests.common.datasets import RandomTextLMDataset | ||
|
||
|
||
def test_load_checkpoint_callback( | ||
tiny_gpt2_model, | ||
tiny_gpt2_tokenizer, | ||
gpt2_peft_config, | ||
): | ||
|
||
model = HuggingFaceModel( | ||
tiny_gpt2_model, | ||
tokenizer=tiny_gpt2_tokenizer, | ||
peft_config=gpt2_peft_config, | ||
should_save_peft_only=True, | ||
) | ||
|
||
# Function to check the arguments passed to the load_checkpoint function. | ||
def check_callback_load_args(state: State, **kwargs): | ||
assert state.model == model | ||
|
||
# Check that the `should_save_peft_only` flag on the model was set to False when loading the checkpoint. | ||
assert state.model.should_save_peft_only == False | ||
|
||
# Patch the load_checkpoint function to check the arguments passed to it. | ||
with mock.patch( | ||
'composer.callbacks.load_checkpoint.load_checkpoint', | ||
new=mock.MagicMock(wraps=check_callback_load_args), | ||
) as callback_load: | ||
with mock.patch('composer.trainer.trainer.checkpoint.load_checkpoint') as trainer_load: | ||
|
||
calls = mock.MagicMock() | ||
calls.attach_mock(trainer_load, 'trainer_load') | ||
calls.attach_mock(callback_load, 'callback_load') | ||
|
||
Trainer( | ||
model=model, | ||
callbacks=[LoadCheckpoint( | ||
load_path='fake-path', | ||
event='BEFORE_LOAD', | ||
)], | ||
train_dataloader=DataLoader(RandomTextLMDataset()), | ||
max_duration='1ba', | ||
load_path='fake_path', | ||
) | ||
|
||
callback_load.assert_called_once() | ||
trainer_load.assert_called_once() | ||
|
||
# Assert that the callback_load and trainer_load functions were called in the correct order. | ||
assert calls.mock_calls == [ | ||
call.callback_load(**callback_load.call_args.kwargs), | ||
call.trainer_load(**trainer_load.call_args.kwargs), | ||
] | ||
|
||
# Check that the `should_save_peft_only` flag on the model was reset to its original value after loading the checkpoint. | ||
assert model.should_save_peft_only == True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.