From 5ac2bde5a738daa067b9f7e262fcd69323dcf17a Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 9 Jan 2025 08:55:57 +0000 Subject: [PATCH] wip add ckpt shampoo --- src/zeroband/checkpoint.py | 25 ++++++++++++++++--------- src/zeroband/utils/__init__.py | 4 ++++ tests/test_torchrun/test_train.py | 9 ++++++--- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index b240c4b5..4acd9da3 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -38,6 +38,7 @@ send_state_dict, send_tensor_and_state_dict, ) +from distributed_shampoo import DistributedShampoo from zeroband.utils.world_info import get_world_info @@ -80,17 +81,23 @@ def __init__( self.optim = optim def state_dict(self) -> dict[str, Any]: - return get_optimizer_state_dict( - model=self.model, optimizers=self.optim, options=StateDictOptions(flatten_optimizer_state_dict=True) - ) + if isinstance(self.optim, DistributedShampoo): + return self.optim.distributed_state_dict(key_to_param=self.model.named_parameters()) + else: + return get_optimizer_state_dict( + model=self.model, optimizers=self.optim, options=StateDictOptions(flatten_optimizer_state_dict=True) + ) def load_state_dict(self, state_dict: dict[str, Any]) -> None: - set_optimizer_state_dict( - model=self.model, - optimizers=self.optim, - optim_state_dict=state_dict, - options=StateDictOptions(flatten_optimizer_state_dict=True), - ) + if isinstance(self.optim, DistributedShampoo): + self.optim.load_distributed_state_dict(state_dict, key_to_param=self.model.named_parameters()) + else: + set_optimizer_state_dict( + model=self.model, + optimizers=self.optim, + optim_state_dict=state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) def cast_dtensor_to_tensor(state_dict: dict[str, Any]) -> dict[str, Any]: diff --git a/src/zeroband/utils/__init__.py b/src/zeroband/utils/__init__.py index c0ea3699..f6b1c915 100644 --- a/src/zeroband/utils/__init__.py +++ b/src/zeroband/utils/__init__.py @@ -4,6 +4,7 @@ import torch from torch.distributed.fsdp import ShardingStrategy from torch.distributed._tensor.api import DTensor +from distributed_shampoo import DistributedShampoo __all__ = ["get_sharding_strategy", "get_peak_flops", "get_num_flop_per_token", "get_num_params"] @@ -138,6 +139,9 @@ def get_optimizer_signature(optimizer: torch.optim.Optimizer, compress: bool = T Get the optimizer signature """ + if isinstance(optimizer, DistributedShampoo): + return "mocked signature because shampoo does not support state_dict()" + def unwrap_tensor(state_dict: dict) -> dict: new_dict = {} for key, value in state_dict.items(): diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index 4398c613..7b33a620 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -127,7 +127,8 @@ def test_soap(diloco: bool): ) -def test_ckpt(tmp_path: Path): +@pytest.mark.parametrize("soap", [False, True]) +def test_ckpt(tmp_path: Path, soap: bool): num_gpus = [1, 2] v1_file = tmp_path / "v1.log" v2_file = tmp_path / "v2.log" @@ -157,7 +158,8 @@ def test_ckpt(tmp_path: Path): "--no-train.sequence_packing", "--train.attn_fn", "math", - ], + ] + + (["--optim.optim.precondition_frequency", "1"] if soap else []), diloco=True, ) _test_multi_gpu( @@ -178,7 +180,8 @@ def test_ckpt(tmp_path: Path): "--no-train.sequence_packing", "--train.attn_fn", "math", - ], + ] + + (["--optim.optim.precondition_frequency", "1"] if soap else []), diloco=True, ) # _test_multi_gpu(