diff --git a/tests/gpu_tests/test_state_dict_fsdp.py b/tests/gpu_tests/test_state_dict_fsdp.py new file mode 100644 index 0000000..7478b43 --- /dev/null +++ b/tests/gpu_tests/test_state_dict_fsdp.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +from pathlib import Path + +import pytest + +import torch +import torch.distributed as dist +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType +from torchsnapshot import Snapshot +from torchsnapshot.test_utils import check_state_dict_eq, run_with_pet + + +def _create_fsdp_model( + seed: int, + device: torch.device, +) -> torch.nn.Module: + torch.manual_seed(seed) + model = torch.nn.Linear(32, 32) + + fsdp_model = FSDP( + module=model, + device_id=device, + ) + FSDP.set_state_dict_type(fsdp_model, StateDictType.SHARDED_STATE_DICT) + return fsdp_model + + +@pytest.mark.skipif( + bool(not torch.cuda.is_available()), reason="The test requires GPUs to run." +) +@pytest.mark.skipif( + bool(torch.cuda.device_count() < 2), reason="At least two GPUs are required." +) +@run_with_pet(nproc=2) +def test_model_and_optim_fsdp(tmp_path: Path) -> None: + dist.init_process_group(backend="nccl") + local_rank = int(os.environ["LOCAL_RANK"]) + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + + fsdp_model = _create_fsdp_model(17, device) + + snapshot = Snapshot.take( + path=str(tmp_path), + app_state={"fsdp_model": fsdp_model}, + ) + state_dict_from_method = snapshot.get_state_dict_for_key("fsdp_model") + FSDP.set_state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT) + + full_state_dict = fsdp_model.state_dict() + for k, v in full_state_dict.items(): + full_state_dict[k] = v.cpu() + + assert check_state_dict_eq(full_state_dict, state_dict_from_method) diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py new file mode 100644 index 0000000..3c44e52 --- /dev/null +++ b/tests/test_state_dict.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest +from typing import cast, Dict + +import torch +import torchsnapshot +from torchsnapshot import Stateful + + +class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.foo = torch.nn.Parameter(torch.randn(20, 20)) + + +class MyStateful(Stateful): + def __init__(self) -> None: + self.foo = 1 + self.bar = "bar" + + def state_dict(self) -> Dict[str, object]: + return {"foo": self.foo, "bar": self.bar} + + def load_state_dict(self, state_dict: Dict[str, object]) -> None: + self.foo = cast(int, state_dict["foo"]) + self.bar = cast(str, state_dict["bar"]) + + +class StateDictTest(unittest.TestCase): + def test_get_state_dict(self) -> None: + my_module = MyModule() + with tempfile.TemporaryDirectory() as path: + torchsnapshot.Snapshot.take( + path=path, + app_state={"my_module": my_module}, + ) + snapshot = torchsnapshot.Snapshot(path) + state_dict = snapshot.get_state_dict_for_key("my_module") + self.assertTrue(torch.allclose(state_dict["foo"], my_module.foo)) + + def test_get_state_dict_with_invalid_key(self) -> None: + my_module = MyModule() + with tempfile.TemporaryDirectory() as path: + torchsnapshot.Snapshot.take( + path=path, + app_state={"my_module": my_module}, + ) + snapshot = torchsnapshot.Snapshot(path) + with self.assertRaisesRegex( + AssertionError, "is absent in both manifest and flattened" + ): + snapshot.get_state_dict_for_key("invalid_key") + + def test_generic_stateful(self) -> None: + my_stateful = MyStateful() + my_stateful.foo = 2 + my_stateful.bar = "baz" + with tempfile.TemporaryDirectory() as path: + snapshot = torchsnapshot.Snapshot(path) + snapshot.take(path, app_state={"my_stateful": my_stateful}) + state_dict = snapshot.get_state_dict_for_key("my_stateful") + self.assertDictEqual(state_dict, my_stateful.state_dict()) diff --git a/torchsnapshot/snapshot.py b/torchsnapshot/snapshot.py index 2f281c3..8d53ae9 100644 --- a/torchsnapshot/snapshot.py +++ b/torchsnapshot/snapshot.py @@ -680,6 +680,45 @@ def _validate_app_state(app_state: AppState) -> None: f"Expected Stateful in app_state for key {key}, got {value_type}." ) + # pyre-fixme: inflate returns Dict[Any,Any] + # Missing return annotation [3]: Return type must be specified as type that does not contain `Any` + def get_state_dict_for_key(self, key: str) -> Dict[Any, Any]: + """ + Gets the state dict for a selected key in the snapshot. + This is useful in case you want to get the state dict without loading it to the stateful. + + Args: + key (str): The key to get the state dict for. Assumes the key was stored as a topline + key in the snapshot. + + Returns: + The state dict associated with the key. + + Below is a usage example + + .. code-block:: python + + snapshot = Snapshot.take(path=..., app_state={"stateful_key": module}) + module_state_dict = snapshot.get_state_dict_for_key("stateful_key") + """ + event_loop = asyncio.new_event_loop() + pg = PGWrapper(self.pg) + + manifest, _ = get_manifest_for_rank(metadata=self.metadata, rank=pg.get_rank()) + + # filter out irrelevant entries from the manifest + manifest = {k: v for k, v in manifest.items() if k.split("/")[0] == key} + + storage = url_to_storage_plugin_in_event_loop( + url_path=self.path, + event_loop=event_loop, + storage_options=self._storage_options, + ) + + return self._get_state_dict_for_manifest( + key, manifest, {}, pg, storage, event_loop + ) + def _load_stateful( # noqa self, stateful_key: str,