-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
End-to-end test saving/load module with DTensor
Summary: Add test that takes and restores a snapshot of an HSDP sharded model, which uses DTensors that are partially replicated (placements = [Replicate(), Shard(0)]) Reviewed By: ananthsub Differential Revision: D49470645 fbshipit-source-id: 5ec5b8af6eb04eb60ff25cd1707e6be0d577472e
- Loading branch information
1 parent
8bce08e
commit 430641f
Showing
1 changed file
with
120 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import logging | ||
import uuid | ||
from typing import Optional | ||
|
||
import torch | ||
from torch import distributed as dist, nn | ||
from torch.distributed._tensor import DeviceMesh | ||
from torch.distributed._tensor.device_mesh import init_device_mesh | ||
from torch.distributed.fsdp import ( | ||
FullyShardedDataParallel as FSDP, | ||
ShardingStrategy, | ||
StateDictType, | ||
) | ||
from torch.distributed.fsdp.api import ( | ||
ShardedOptimStateDictConfig, | ||
ShardedStateDictConfig, | ||
) | ||
from torch.testing._internal.distributed._tensor.common_dtensor import ( | ||
DTensorTestBase, | ||
skip_if_lt_x_gpu, | ||
with_comms, | ||
) | ||
from torchsnapshot import Snapshot | ||
from torchsnapshot.test_utils import check_state_dict_eq | ||
from torchsnapshot.tricks.fsdp import FSDPOptimizerAdapter | ||
|
||
logger: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
||
WORLD_SIZE: int = 4 | ||
|
||
|
||
class DummyModel(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU()) | ||
self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU()) | ||
self.net3 = nn.Linear(32, 64) | ||
self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8)) | ||
|
||
def forward(self, x): | ||
return self.net4(self.net3(self.net2(self.net1(x)))) | ||
|
||
def get_input(self): | ||
return torch.rand(4, 8, device="cuda") | ||
|
||
|
||
# TODO: Test different world sizes (may require not using DTensorTestBase) | ||
# TODO: Test FSDP + TP once dim_map is updated for [Shard(0), Shard(0)] cases | ||
class TestSnapshotWithDTensor(DTensorTestBase): | ||
def _create_model( | ||
self, seed: int, optim_lr: float, device_mesh: Optional[DeviceMesh] = None | ||
): | ||
torch.manual_seed(seed) | ||
# Using HSDP model as an example model that uses DTensor | ||
# This should create model with placements | ||
# [Replicate(), Shard(0)] | ||
if device_mesh: | ||
model = FSDP( | ||
DummyModel().cuda(), | ||
device_mesh=device_mesh, | ||
sharding_strategy=ShardingStrategy.HYBRID_SHARD, | ||
) | ||
else: | ||
mesh_2d = init_device_mesh("cuda", (2, WORLD_SIZE // 2)) | ||
intra_node_pg = mesh_2d.get_dim_groups(mesh_dim=1) | ||
inter_node_pg = mesh_2d.get_dim_groups(mesh_dim=0) | ||
model = FSDP( | ||
DummyModel().cuda(), | ||
process_group=(intra_node_pg, inter_node_pg), | ||
sharding_strategy=ShardingStrategy.HYBRID_SHARD, | ||
) | ||
|
||
FSDP.set_state_dict_type( | ||
model, | ||
StateDictType.SHARDED_STATE_DICT, | ||
state_dict_config=ShardedStateDictConfig(), | ||
optim_state_dict_config=ShardedOptimStateDictConfig(), | ||
) | ||
|
||
# Need to step and zero_grad in order to initialize all the optimizer parameters | ||
optim = torch.optim.Adam(model.parameters(), lr=optim_lr) | ||
optim.step(closure=None) | ||
optim.zero_grad(set_to_none=True) | ||
|
||
optim = FSDPOptimizerAdapter(model, optim) | ||
|
||
return model, optim | ||
|
||
@with_comms | ||
@skip_if_lt_x_gpu(WORLD_SIZE) | ||
def test_save_and_load_same_world_size(self): | ||
mesh_2d = init_device_mesh("cuda", (2, WORLD_SIZE // 2)) | ||
src_model, src_optim = self._create_model( | ||
seed=42, optim_lr=0.1, device_mesh=mesh_2d | ||
) | ||
dst_model, dst_optim = self._create_model( | ||
seed=24, optim_lr=0.2, device_mesh=mesh_2d | ||
) | ||
assert not check_state_dict_eq(src_model.state_dict(), dst_model.state_dict()) | ||
assert not check_state_dict_eq(src_optim.state_dict(), dst_optim.state_dict()) | ||
|
||
tmp_path = f"/tmp/{uuid.uuid4()}" | ||
if dist.get_rank() == 0: | ||
logger.info(f"Saving to {tmp_path}") | ||
|
||
snapshot = Snapshot.take( | ||
str(tmp_path), {"model": src_model, "optim": src_optim} | ||
) | ||
snapshot.restore({"model": dst_model, "optim": dst_optim}) | ||
logging.info(f"{dst_model.state_dict()}") | ||
assert check_state_dict_eq(dst_model.state_dict(), src_model.state_dict()) | ||
assert check_state_dict_eq(dst_optim.state_dict(), src_optim.state_dict()) |