Skip to content

Commit

Permalink
wip add ckpt shampoo
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Jan 9, 2025
1 parent 4221cc0 commit 5ac2bde
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 12 deletions.
25 changes: 16 additions & 9 deletions src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
send_state_dict,
send_tensor_and_state_dict,
)
from distributed_shampoo import DistributedShampoo

from zeroband.utils.world_info import get_world_info

Expand Down Expand Up @@ -80,17 +81,23 @@ def __init__(
self.optim = optim

def state_dict(self) -> dict[str, Any]:
return get_optimizer_state_dict(
model=self.model, optimizers=self.optim, options=StateDictOptions(flatten_optimizer_state_dict=True)
)
if isinstance(self.optim, DistributedShampoo):
return self.optim.distributed_state_dict(key_to_param=self.model.named_parameters())
else:
return get_optimizer_state_dict(
model=self.model, optimizers=self.optim, options=StateDictOptions(flatten_optimizer_state_dict=True)
)

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
set_optimizer_state_dict(
model=self.model,
optimizers=self.optim,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
if isinstance(self.optim, DistributedShampoo):
self.optim.load_distributed_state_dict(state_dict, key_to_param=self.model.named_parameters())
else:
set_optimizer_state_dict(
model=self.model,
optimizers=self.optim,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)


def cast_dtensor_to_tensor(state_dict: dict[str, Any]) -> dict[str, Any]:
Expand Down
4 changes: 4 additions & 0 deletions src/zeroband/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed._tensor.api import DTensor
from distributed_shampoo import DistributedShampoo


__all__ = ["get_sharding_strategy", "get_peak_flops", "get_num_flop_per_token", "get_num_params"]
Expand Down Expand Up @@ -138,6 +139,9 @@ def get_optimizer_signature(optimizer: torch.optim.Optimizer, compress: bool = T
Get the optimizer signature
"""

if isinstance(optimizer, DistributedShampoo):
return "mocked signature because shampoo does not support state_dict()"

def unwrap_tensor(state_dict: dict) -> dict:
new_dict = {}
for key, value in state_dict.items():
Expand Down
9 changes: 6 additions & 3 deletions tests/test_torchrun/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def test_soap(diloco: bool):
)


def test_ckpt(tmp_path: Path):
@pytest.mark.parametrize("soap", [False, True])
def test_ckpt(tmp_path: Path, soap: bool):
num_gpus = [1, 2]
v1_file = tmp_path / "v1.log"
v2_file = tmp_path / "v2.log"
Expand Down Expand Up @@ -157,7 +158,8 @@ def test_ckpt(tmp_path: Path):
"--no-train.sequence_packing",
"--train.attn_fn",
"math",
],
]
+ (["--optim.optim.precondition_frequency", "1"] if soap else []),
diloco=True,
)
_test_multi_gpu(
Expand All @@ -178,7 +180,8 @@ def test_ckpt(tmp_path: Path):
"--no-train.sequence_packing",
"--train.attn_fn",
"math",
],
]
+ (["--optim.optim.precondition_frequency", "1"] if soap else []),
diloco=True,
)
# _test_multi_gpu(
Expand Down

0 comments on commit 5ac2bde

Please sign in to comment.