diff --git a/composer/checkpoint/save.py b/composer/checkpoint/save.py new file mode 100644 index 0000000000..72e5311d0f --- /dev/null +++ b/composer/checkpoint/save.py @@ -0,0 +1,145 @@ +# Copyright 2024 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Useful functions for saving state dicts to disk.""" + +import logging +import os +import textwrap +import warnings +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import torch +import torch.distributed.checkpoint as DCP +from packaging import version +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._tensor import DTensor + +from composer.utils import dist +from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME, _write_checkpoint_file + +log = logging.getLogger(__name__) + + +def save_state_dict_to_disk( + state_dict: Dict[str, Any], + destination_file_path: str, + overwrite: bool = False, + save_format: str = 'pt', # or hf, safetensor +) -> Optional[str]: + """Saves a state dict to local disk. + + Args: + state_dict (Dict[str,Any]): The state dict to save. + destination_file_path (str): The path to save the state dict to. If sharded, + this should be the pth to a directory. Otherwise, it should be a path to a file. + overwrite (bool): If True, the file will be overwritten if it exists. + save_format (str): The format to save the state dict in. One of 'pt', 'hf', or 'safetensor'. + + Returns: + str: The full path to the saved state dict if (sharded is false and rank 0) or if sharded is true, otherwise None. + """ + if state_dict == {}: + return None + if is_state_dict_sharded(state_dict): + path_saved = _save_sharded_state_dict_to_disk(state_dict, destination_file_path, overwrite, save_format) + else: + if dist.get_global_rank() == 0: + path_saved = _save_full_state_dict_to_disk(state_dict, destination_file_path, overwrite, save_format) + else: + path_saved = None + + return path_saved + + +def _save_sharded_state_dict_to_disk( + state_dict: Dict[str, Any], + destination_file_path: str, + overwrite: bool = False, + save_format: str = 'pt', +) -> Optional[str]: + + if save_format != 'pt': + raise NotImplementedError( + f"Saving sharded state dict to disk in format {save_format} is not supported. Please choose from ['pt'].", + ) + + if state_dict == {}: + return None + + # If user specifies filename instead of directory suffixes, strip them and warn + if len(Path(destination_file_path).suffixes) > 0: + stripped_path = _strip_suffixes(destination_file_path) + warnings.warn( + textwrap.dedent( + f"""Sharded checkpoints require a directory path not a file path: + {destination_file_path} will have its extensions stripped and checkpoints will be saved in {stripped_path} + as {stripped_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}""", + ), + ) + destination_file_path = stripped_path + + if dist.get_global_rank() == 0 and not overwrite and os.path.exists(destination_file_path): + raise ValueError(f'Directory {destination_file_path} already exists. Set overwrite=True to overwrite it.') + + log.debug( + f'Starting saving of sharded state dict to {destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}', + ) + + # For 2.3.0 and above you can use checkpoint_id, but this version works the best for all versions + # of torch (and makes pyright happier) that we support, so we use it for now. + if version.parse(torch.__version__) < version.parse('2.2.0'): + DCP.save_state_dict(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path)) + else: + DCP.save(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path)) + + return destination_file_path + '/' + _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME + + +def _save_full_state_dict_to_disk( + state_dict: Dict[str, Any], + destination_file_path: str, + overwrite: bool = False, + save_format: str = 'pt', # or hf, safetensor +) -> Optional[str]: + + if save_format != 'pt': + raise NotImplementedError( + f"Saving sharded state dict to disk in format {save_format} is not supported. Please choose from ['pt'].", + ) + + if not overwrite and os.path.exists(destination_file_path): + raise ValueError(f'File {destination_file_path} already exists. Set overwrite=True to overwrite it.') + + if dist.get_global_rank() == 0: + _write_checkpoint_file(state_dict=state_dict, filename=destination_file_path) + return destination_file_path + return None + + +def is_state_dict_sharded(state_dict: Dict[str, Any]) -> bool: + """Determines if the state dict is sharded. + + Args: + state_dict (Dict[str, Any]): The state dict to check. + + Returns: + bool: Whether the state dict is sharded. + """ + for value in state_dict.values(): + if isinstance(value, ShardedTensor) or isinstance(value, DTensor): + return True + if isinstance(value, Dict): + is_sharded = is_state_dict_sharded(value) + if is_sharded: + return True + return False + + +def _strip_suffixes(path: Union[str, Path]) -> str: + path = Path(path) + for _ in path.suffixes: + path = path.with_suffix('') + + return str(path) diff --git a/docs/source/conf.py b/docs/source/conf.py index 45affa4a0e..533ce95b78 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -219,7 +219,6 @@ def _get_commit_sha() -> str: 'torch': ('https://pytorch.org/docs/stable/', None), 'torchvision': ('https://pytorch.org/vision/stable/', None), 'torchtext': ('https://pytorch.org/text/stable/', None), - 'torchmetrics': ('https://torchmetrics.readthedocs.io/en/latest/', None), 'libcloud': ('https://libcloud.readthedocs.io/en/stable/', None), 'PIL': ('https://pillow.readthedocs.io/en/stable', None), 'coolname': ('https://coolname.readthedocs.io/en/latest/', None), diff --git a/tests/checkpoint/helpers.py b/tests/checkpoint/helpers.py new file mode 100644 index 0000000000..047d30e813 --- /dev/null +++ b/tests/checkpoint/helpers.py @@ -0,0 +1,110 @@ +# Copyright 2024 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict + +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.api import CPUOffload +from torch.optim import adam + +from tests.common.models import EvenSimplerMLP, SimpleComposerMLP + +__all__ = [ + 'init_model_and_optimizer', + 'init_model', + 'init_optimizer', +] + + +def init_model_and_optimizer( + use_composer_model: bool, + num_classes=3, + batch_size=5, + num_features=8, + take_step=True, + use_fsdp=False, + tensor_type='sharded_tensor', + device='cuda', +): + model, loss_fn = init_model( + use_composer_model, + num_classes=num_classes, + num_features=num_features, + use_fsdp=use_fsdp, + tensor_type=tensor_type, + device=device, + ) + + optimizer = init_optimizer( + model, + loss_fn, + use_composer_model=use_composer_model, + num_classes=num_classes, + batch_size=batch_size, + num_features=num_features, + take_step=take_step, + device=device, + ) + + return model, optimizer + + +def init_model( + use_composer_model: bool = False, + num_classes=3, + num_features=8, + use_fsdp=False, + device='cuda', + tensor_type='sharded_tensor', + sync_module_states=True, + cpu_offload=False, +): + if use_composer_model: + model = SimpleComposerMLP(num_features=num_features, num_classes=num_classes, device=device) + loss_fn = model._loss_fn + else: + model = EvenSimplerMLP(num_features=num_features, num_out_features=num_classes, device=device) + loss_fn = torch.nn.CrossEntropyLoss() + + if use_fsdp: + fsdp_kwargs: Dict[str, Any] = dict( + use_orig_params=True, + sync_module_states=sync_module_states, # To enable easy comparison between rank 0 unsharded model and full state dict + cpu_offload=CPUOffload(offload_params=True) if cpu_offload else None, + device_id=torch.device('cpu') if device == 'cpu' else None, + ) + + if tensor_type == 'dtensor': + from torch.distributed.device_mesh import init_device_mesh + device_mesh = init_device_mesh('cuda', (2,)) + fsdp_kwargs['device_mesh'] = device_mesh + + model = FSDP( + model, + **fsdp_kwargs, + ) + + return model, loss_fn + + +def init_optimizer( + model, + loss_fn, + use_composer_model: bool = False, + num_classes=3, + batch_size=5, + num_features=8, + take_step=True, + device='cuda', +): + inputs = torch.randn(batch_size, num_features, device=device) + targets = torch.randint(low=0, high=num_classes, size=(batch_size,), device=device, dtype=torch.long) + batch = (inputs, targets) if use_composer_model else inputs + optimizer = adam.Adam(model.parameters()) + outputs = model(batch) + loss = loss_fn(outputs, targets) + loss.backward() + if take_step: + optimizer.step() + return optimizer diff --git a/tests/checkpoint/test_save.py b/tests/checkpoint/test_save.py new file mode 100644 index 0000000000..03b12bbcbc --- /dev/null +++ b/tests/checkpoint/test_save.py @@ -0,0 +1,79 @@ +# Copyright 2024 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +import os +import time +import uuid +from copy import deepcopy +from pathlib import Path + +import pytest +import torch +import torch.distributed.checkpoint as DCP +from packaging import version + +from composer.checkpoint.save import save_state_dict_to_disk +from composer.checkpoint.state_dict import get_model_state_dict +from composer.utils import dist +from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME +from tests.checkpoint.helpers import init_model +from tests.common.compare import deep_compare +from tests.common.markers import world_size + + +@world_size(1, 2) +@pytest.mark.gpu +@pytest.mark.parametrize('sharded_model', [False, True]) +def test_save_full_state_dict_to_disk(world_size: int, tmp_path: str, sharded_model: bool): + if world_size == 1 and sharded_model: + pytest.skip("Can't have a sharded model for world_size = 1") + destination_file_path = os.path.join(tmp_path, 'test.pt') + use_fsdp = sharded_model + model, _ = init_model(use_fsdp=use_fsdp, device='cuda', sync_module_states=True) + + state_dict = get_model_state_dict(model, sharded_state_dict=False) + path_saved = save_state_dict_to_disk(state_dict, destination_file_path=destination_file_path) + time.sleep(1) + if dist.get_global_rank() == 0: + assert path_saved is not None + assert path_saved == destination_file_path + assert os.path.exists(destination_file_path), f'{destination_file_path} does not exist' + loaded_state_dict = torch.load(path_saved, map_location='cuda') + deep_compare(state_dict, loaded_state_dict) + else: + assert path_saved is None + + +@world_size(2) +@pytest.mark.gpu +@pytest.mark.parametrize( + 'tensor_type', + [ + 'sharded_tensor', + pytest.param( + 'dtensor', + marks=pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('2.2.0'), + reason='Requires torch>=2.2.0 for dtensor', + ), + ), + ], +) +def test_save_sharded_state_dict_to_disk(world_size: int, tmp_path: str, tensor_type: str): + + destination_file_path = os.path.join(tmp_path, str(uuid.uuid4())[:8]) + # Sync the path across all ranks + destination_file_path = dist.all_gather_object(destination_file_path)[0] + model, _ = init_model(use_fsdp=True, device='cuda', tensor_type=tensor_type) + + state_dict = get_model_state_dict(model, sharded_state_dict=True) + loaded_in_state_dict = deepcopy(state_dict) + path_saved = save_state_dict_to_disk(state_dict, destination_file_path=destination_file_path, overwrite=True) + assert path_saved == f'{destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}' + assert path_saved is not None + load_path = str(Path(path_saved).parent) + if version.parse(torch.__version__) < version.parse('2.2.0'): + DCP.load_state_dict(state_dict=loaded_in_state_dict, storage_reader=DCP.FileSystemReader(load_path)) + else: + DCP.load(state_dict=loaded_in_state_dict, storage_reader=DCP.FileSystemReader(load_path)) + deep_compare(state_dict, loaded_in_state_dict) diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index bd14154dc9..e010440836 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import datetime -from typing import Any, Dict +from typing import Any from unittest.mock import MagicMock import pytest @@ -10,7 +10,6 @@ import torch.distributed as torch_dist from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.optim import adam from torch.optim.lr_scheduler import StepLR from torch.utils.data import DataLoader @@ -25,6 +24,7 @@ from composer.core import State from composer.devices import DeviceCPU, DeviceGPU from composer.utils import dist, reproducibility +from tests.checkpoint.helpers import init_model_and_optimizer from tests.common.compare import deep_compare from tests.common.markers import world_size from tests.common.models import EvenSimplerMLP, SimpleComposerMLP, configure_tiny_gpt2_hf_model @@ -247,101 +247,10 @@ def test_get_model_state_dict_precision_unsharded_model(precision: str, use_comp assert tens.dtype == precision -def _init_model_and_optimizer( - use_composer_model: bool, - num_classes=3, - batch_size=5, - num_features=8, - take_step=True, - use_fsdp=False, - tensor_type='sharded_tensor', - device='cuda', -): - model, loss_fn = _init_model( - use_composer_model, - num_classes=num_classes, - batch_size=batch_size, - num_features=num_features, - use_fsdp=use_fsdp, - tensor_type=tensor_type, - device=device, - ) - - optimizer = _init_optimizer( - model, - loss_fn, - use_composer_model=use_composer_model, - num_classes=num_classes, - batch_size=batch_size, - num_features=num_features, - take_step=take_step, - device=device, - ) - - return model, optimizer - - -def _init_model( - use_composer_model: bool = False, - num_classes=3, - batch_size=5, - num_features=8, - use_fsdp=False, - device='cuda', - tensor_type='sharded_tensor', -): - if use_composer_model: - model = SimpleComposerMLP(num_features=num_features, num_classes=num_classes, device=device) - loss_fn = model._loss_fn - else: - model = EvenSimplerMLP(num_features=num_features, num_out_features=num_classes, device=device) - loss_fn = torch.nn.CrossEntropyLoss() - - if use_fsdp: - fsdp_kwargs: Dict[str, Any] = dict( - use_orig_params=True, - sync_module_states=True, # To enable easy comparison between rank 0 unsharded model and full state dict - ) - - if tensor_type == 'dtensor': - from torch.distributed.device_mesh import init_device_mesh - device_mesh = init_device_mesh('cuda', (2,)) - fsdp_kwargs['device_mesh'] = device_mesh - - model = FSDP( - model, - **fsdp_kwargs, - ) - - return model, loss_fn - - -def _init_optimizer( - model, - loss_fn, - use_composer_model: bool = False, - num_classes=3, - batch_size=5, - num_features=8, - take_step=True, - device='cuda', -): - inputs = torch.randn(batch_size, num_features, device=device) - targets = torch.randint(low=0, high=num_classes, size=(batch_size,), device=device, dtype=torch.long) - batch = (inputs, targets) if use_composer_model else inputs - optimizer = adam.Adam(model.parameters()) - outputs = model(batch) - loss = loss_fn(outputs, targets) - loss.backward() - if take_step: - optimizer.step() - return optimizer - - @pytest.mark.gpu @pytest.mark.parametrize('use_composer_model', [True, False]) def test_get_optim_state_dict_unsharded_model(use_composer_model: bool): - model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True) + model, optimizer = init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True) optim_state_dict = get_optim_state_dict(model, optimizer) # Dict mapping parameter index to optimizer state for that parameter. @@ -385,7 +294,7 @@ def test_get_optim_state_dict_unsharded_model(use_composer_model: bool): ) @pytest.mark.parametrize('use_composer_model', [True, False]) def test_get_optim_state_dict_precision_unsharded_model(precision: str, use_composer_model: bool): - model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True) + model, optimizer = init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True) optim_state_dict = get_optim_state_dict(model, optimizer, precision=precision) for param_state in optim_state_dict['state'].values(): assert param_state['exp_avg'].dtype == precision @@ -400,7 +309,7 @@ def test_get_optim_dict_full_for_sharded_model(world_size, tensor_type, use_comp if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'): pytest.skip('DTensor is only supported in PyTorch >= 2.2.0') - model, optimizer = _init_model_and_optimizer( + model, optimizer = init_model_and_optimizer( use_composer_model=use_composer_model, take_step=True, use_fsdp=True, @@ -427,7 +336,7 @@ def test_get_optim_dict_sharded_for_sharded_model(world_size, tensor_type, use_c if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'): pytest.skip('DTensor is only supported in PyTorch >= 2.2.0') - model, optimizer = _init_model_and_optimizer( + model, optimizer = init_model_and_optimizer( use_composer_model=use_composer_model, take_step=True, use_fsdp=True, @@ -540,7 +449,7 @@ def test_get_metadata_sharded_model(model_type: str, tensor_type: str, world_siz @pytest.mark.filterwarnings('ignore:SWA has') def test_get_resumption_state_dict(): - model, optimizer = _init_model_and_optimizer(use_composer_model=True, take_step=True, device='cpu') + model, optimizer = init_model_and_optimizer(use_composer_model=True, take_step=True, device='cpu') rank_zero_seed = 10 run_name = 'test_run' @@ -605,7 +514,7 @@ def test_get_resumption_state_dict_gpu(): else: from torch.cuda.amp.grad_scaler import GradScaler - model, _ = _init_model_and_optimizer(use_composer_model=True, take_step=False, device='cuda') + model, _ = init_model_and_optimizer(use_composer_model=True, take_step=False, device='cuda') rank_zero_seed = 10 run_name = 'test_run' diff --git a/tests/common/compare.py b/tests/common/compare.py index 432ac55dfd..79dfe573bb 100644 --- a/tests/common/compare.py +++ b/tests/common/compare.py @@ -7,6 +7,8 @@ import numpy as np import torch import torchmetrics +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._tensor import DTensor from composer import Time from composer.core.time import TimeUnit @@ -39,7 +41,7 @@ def _check_item( assert type(item1) == type(item2) assert item1 == item2, f'{path} differs: {item1} != {item2}' return - if isinstance(item1, torch.Tensor): + if isinstance(item1, torch.Tensor) and not (isinstance(item1, ShardedTensor) or isinstance(item1, DTensor)): assert isinstance(item2, torch.Tensor) if item1.device != item2.device: item1 = item1.cpu() @@ -58,6 +60,16 @@ def _check_item( assert isinstance(item2, type(item1)), f'{path} differs: {item1} != {item2}' _check_list_recursively(item1, item2, path, atol=atol, rtol=rtol) return + if isinstance(item1, ShardedTensor): + assert isinstance(item2, type(item1)), f'{path} differs: {item1} != {item2}' + _check_sharded_tensor_recursively(item1, item2, path, atol=atol, rtol=rtol) + return + + if isinstance(item1, DTensor): + assert isinstance(item2, type(item1)), f'{path} differs: {item1} != {item2}' + _check_dtensor_recursively(item1, item2, path, atol=atol, rtol=rtol) + return + if isinstance(item1, torchmetrics.Metric): assert isinstance(item2, torchmetrics.Metric), f'{path} differs: {item1} != {item2}' # Increase update count so Torchmetrics doesn't throw warning when computing two metrics which haven't been updated @@ -84,6 +96,28 @@ def _check_item( raise NotImplementedError(f'Unsupported item type: {type(item1)}') +def _check_dtensor_recursively( + dtensor1: DTensor, + dtensor2: DTensor, + path: str, + atol: float, + rtol: float, +): + tensor1, tensor2 = dtensor1.to_local(), dtensor2.to_local() + _check_item(tensor1, tensor2, path, atol=atol, rtol=rtol) + + +def _check_sharded_tensor_recursively( + sharded_tensor1: ShardedTensor, + sharded_tensor2: ShardedTensor, + path: str, + atol: float, + rtol: float, +): + tensor1, tensor2 = sharded_tensor1.local_tensor(), sharded_tensor2.local_tensor() + _check_item(tensor1, tensor2, path, atol=atol, rtol=rtol) + + def _check_list_recursively( list1: Union[tuple[Any], list[Any]], list2: Union[tuple[Any], list[Any]],