Skip to content

Commit

Permalink
Merge branch 'main' into mvpatel2000/remove-legacy
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 authored Sep 23, 2024
2 parents ec84fbb + 17304a0 commit 8729c12
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 44 deletions.
6 changes: 1 addition & 5 deletions composer/trainer/_patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,8 +945,7 @@ def unshard_with_sync(self):

if version.parse(torch.__version__) >= version.parse('2.4.0') and version.parse(
torch.__version__,
) < version.parse('2.4.1'):
# 2.4.0 only patch
) < version.parse('2.4.2'):
# PyTorch issue: https://github.com/pytorch/pytorch/issues/133923
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from typing import Mapping, Collection
Expand Down Expand Up @@ -1003,9 +1002,6 @@ def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
for key, value in state_dict.items():
_traverse_obj((str(key),), value)

if version.parse(torch.__version__) >= version.parse('2.4.0') and version.parse(
torch.__version__,
) < version.parse('2.4.2'):
# Save original FlatParamHandle.unshard to revert back to when dropping automicrobatching hooks
from torch.distributed.fsdp._flat_param import FlatParamHandle
original_unshard = FlatParamHandle.unshard
Expand Down
69 changes: 30 additions & 39 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,50 +596,41 @@ def dist_cp_load(
load_planner: Optional[LoadPlanner] = None,
):
if version.parse(torch.__version__) >= version.parse('2.4.0'):
if version.parse(torch.__version__) < version.parse('2.4.1'):
# PyTorch 2.4.0
from torch.distributed.checkpoint.utils import CheckpointException
try:
dist_cp.load(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
)
except CheckpointException as e:
checkpoint_metadata = storage_reader.read_metadata().state_dict_metadata
if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata:
# Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility.
# Torch issue: https://github.com/pytorch/pytorch/issues/133923.
# We override the traverse_state_dict so that the load planner could
# use the old way of flattening the state dict
log.debug('Trying to load checkpointing saved before torch 2.4')

import torch.distributed.checkpoint._nested_dict as nested_dict
import torch.distributed.checkpoint._sharded_tensor_utils as sharded_tensor_util
from torch.distributed.checkpoint._traverse import traverse_state_dict as traverse_2_4_0

from composer.trainer._patch_pytorch import traverse_state_dict as backward_compatible_traverse

nested_dict.traverse_state_dict = backward_compatible_traverse
sharded_tensor_util.traverse_state_dict = backward_compatible_traverse

dist_cp.load(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
)
# Revert the override
nested_dict.traverse_state_dict = traverse_2_4_0
sharded_tensor_util.traverse_state_dict = traverse_2_4_0
else:
raise e
else:
# PyTorch 2.4.1
from torch.distributed.checkpoint.utils import CheckpointException
try:
dist_cp.load(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
)
except CheckpointException as e:
checkpoint_metadata = storage_reader.read_metadata().state_dict_metadata
if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata:
# Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility.
# Torch issue: https://github.com/pytorch/pytorch/issues/133923.
# We override the traverse_state_dict so that the load planner could
# use the old way of flattening the state dict
log.debug('Trying to load checkpointing saved before torch 2.4')

import torch.distributed.checkpoint._nested_dict as nested_dict
import torch.distributed.checkpoint._sharded_tensor_utils as sharded_tensor_util
from torch.distributed.checkpoint._traverse import traverse_state_dict as traverse_2_4_0

from composer.trainer._patch_pytorch import traverse_state_dict as backward_compatible_traverse

nested_dict.traverse_state_dict = backward_compatible_traverse
sharded_tensor_util.traverse_state_dict = backward_compatible_traverse

dist_cp.load(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
)
# Revert the override
nested_dict.traverse_state_dict = traverse_2_4_0
sharded_tensor_util.traverse_state_dict = traverse_2_4_0
else:
raise e
else:
dist_cp.load_state_dict(
state_dict=state_dict,
Expand Down

0 comments on commit 8729c12

Please sign in to comment.