Skip to content

Commit

Permalink
End-to-end test saving/load module with DTensor
Browse files Browse the repository at this point in the history
Summary: Add test that takes and restores a snapshot of an HSDP sharded model, which uses DTensors that are partially replicated (placements = [Replicate(), Shard(0)])

Reviewed By: ananthsub

Differential Revision: D49470645

fbshipit-source-id: 5ec5b8af6eb04eb60ff25cd1707e6be0d577472e
  • Loading branch information
Rafi Ayub authored and facebook-github-bot committed Oct 26, 2023
1 parent 8bce08e commit 430641f
Showing 1 changed file with 120 additions and 0 deletions.
120 changes: 120 additions & 0 deletions tests/gpu_tests/test_snapshot_dtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import uuid
from typing import Optional

import torch
from torch import distributed as dist, nn
from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
StateDictType,
)
from torch.distributed.fsdp.api import (
ShardedOptimStateDictConfig,
ShardedStateDictConfig,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
from torchsnapshot import Snapshot
from torchsnapshot.test_utils import check_state_dict_eq
from torchsnapshot.tricks.fsdp import FSDPOptimizerAdapter

logger: logging.Logger = logging.getLogger(__name__)


WORLD_SIZE: int = 4


class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
self.net3 = nn.Linear(32, 64)
self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))

def forward(self, x):
return self.net4(self.net3(self.net2(self.net1(x))))

def get_input(self):
return torch.rand(4, 8, device="cuda")


# TODO: Test different world sizes (may require not using DTensorTestBase)
# TODO: Test FSDP + TP once dim_map is updated for [Shard(0), Shard(0)] cases
class TestSnapshotWithDTensor(DTensorTestBase):
def _create_model(
self, seed: int, optim_lr: float, device_mesh: Optional[DeviceMesh] = None
):
torch.manual_seed(seed)
# Using HSDP model as an example model that uses DTensor
# This should create model with placements
# [Replicate(), Shard(0)]
if device_mesh:
model = FSDP(
DummyModel().cuda(),
device_mesh=device_mesh,
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
else:
mesh_2d = init_device_mesh("cuda", (2, WORLD_SIZE // 2))
intra_node_pg = mesh_2d.get_dim_groups(mesh_dim=1)
inter_node_pg = mesh_2d.get_dim_groups(mesh_dim=0)
model = FSDP(
DummyModel().cuda(),
process_group=(intra_node_pg, inter_node_pg),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)

FSDP.set_state_dict_type(
model,
StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
optim_state_dict_config=ShardedOptimStateDictConfig(),
)

# Need to step and zero_grad in order to initialize all the optimizer parameters
optim = torch.optim.Adam(model.parameters(), lr=optim_lr)
optim.step(closure=None)
optim.zero_grad(set_to_none=True)

optim = FSDPOptimizerAdapter(model, optim)

return model, optim

@with_comms
@skip_if_lt_x_gpu(WORLD_SIZE)
def test_save_and_load_same_world_size(self):
mesh_2d = init_device_mesh("cuda", (2, WORLD_SIZE // 2))
src_model, src_optim = self._create_model(
seed=42, optim_lr=0.1, device_mesh=mesh_2d
)
dst_model, dst_optim = self._create_model(
seed=24, optim_lr=0.2, device_mesh=mesh_2d
)
assert not check_state_dict_eq(src_model.state_dict(), dst_model.state_dict())
assert not check_state_dict_eq(src_optim.state_dict(), dst_optim.state_dict())

tmp_path = f"/tmp/{uuid.uuid4()}"
if dist.get_rank() == 0:
logger.info(f"Saving to {tmp_path}")

snapshot = Snapshot.take(
str(tmp_path), {"model": src_model, "optim": src_optim}
)
snapshot.restore({"model": dst_model, "optim": dst_optim})
logging.info(f"{dst_model.state_dict()}")
assert check_state_dict_eq(dst_model.state_dict(), src_model.state_dict())
assert check_state_dict_eq(dst_optim.state_dict(), src_optim.state_dict())

0 comments on commit 430641f

Please sign in to comment.