Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ckpt-rewr] Save state dict API #3372

Merged
merged 48 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
b7915b9
add stubs
eracah Jun 4, 2024
5700fd4
save progress
eracah Jun 5, 2024
cfa11d4
Merge branch 'dev' of https://github.com/mosaicml/composer into sv-sd
eracah Jun 6, 2024
eabb9a3
add full state dict saving and testing
eracah Jun 6, 2024
37cfed3
Merge branch 'dev' into sv-sd
eracah Jun 6, 2024
0c503dc
Merge branch 'dev' of https://github.com/mosaicml/composer into sv-sd
eracah Jun 6, 2024
13d7e0c
Merge branch 'sv-sd' of https://github.com/eracah/evan-composer into …
eracah Jun 6, 2024
316504e
add stubs
eracah Jun 6, 2024
00a7ce2
implement sharded save and get tests to pass
eracah Jun 6, 2024
0fd0d53
Merge branch 'dev' into sv-sd
eracah Jun 6, 2024
3a6c185
add cpu sharded test
eracah Jun 6, 2024
9b79b30
Merge branch 'sv-sd' of https://github.com/eracah/evan-composer into …
eracah Jun 6, 2024
56f9339
pre-commit
eracah Jun 7, 2024
03dbb6a
remove comment
eracah Jun 7, 2024
7efdc74
Merge branch 'dev' into sv-sd
eracah Jun 7, 2024
e437fa7
remove __init__
eracah Jun 7, 2024
b518cde
rm init
eracah Jun 7, 2024
e9bb7ef
Merge branch 'sv-sd' of https://github.com/eracah/evan-composer into …
eracah Jun 7, 2024
5156572
fix
eracah Jun 7, 2024
4f516dc
remove torchmetrics
eracah Jun 7, 2024
66c84b8
Merge branch 'sv-sd' of https://github.com/eracah/evan-composer into …
eracah Jun 7, 2024
8c076b9
Update composer/checkpoint/save.py
eracah Jun 10, 2024
dba8cdc
Update composer/checkpoint/save.py
eracah Jun 10, 2024
dc35df0
Update composer/checkpoint/save.py
eracah Jun 10, 2024
dc5cf9f
Update composer/checkpoint/save.py
eracah Jun 10, 2024
6c67f93
Merge branch 'dev' into sv-sd
eracah Jun 10, 2024
892536d
fix docstring
eracah Jun 10, 2024
fcd1789
remove time.sleep
eracah Jun 10, 2024
1b47500
fix cpu tests
eracah Jun 11, 2024
46e4bec
pre-commit
eracah Jun 11, 2024
c4a97ce
fix cpu test
eracah Jun 11, 2024
05b9903
remove cpu tests :(
eracah Jun 11, 2024
44b123f
Merge branch 'dev' into sv-sd
eracah Jun 11, 2024
3874e58
pre-commit
eracah Jun 11, 2024
037548a
pc
eracah Jun 11, 2024
8324589
add all check
eracah Jun 12, 2024
ae3911d
pre-commit
eracah Jun 12, 2024
3e6e60b
add world_size = 1 test
eracah Jun 12, 2024
c865023
Merge branch 'sv-sd' of https://github.com/eracah/evan-composer into …
eracah Jun 12, 2024
bd885a6
pre-commit
eracah Jun 12, 2024
064d0ba
pc
eracah Jun 12, 2024
aa05ce5
Merge branch 'dev' into sv-sd
eracah Jun 12, 2024
c4ef047
Merge branch 'dev' into sv-sd
eracah Jun 17, 2024
fbb80e5
Merge branch 'dev' into sv-sd
eracah Jun 17, 2024
8ae5a52
Update composer/checkpoint/save.py
eracah Jun 17, 2024
2efeb9a
Update composer/checkpoint/save.py
eracah Jun 17, 2024
92ffa59
fix is_sharded
eracah Jun 17, 2024
42f56b9
Merge branch 'sv-sd' of https://github.com/eracah/evan-composer into …
eracah Jun 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions composer/checkpoint/save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Useful functions for saving state dicts to disk."""

import logging
import os
import textwrap
import warnings
from pathlib import Path
from typing import Any, Dict, Optional, Union

import torch.distributed.checkpoint as DCP
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor

from composer.utils import dist
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME, _write_checkpoint_file

log = logging.getLogger(__name__)


def save_state_dict_to_disk(
state_dict: Dict[str, Any],
destination_file_path: str,
overwrite: bool = False,
save_format: str = 'pt', # or hf, safetensor
) -> Optional[str]:
"""Saves a state dict to local disk.

Args:
state_dict (Dict[str,Any]): The state dict to save.
destination_dir (str): The directory to save the state dict to.
filename (str): The name of the file to save the state dict to.
overwrite (bool): If True, the file will be overwritten if it exists.
save_format (str): The format to save the state dict in. One of 'pt', 'hf', or 'safetensor'.
async_save (bool): If True, the save will be done asynchronously and the function will return with the path of where it was going to be saved
eracah marked this conversation as resolved.
Show resolved Hide resolved

Returns:
str: The full path to the saved state dict if sharded is false and rank 0 or if sharded is true, otherwise None.
eracah marked this conversation as resolved.
Show resolved Hide resolved
"""
if state_dict == {}:
return None
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
sharded_state_dict = is_state_dict_sharded(state_dict)

if sharded_state_dict:
eracah marked this conversation as resolved.
Show resolved Hide resolved
path_saved = _save_sharded_state_dict_to_disk(state_dict, destination_file_path, overwrite, save_format)
else:
if dist.get_global_rank() == 0:
path_saved = _save_full_state_dict_to_disk(state_dict, destination_file_path, overwrite, save_format)
else:
path_saved = None

return path_saved


def _save_sharded_state_dict_to_disk(
state_dict: Dict[str, Any],
destination_file_path: str,
overwrite: bool = False,
save_format: str = 'pt', # or safetensor
eracah marked this conversation as resolved.
Show resolved Hide resolved
) -> Optional[str]:

if save_format != 'pt':
raise NotImplementedError(f'Saving sharded state dict to disk in format {save_format} is not supported.')
eracah marked this conversation as resolved.
Show resolved Hide resolved

if state_dict == {}:
return None
eracah marked this conversation as resolved.
Show resolved Hide resolved

# If used specifies filename instead of directory suffixes, strip them and warn.
eracah marked this conversation as resolved.
Show resolved Hide resolved
if len(Path(destination_file_path).suffixes) > 0:
stripped_path = _strip_suffixes(destination_file_path)
warnings.warn(
textwrap.dedent(
f"""Sharded checkpoints require a directory path not a file path:
{destination_file_path} will have its extensions stripped and checkpoints will be saved in {stripped_path}
as {stripped_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}""",
),
)
destination_file_path = stripped_path

if dist.get_global_rank() == 0 and not overwrite and os.path.exists(destination_file_path):
raise ValueError(f'Directory {destination_file_path} already exists. Set overwrite=True to overwrite it.')

log.debug(
f'Starting saving of sharded state dict to {destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}',
)

# For 2.3.0 and above you can use checkpoint_id, but this version works the best for all versions
# of torch (and makes pyright happier) that we support, so we use it for now.
DCP.save(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path))
eracah marked this conversation as resolved.
Show resolved Hide resolved

return destination_file_path + '/' + _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME


def _save_full_state_dict_to_disk(
state_dict: Dict[str, Any],
destination_file_path: str,
overwrite: bool = False,
save_format: str = 'pt', # or hf, safetensor
) -> Optional[str]:

if save_format != 'pt':
raise NotImplementedError(f'Saving full state dict to disk in format {save_format} is not supported.')
eracah marked this conversation as resolved.
Show resolved Hide resolved

if not overwrite and os.path.exists(destination_file_path):
raise ValueError(f'File {destination_file_path} already exists. Set overwrite=True to overwrite it.')

if dist.get_global_rank() == 0:
_write_checkpoint_file(state_dict=state_dict, filename=destination_file_path)
return destination_file_path
return None


def is_state_dict_sharded(state_dict: Dict[str, Any]) -> bool:
"""Determines if the state dict is sharded.

Args:
state_dict (Dict[str, Any]): The state dict to check.

Returns:
bool: Whether the state dict is sharded.
"""
sample_value = next(iter(state_dict.values()))
return isinstance(sample_value, ShardedTensor) or isinstance(sample_value, DTensor)
eracah marked this conversation as resolved.
Show resolved Hide resolved


def _strip_suffixes(path: Union[str, Path]) -> str:
path = Path(path)
for _ in path.suffixes:
path = path.with_suffix('')

return str(path)
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def _get_commit_sha() -> str:
'torch': ('https://pytorch.org/docs/stable/', None),
'torchvision': ('https://pytorch.org/vision/stable/', None),
'torchtext': ('https://pytorch.org/text/stable/', None),
'torchmetrics': ('https://torchmetrics.readthedocs.io/en/latest/', None),
'libcloud': ('https://libcloud.readthedocs.io/en/stable/', None),
'PIL': ('https://pillow.readthedocs.io/en/stable', None),
'coolname': ('https://coolname.readthedocs.io/en/latest/', None),
Expand Down
109 changes: 109 additions & 0 deletions tests/checkpoint/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict

eracah marked this conversation as resolved.
Show resolved Hide resolved
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import CPUOffload
from torch.optim import adam

from tests.common.models import EvenSimplerMLP, SimpleComposerMLP

__all__ = [
'init_model_and_optimizer',
'init_model',
'init_optimizer',
]


def init_model_and_optimizer(
use_composer_model: bool,
num_classes=3,
batch_size=5,
num_features=8,
take_step=True,
use_fsdp=False,
tensor_type='sharded_tensor',
device='cuda',
):
model, loss_fn = init_model(
use_composer_model,
num_classes=num_classes,
num_features=num_features,
use_fsdp=use_fsdp,
tensor_type=tensor_type,
device=device,
)

optimizer = init_optimizer(
model,
loss_fn,
use_composer_model=use_composer_model,
num_classes=num_classes,
batch_size=batch_size,
num_features=num_features,
take_step=take_step,
device=device,
)

return model, optimizer


def init_model(
use_composer_model: bool = False,
num_classes=3,
num_features=8,
use_fsdp=False,
device='cuda',
tensor_type='sharded_tensor',
sync_module_states=True,
cpu_offload=False,
):
if use_composer_model:
model = SimpleComposerMLP(num_features=num_features, num_classes=num_classes, device=device)
loss_fn = model._loss_fn
else:
model = EvenSimplerMLP(num_features=num_features, num_out_features=num_classes, device=device)
loss_fn = torch.nn.CrossEntropyLoss()

if use_fsdp:
fsdp_kwargs: Dict[str, Any] = dict(
use_orig_params=True,
sync_module_states=sync_module_states, # To enable easy comparison between rank 0 unsharded model and full state dict
cpu_offload=CPUOffload(offload_params=True) if cpu_offload else None,
)

if tensor_type == 'dtensor':
from torch.distributed.device_mesh import init_device_mesh
device_mesh = init_device_mesh('cuda', (2,))
fsdp_kwargs['device_mesh'] = device_mesh

model = FSDP(
model,
**fsdp_kwargs,
)

return model, loss_fn


def init_optimizer(
model,
loss_fn,
use_composer_model: bool = False,
num_classes=3,
batch_size=5,
num_features=8,
take_step=True,
device='cuda',
):
inputs = torch.randn(batch_size, num_features, device=device)
targets = torch.randint(low=0, high=num_classes, size=(batch_size,), device=device, dtype=torch.long)
batch = (inputs, targets) if use_composer_model else inputs
optimizer = adam.Adam(model.parameters())
outputs = model(batch)
loss = loss_fn(outputs, targets)
loss.backward()
if take_step:
optimizer.step()
eracah marked this conversation as resolved.
Show resolved Hide resolved
return optimizer
99 changes: 99 additions & 0 deletions tests/checkpoint/test_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import os
import time
import uuid
from copy import deepcopy
from pathlib import Path

import pytest
import torch
import torch.distributed.checkpoint as DCP

from composer.checkpoint.save import save_state_dict_to_disk
from composer.checkpoint.state_dict import get_model_state_dict
from composer.utils import dist
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME
from tests.checkpoint.helpers import init_model
from tests.common.compare import deep_compare
from tests.common.markers import world_size


@world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize('sharded_model', [False, True])
def test_save_full_state_dict_to_disk(world_size: int, tmp_path: str, sharded_model: bool):

destination_file_path = os.path.join(tmp_path, 'test.pt')
use_fsdp = sharded_model
model, _ = init_model(use_fsdp=use_fsdp, device='cuda', sync_module_states=True)

state_dict = get_model_state_dict(model, sharded_state_dict=False)
path_saved = save_state_dict_to_disk(state_dict, destination_file_path=destination_file_path)
time.sleep(1)
if dist.get_global_rank() == 0:
assert path_saved is not None
assert path_saved == destination_file_path
assert os.path.exists(destination_file_path), f'{destination_file_path} does not exist'
loaded_state_dict = torch.load(path_saved, map_location='cuda')
deep_compare(state_dict, loaded_state_dict)
else:
assert path_saved is None


@world_size(2)
@pytest.mark.filterwarnings('ignore:The passed') # Torch issues a warning for wrapping a CPU model in FSDP
@pytest.mark.parametrize('sharded_model', [False, True])
def test_save_full_state_dict_to_disk_cpu(world_size: int, tmp_path: str, sharded_model: bool):

destination_file_path = os.path.join(tmp_path, 'test.pt')
use_fsdp = sharded_model
model, _ = init_model(use_fsdp=use_fsdp, device='cpu', sync_module_states=False, cpu_offload=True)
state_dict = get_model_state_dict(model, sharded_state_dict=False)
path_saved = save_state_dict_to_disk(state_dict, destination_file_path=destination_file_path)
time.sleep(1)
eracah marked this conversation as resolved.
Show resolved Hide resolved
if dist.get_global_rank() == 0:
assert path_saved == destination_file_path
assert os.path.exists(destination_file_path), f'{destination_file_path} does not exist'
assert path_saved is not None
loaded_state_dict = torch.load(path_saved, map_location='cpu')
deep_compare(state_dict, loaded_state_dict)
else:
assert path_saved is None


@world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor'])
def test_save_sharded_state_dict_to_disk(world_size: int, tmp_path: str, tensor_type: str):
destination_file_path = os.path.join(tmp_path, str(uuid.uuid4())[:8])
# Sync the path across all ranks
destination_file_path = dist.all_gather_object(destination_file_path)[0]
model, _ = init_model(use_fsdp=True, device='cuda', tensor_type=tensor_type)

state_dict = get_model_state_dict(model, sharded_state_dict=True)
loaded_in_state_dict = deepcopy(state_dict)
path_saved = save_state_dict_to_disk(state_dict, destination_file_path=destination_file_path, overwrite=True)
assert path_saved == f'{destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}'
assert path_saved is not None
load_path = str(Path(path_saved).parent)
DCP.load(state_dict=loaded_in_state_dict, storage_reader=DCP.FileSystemReader(load_path))
deep_compare(state_dict, loaded_in_state_dict)


@pytest.mark.filterwarnings('ignore:The passed') # Torch issues a warning for wrapping a CPU model in FSDP
@world_size(2)
def test_save_sharded_state_dict_to_disk_cpu(world_size: int, tmp_path: str):
destination_file_path = os.path.join(tmp_path, str(uuid.uuid4())[:8])
destination_file_path = dist.all_gather_object(destination_file_path)[0]
model, _ = init_model(use_fsdp=True, device='cpu', sync_module_states=False, cpu_offload=True)
state_dict = get_model_state_dict(model, sharded_state_dict=True)
loaded_in_state_dict = deepcopy(state_dict)
path_saved = save_state_dict_to_disk(state_dict, destination_file_path=destination_file_path, overwrite=True)
assert path_saved == f'{destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}'

assert isinstance(path_saved, str)
load_path = str(Path(path_saved).parent)
DCP.load(state_dict=loaded_in_state_dict, storage_reader=DCP.FileSystemReader(load_path))
deep_compare(state_dict, loaded_in_state_dict)
Loading
Loading