Skip to content

Commit

Permalink
Create callback to load checkpoint (#3641)
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Oct 7, 2024
1 parent 4e1ec17 commit bb7ea43
Show file tree
Hide file tree
Showing 11 changed files with 253 additions and 54 deletions.
2 changes: 2 additions & 0 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from composer.callbacks.free_outputs import FreeOutputs
from composer.callbacks.generate import Generate
from composer.callbacks.image_visualizer import ImageVisualizer
from composer.callbacks.load_checkpoint import LoadCheckpoint
from composer.callbacks.lr_monitor import LRMonitor
from composer.callbacks.memory_monitor import MemoryMonitor
from composer.callbacks.memory_snapshot import MemorySnapshot
Expand Down Expand Up @@ -44,4 +45,5 @@
'FreeOutputs',
'MemorySnapshot',
'OOMObserver',
'LoadCheckpoint',
]
76 changes: 76 additions & 0 deletions composer/callbacks/load_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Load a checkpoint."""
import logging
from typing import Optional, Union

from composer.core import Callback, State
from composer.core.event import Event
from composer.loggers import Logger
from composer.models.huggingface import HuggingFaceModel
from composer.utils.checkpoint import load_checkpoint
from composer.utils.file_helpers import maybe_create_object_store_from_uri, parse_uri

log = logging.getLogger(__name__)


class LoadCheckpoint(Callback):
"""Callback that loads a checkpoint at the specified event.
Args:
load_path (str): The path to the checkpoint to load.
load_options (Optional[dict]): A dictionary of options to pass to the checkpoint loading function.
event (Union[str, Event]): The event at which to load the checkpoint. Defaults to ``Event.BEFORE_LOAD``.
"""

def __init__(
self,
load_path: str,
load_weights_only: bool = False,
strict_model_weights: bool = True,
ignore_keys: Optional[list[str]] = None,
event: Union[str, Event] = Event.BEFORE_LOAD,
):
super().__init__()
self.load_path = load_path
self.load_object_store = maybe_create_object_store_from_uri(load_path)
_, _, self.parsed_path = parse_uri(load_path)

self.load_weights_only = load_weights_only
self.strict_model_weights = strict_model_weights
self.ignore_keys = ignore_keys

self.event = event if isinstance(event, Event) else Event[event.upper()]

def run_event(self, event: Event, state: State, logger: Logger) -> None:
if event == self.event:
log.info(f'Loading checkpoint from {self.load_path} at {self.event}.')
self._load(state, logger)
log.info(f'Finished loading checkpoint from {self.load_path} at {self.event}.')

return super().run_event(event, state, logger)

def _load(self, state: State, logger: Logger) -> None:

# We need to temporarily disable the `should_save_peft_only` flag on the model
# so that we can have access to the full model weights for loading.
model = state.model
original_should_save_peft_only = False
if isinstance(model, HuggingFaceModel):
original_should_save_peft_only = model.should_save_peft_only
model.should_save_peft_only = False

load_checkpoint(
path=self.parsed_path,
state=state,
logger=logger,
object_store=self.load_object_store,
strict_model_weights=self.strict_model_weights,
ignore_keys=self.ignore_keys,
load_weights_only=self.load_weights_only,
)

# Restore the original `should_save_peft_only` flag on the model
if isinstance(model, HuggingFaceModel):
model.should_save_peft_only = original_should_save_peft_only
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ filterwarnings = [
'''ignore:'.*_state_dict' is deprecated and will be removed in future versions.*:UserWarning''',
# Ignore mlflow warnings about transformers versions,
'''ignore:The 'transformers' MLflow Models integration.*:UserWarning''',
# Ignore our own deprecation warnings,
# Ignore the flash v3 warnings from transformer engine
'''ignore:To use flash-attn v3*:UserWarning''',
# Ignore our own deprecation warnings
'''ignore::composer.utils.warnings.VersionedDeprecationWarning''',
# Ignore deprecation warning for torch.load
'''ignore:You are using `torch.load` with `weights_only=False`.*:FutureWarning''',
Expand Down
14 changes: 14 additions & 0 deletions tests/callbacks/callback_settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import contextlib
import os
from typing import Any
from unittest import mock
from unittest.mock import MagicMock

import pytest
Expand All @@ -26,6 +28,7 @@
SystemMetricsMonitor,
ThresholdStopper,
)
from composer.callbacks.load_checkpoint import LoadCheckpoint
from composer.loggers import (
CometMLLogger,
ConsoleLogger,
Expand Down Expand Up @@ -155,6 +158,9 @@
'trace_handlers': [MagicMock()],
'schedule': composer.profiler.cyclic_schedule(),
},
LoadCheckpoint: {
'load_path': 'fake-path',
},
}

_callback_marks: dict[
Expand Down Expand Up @@ -201,6 +207,14 @@
NeptuneLogger: [pytest.mark.skipif(not _NEPTUNE_INSTALLED, reason='neptune is optional')],
}

_callback_patches: dict[type[Callback], Any] = {
LoadCheckpoint: mock.patch('composer.callbacks.load_checkpoint.load_checkpoint'),
}


def get_cb_patches(impl: type[Callback]):
return _callback_patches.get(impl, contextlib.nullcontext())


def get_cb_kwargs(impl: type[Callback]):
return _callback_kwargs.get(impl, {})
Expand Down
23 changes: 18 additions & 5 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from composer.loggers import Logger, LoggerDestination
from composer.profiler import Profiler, ProfilerAction
from composer.trainer import Trainer
from tests.callbacks.callback_settings import get_cb_kwargs, get_cb_model_and_datasets, get_cbs_and_marks
from tests.callbacks.callback_settings import (
get_cb_kwargs,
get_cb_model_and_datasets,
get_cb_patches,
get_cbs_and_marks,
)
from tests.common import EventCounterCallback


Expand Down Expand Up @@ -154,8 +159,12 @@ def test_trains(self, cb_cls: type[Callback], device_train_microbatch_size: int,
del _remote # unused. `_remote` must be passed through to parameterize the test markers.
cb_kwargs = get_cb_kwargs(cb_cls)
cb = cb_cls(**cb_kwargs)
trainer = self._get_trainer(cb, device_train_microbatch_size)
trainer.fit()

maybe_patch_context = get_cb_patches(cb_cls)

with maybe_patch_context:
trainer = self._get_trainer(cb, device_train_microbatch_size)
trainer.fit()

@pytest.mark.filterwarnings('ignore::UserWarning')
def test_trains_multiple_calls(self, cb_cls: type[Callback], device_train_microbatch_size: int, _remote: bool):
Expand All @@ -167,8 +176,12 @@ def test_trains_multiple_calls(self, cb_cls: type[Callback], device_train_microb
del _remote # unused. `_remote` must be passed through to parameterize the test markers.
cb_kwargs = get_cb_kwargs(cb_cls)
cb = cb_cls(**cb_kwargs)
trainer = self._get_trainer(cb, device_train_microbatch_size)
trainer.fit()

maybe_patch_context = get_cb_patches(cb_cls)

with maybe_patch_context:
trainer = self._get_trainer(cb, device_train_microbatch_size)
trainer.fit()

assert trainer.state.max_duration is not None
trainer.state.max_duration = cast(Time[int], trainer.state.max_duration * 2)
Expand Down
68 changes: 68 additions & 0 deletions tests/callbacks/test_load_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

from unittest import mock
from unittest.mock import call

from torch.utils.data import DataLoader

from composer.callbacks import LoadCheckpoint
from composer.core.state import State
from composer.models.huggingface import HuggingFaceModel
from composer.trainer.trainer import Trainer
from tests.common.datasets import RandomTextLMDataset


def test_load_checkpoint_callback(
tiny_gpt2_model,
tiny_gpt2_tokenizer,
gpt2_peft_config,
):

model = HuggingFaceModel(
tiny_gpt2_model,
tokenizer=tiny_gpt2_tokenizer,
peft_config=gpt2_peft_config,
should_save_peft_only=True,
)

# Function to check the arguments passed to the load_checkpoint function.
def check_callback_load_args(state: State, **kwargs):
assert state.model == model

# Check that the `should_save_peft_only` flag on the model was set to False when loading the checkpoint.
assert state.model.should_save_peft_only == False

# Patch the load_checkpoint function to check the arguments passed to it.
with mock.patch(
'composer.callbacks.load_checkpoint.load_checkpoint',
new=mock.MagicMock(wraps=check_callback_load_args),
) as callback_load:
with mock.patch('composer.trainer.trainer.checkpoint.load_checkpoint') as trainer_load:

calls = mock.MagicMock()
calls.attach_mock(trainer_load, 'trainer_load')
calls.attach_mock(callback_load, 'callback_load')

Trainer(
model=model,
callbacks=[LoadCheckpoint(
load_path='fake-path',
event='BEFORE_LOAD',
)],
train_dataloader=DataLoader(RandomTextLMDataset()),
max_duration='1ba',
load_path='fake_path',
)

callback_load.assert_called_once()
trainer_load.assert_called_once()

# Assert that the callback_load and trainer_load functions were called in the correct order.
assert calls.mock_calls == [
call.callback_load(**callback_load.call_args.kwargs),
call.trainer_load(**trainer_load.call_args.kwargs),
]

# Check that the `should_save_peft_only` flag on the model was reset to its original value after loading the checkpoint.
assert model.should_save_peft_only == True
28 changes: 18 additions & 10 deletions tests/callbacks/test_loggers_across_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from composer.loggers import ConsoleLogger, LoggerDestination, ProgressBarLogger, SlackLogger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.trainer import Trainer
from tests.callbacks.callback_settings import get_cb_kwargs, get_cb_model_and_datasets, get_cbs_and_marks
from tests.callbacks.callback_settings import (
get_cb_kwargs,
get_cb_model_and_datasets,
get_cb_patches,
get_cbs_and_marks,
)


@pytest.mark.parametrize('logger_cls', get_cbs_and_marks(loggers=True))
Expand All @@ -27,12 +32,15 @@ def test_loggers_on_callbacks(logger_cls: type[LoggerDestination], callback_cls:
callback_kwargs = get_cb_kwargs(callback_cls)
callback = callback_cls(**callback_kwargs)
model, train_dataloader, _ = get_cb_model_and_datasets(callback)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
train_subset_num_batches=2,
max_duration='1ep',
callbacks=callback,
loggers=logger,
)
trainer.fit()
maybe_patch_context = get_cb_patches(callback_cls)

with maybe_patch_context:
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
train_subset_num_batches=2,
max_duration='1ep',
callbacks=callback,
loggers=logger,
)
trainer.fit()
18 changes: 18 additions & 0 deletions tests/fixtures/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,24 @@ def tiny_gpt2_model(_session_tiny_gpt2_model):
return copy.deepcopy(_session_tiny_gpt2_model)


def _gpt2_peft_config():
pytest.importorskip('peft')
from peft import get_peft_config

peft_config = get_peft_config({
'peft_type': 'LORA',
'task_type': 'CAUSAL_LM',
'target_modules': ['c_attn'],
'fan_in_fan_out': True,
})
return peft_config


@pytest.fixture
def gpt2_peft_config():
return _gpt2_peft_config()


@pytest.fixture
def tiny_opt_config(_session_tiny_opt_config):
return copy.deepcopy(_session_tiny_opt_config)
Expand Down
27 changes: 17 additions & 10 deletions tests/loggers/test_mosaicml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
)
from composer.trainer import Trainer
from composer.utils import dist, get_composer_env_dict
from tests.callbacks.callback_settings import get_cb_kwargs, get_cb_model_and_datasets, get_cbs_and_marks
from tests.callbacks.callback_settings import (
get_cb_kwargs,
get_cb_model_and_datasets,
get_cb_patches,
get_cbs_and_marks,
)
from tests.common import RandomClassificationDataset, SimpleModel
from tests.common.markers import world_size

Expand Down Expand Up @@ -121,15 +126,17 @@ def test_logged_data_is_json_serializable(monkeypatch, callback_cls: type[Callba
callback = callback_cls(**callback_kwargs)
train_dataset = RandomClassificationDataset()
model, train_dataloader, _ = get_cb_model_and_datasets(callback, sampler=dist.get_sampler(train_dataset))
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
train_subset_num_batches=1,
max_duration='1ep',
callbacks=callback,
loggers=MosaicMLLogger(),
)
trainer.fit()
maybe_patch_context = get_cb_patches(callback_cls)
with maybe_patch_context:
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
train_subset_num_batches=1,
max_duration='1ep',
callbacks=callback,
loggers=MosaicMLLogger(),
)
trainer.fit()

if dist.get_global_rank() == 0:
assert len(mock_mapi.run_metadata[run_name].keys()) > 0
Expand Down
Loading

0 comments on commit bb7ea43

Please sign in to comment.