Skip to content

Commit

Permalink
Make sharded checkpoint loading backwards-compatible (mosaicml#3240)
Browse files Browse the repository at this point in the history
* backwards compatible

* fix fsdp old tests

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
snarayan21 and dakinggg authored May 1, 2024
1 parent 86b0083 commit 5eddaf3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
20 changes: 14 additions & 6 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,13 +620,21 @@ def load_sharded_checkpoint(

# We need no_grad because we overwrite tensor values with set_() when we do elastic loading and we don't want the set_ op recorded in the computation graph.
with torch.no_grad():
# 1. Load model and metadata first
# 1. Load metadata first for backwards compatability check
# We need to check if the "optimizers" is at the root of the state dict to determine
# how to load the optimizer state.
metadata = storage_reader.read_metadata()
# Retrieve all top-level keys of the metadata.
top_level_keys = [v[0] for v in metadata.planner_data.values()]
optimizers_at_root = 'optimizers' in top_level_keys

# 2. Load model and metadata
if load_weights_only:
state_dict: Dict[str, Any] = {'state': {'model': state.get_model_state_dict()}}
else:
cur_state_dict = state.state_dict()
# For older versions of torch, we load optimizer separately.
if version.parse(torch.__version__) < version.parse('2.2.3'):
# If 'optimizers' is at root-level, we load it separately.
if optimizers_at_root:
cur_state_dict.pop('optimizers')
num_rng_ranks = _get_num_ranks_that_saved_rng(storage_reader.read_metadata())
state_dict: Dict[str, Any] = {
Expand Down Expand Up @@ -661,9 +669,9 @@ def load_sharded_checkpoint(
algorithm_passes=algorithm_passes,
)

# 2. Optionally load optimizer
# if we are using later than 2.2.3 then optimizer will already be loaded
if version.parse(torch.__version__) < version.parse('2.2.3') and not load_weights_only:
# 3. Optionally load optimizer
# If 'optimizers' was not at root-level, then it will already be loaded
if optimizers_at_root and not load_weights_only:
optim_state = load_sharded_optimizer_state_dict(
model_state_dict=state.state_dict()['model'],
optimizer_key='optimizers',
Expand Down
12 changes: 9 additions & 3 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,6 @@ def test_fsdp_load_old_checkpoint(
'state': trainer2.state.state_dict(),
'rng': get_rng_state(),
}
if version.parse(torch.__version__) < version.parse('2.2.3'):
state_dict['state'].pop('optimizers')

object_store = S3ObjectStore(bucket=f'{s3_bucket}')
storage_reader = DistCPObjectStoreReader(
Expand All @@ -538,14 +536,22 @@ def test_fsdp_load_old_checkpoint(
device_mesh=None,
)

# Load metadata first, and check if 'optimizers' is a top-level key. Pop if it is.
metadata = storage_reader.read_metadata()
# Retrieve all top-level keys of the metadata.
top_level_keys = [v[0] for v in metadata.planner_data.values()]
optimizers_at_root = 'optimizers' in top_level_keys
if optimizers_at_root:
state_dict['state'].pop('optimizers')

process_group = None
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=storage_reader,
planner=None,
process_group=process_group,
)
if version.parse(torch.__version__) < version.parse('2.2.3'):
if optimizers_at_root:
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model_state_dict = state_dict['state']['model']
Expand Down

0 comments on commit 5eddaf3

Please sign in to comment.