Skip to content

Commit

Permalink
add method to get state dict
Browse files Browse the repository at this point in the history
Reviewed By: JKSenthil

Differential Revision: D54447983

fbshipit-source-id: b458639aab4bdf2825865304eda6a06d70600393
  • Loading branch information
galrotem authored and facebook-github-bot committed Mar 6, 2024
1 parent e0184bf commit 6bd9dc6
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 0 deletions.
61 changes: 61 additions & 0 deletions tests/gpu_tests/test_state_dict_fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from pathlib import Path

import pytest

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
from torchsnapshot import Snapshot
from torchsnapshot.test_utils import check_state_dict_eq, run_with_pet


def _create_fsdp_model(
seed: int,
device: torch.device,
) -> torch.nn.Module:
torch.manual_seed(seed)
model = torch.nn.Linear(32, 32)

fsdp_model = FSDP(
module=model,
device_id=device,
)
FSDP.set_state_dict_type(fsdp_model, StateDictType.SHARDED_STATE_DICT)
return fsdp_model


@pytest.mark.skipif(
bool(not torch.cuda.is_available()), reason="The test requires GPUs to run."
)
@pytest.mark.skipif(
bool(torch.cuda.device_count() < 2), reason="At least two GPUs are required."
)
@run_with_pet(nproc=2)
def test_model_and_optim_fsdp(tmp_path: Path) -> None:
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)

fsdp_model = _create_fsdp_model(17, device)

snapshot = Snapshot.take(
path=str(tmp_path),
app_state={"fsdp_model": fsdp_model},
)
state_dict_from_method = snapshot.get_state_dict_for_key("fsdp_model")
FSDP.set_state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT)

full_state_dict = fsdp_model.state_dict()
for k, v in full_state_dict.items():
full_state_dict[k] = v.cpu()

assert check_state_dict_eq(full_state_dict, state_dict_from_method)
69 changes: 69 additions & 0 deletions tests/test_state_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import tempfile
import unittest
from typing import cast, Dict

import torch
import torchsnapshot
from torchsnapshot import Stateful


class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = torch.nn.Parameter(torch.randn(20, 20))


class MyStateful(Stateful):
def __init__(self) -> None:
self.foo = 1
self.bar = "bar"

def state_dict(self) -> Dict[str, object]:
return {"foo": self.foo, "bar": self.bar}

def load_state_dict(self, state_dict: Dict[str, object]) -> None:
self.foo = cast(int, state_dict["foo"])
self.bar = cast(str, state_dict["bar"])


class StateDictTest(unittest.TestCase):
def test_get_state_dict(self) -> None:
my_module = MyModule()
with tempfile.TemporaryDirectory() as path:
torchsnapshot.Snapshot.take(
path=path,
app_state={"my_module": my_module},
)
snapshot = torchsnapshot.Snapshot(path)
state_dict = snapshot.get_state_dict_for_key("my_module")
self.assertTrue(torch.allclose(state_dict["foo"], my_module.foo))

def test_get_state_dict_with_invalid_key(self) -> None:
my_module = MyModule()
with tempfile.TemporaryDirectory() as path:
torchsnapshot.Snapshot.take(
path=path,
app_state={"my_module": my_module},
)
snapshot = torchsnapshot.Snapshot(path)
with self.assertRaisesRegex(
AssertionError, "is absent in both manifest and flattened"
):
snapshot.get_state_dict_for_key("invalid_key")

def test_generic_stateful(self) -> None:
my_stateful = MyStateful()
my_stateful.foo = 2
my_stateful.bar = "baz"
with tempfile.TemporaryDirectory() as path:
snapshot = torchsnapshot.Snapshot(path)
snapshot.take(path, app_state={"my_stateful": my_stateful})
state_dict = snapshot.get_state_dict_for_key("my_stateful")
self.assertDictEqual(state_dict, my_stateful.state_dict())
39 changes: 39 additions & 0 deletions torchsnapshot/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,45 @@ def _validate_app_state(app_state: AppState) -> None:
f"Expected Stateful in app_state for key {key}, got {value_type}."
)

# pyre-fixme: inflate returns Dict[Any,Any]
# Missing return annotation [3]: Return type must be specified as type that does not contain `Any`
def get_state_dict_for_key(self, key: str) -> Dict[Any, Any]:
"""
Gets the state dict for a selected key in the snapshot.
This is useful in case you want to get the state dict without loading it to the stateful.
Args:
key (str): The key to get the state dict for. Assumes the key was stored as a topline
key in the snapshot.
Returns:
The state dict associated with the key.
Below is a usage example
.. code-block:: python
snapshot = Snapshot.take(path=..., app_state={"stateful_key": module})
module_state_dict = snapshot.get_state_dict_for_key("stateful_key")
"""
event_loop = asyncio.new_event_loop()
pg = PGWrapper(self.pg)

manifest, _ = get_manifest_for_rank(metadata=self.metadata, rank=pg.get_rank())

# filter out irrelevant entries from the manifest
manifest = {k: v for k, v in manifest.items() if k.split("/")[0] == key}

storage = url_to_storage_plugin_in_event_loop(
url_path=self.path,
event_loop=event_loop,
storage_options=self._storage_options,
)

return self._get_state_dict_for_manifest(
key, manifest, {}, pg, storage, event_loop
)

def _load_stateful( # noqa
self,
stateful_key: str,
Expand Down

0 comments on commit 6bd9dc6

Please sign in to comment.