Skip to content

Commit

Permalink
add topk compression ratio
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Jan 7, 2025
1 parent fa228bc commit 7a46817
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
10 changes: 9 additions & 1 deletion src/zeroband/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TypeAlias
from pydantic import model_validator
from pydantic_config import BaseConfig
import torch
from distributed_shampoo.shampoo_types import EigenvalueCorrectedShampooPreconditionerConfig
Expand All @@ -20,7 +21,14 @@ class SoapConfig(BaseConfig):
max_preconditioner_dim: int = 8192
precondition_frequency: int = 100

topk_compression: int | None = None
topk_compression: int | float | None = None

@model_validator(mode="after")
def validate_topk_compression(self):
if isinstance(self.topk_compression, float):
if not 0 < self.topk_compression <= 1:
raise ValueError("If topk_compression is float, it must be between 0 and 1")
return self


OptimizersConfig: TypeAlias = AdamConfig | MuonConfig | SoapConfig
Expand Down
2 changes: 1 addition & 1 deletion tests/test_torchrun/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_muon(diloco: bool):


@pytest.mark.parametrize("diloco", [False, True])
@pytest.mark.parametrize("topk_compression", [None, 5])
@pytest.mark.parametrize("topk_compression", [None, 5, 0.1])
def test_soap(diloco: bool, topk_compression: int | None):
num_gpus = [1, 2] if diloco else [2, 1]

Expand Down
2 changes: 1 addition & 1 deletion third_party/optimizers

0 comments on commit 7a46817

Please sign in to comment.