From ea239849181ff3d730c663981c97375a4f6507f2 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 6 Jan 2025 09:55:46 +0000 Subject: [PATCH] upgrade to latest torch shampoo comit --- src/zeroband/optimizers/__init__.py | 6 ++---- uv.lock | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/zeroband/optimizers/__init__.py b/src/zeroband/optimizers/__init__.py index 2199668f..c311b900 100644 --- a/src/zeroband/optimizers/__init__.py +++ b/src/zeroband/optimizers/__init__.py @@ -3,7 +3,7 @@ import torch from zeroband.optimizers.muon import Muon, AdamConfig, MuonConfig from distributed_shampoo import ( - EighEigenvalueCorrectionConfig, + DefaultSOAPConfig, DistributedShampoo, FullyShardShampooConfig, ShampooPT2CompileConfig, @@ -52,9 +52,7 @@ def get_optimizer(params: list[torch.nn.Parameter], config: OptimizersConfig) -> max_preconditioner_dim=config.max_preconditioner_dim, precondition_frequency=config.precondition_frequency, use_decoupled_weight_decay=True, - # This can also be set to `QREigenvalueCorrectionConfig` which is less expensive - # and might therefore allow for a smaller `precondition_frequency`. - preconditioner_computation_config=EighEigenvalueCorrectionConfig(), + preconditioner_config=DefaultSOAPConfig, distributed_config=FullyShardShampooConfig(), shampoo_pt2_compile_config=ShampooPT2CompileConfig(enable_shampoo_pt2_dynamic_shape=False), ) diff --git a/uv.lock b/uv.lock index 3fa628ef..cc6fdf7a 100644 --- a/uv.lock +++ b/uv.lock @@ -2416,7 +2416,7 @@ wheels = [ [[package]] name = "torch-shampoo" version = "1.0.0" -source = { git = "https://github.com/facebookresearch/optimizers.git?rev=main#f9d2a8cb526709bd4b5ef71f8cca3705906a0f94" } +source = { git = "https://github.com/facebookresearch/optimizers.git?rev=main#c51e4e6c0a9a6e93163441a9b32bb65cc1c736a8" } dependencies = [ { name = "torch" }, ]