diff --git a/src/zeroband/train.py b/src/zeroband/train.py index d54391e3..551bdad2 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -133,6 +133,7 @@ def log_hash_training_state( inner_optimizer: torch.optim.Optimizer, diloco: Diloco | None, metric_logger: MetricLogger, + step: int, id: str = "", ): """Log the hash of the model and optimizer. This function is slow""" @@ -143,10 +144,11 @@ def log_hash_training_state( logger.debug(f"inner diloco model {id} : {inner_model_hash}") logger.debug(f"inner optimizer hash {id} : {inner_optimizer_hash}") - if world_info.rank == 0: - metric_logger.log( - {"inner_model_hash_{id}": inner_model_hash, "inner_optimizer_hash_{id}": inner_optimizer_hash} - ) + metrics = { + "step": step, + "inner_model_hash_{id}": inner_model_hash, + "inner_optimizer_hash_{id}": inner_optimizer_hash, + } if config.diloco is not None and diloco is not None: outer_optimizer_hash = get_optimizer_signature(diloco.outer_optimizer) @@ -155,10 +157,11 @@ def log_hash_training_state( logger.debug(f"outer diloco optimizer hash {id} : {outer_optimizer_hash}") logger.debug(f"outer diloco model hash {id} : {outer_model_hash}") - if world_info.rank == 0: - metric_logger.log( - {f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash} - ) + metrics.update( + {f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash} + ) + if world_info.rank == 0: + metric_logger.log(metrics) def train(config: Config): @@ -300,7 +303,9 @@ def train(config: Config): skip_dataloader=config.ckpt.skip_dataloader, data_path=config.ckpt.data_path, ) - log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="resume") + log_hash_training_state( + config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="resume" + ) if config.train.memory_monitor: gpu_mem_monitor = GPUMemoryMonitor() @@ -349,7 +354,15 @@ def train(config: Config): ckpt_manager.recv_ckpt_from_peer(elastic_device_mesh.global_pg) - log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="live_reco_recv") + log_hash_training_state( + config, + model, + inner_optimizer, + diloco, + metric_logger, + step=training_progress.step, + id="live_reco_recv", + ) need_live_recovery = False if config.ckpt.remote_data_load: @@ -483,7 +496,9 @@ def train(config: Config): diloco.step(model=model, flag=training_progress.outer_step) diloco_time = time.perf_counter() - time_start_inner - log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="outer_step") + log_hash_training_state( + config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="outer_step" + ) training_progress.outer_step += 1 @@ -496,7 +511,9 @@ def train(config: Config): do_remote = config.ckpt.remote is not None and training_progress.step % config.ckpt.remote.interval == 0 ckpt_manager.save(remote=do_remote) - log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="ckpt save") + log_hash_training_state( + config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="ckpt save" + ) if config.diloco: tokens_per_second = ( diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index e5703fe3..e1cda601 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -1,5 +1,7 @@ import copy import os +from pathlib import Path +import pickle import subprocess import pytest import socket @@ -112,3 +114,88 @@ def test_packing(packing: bool): num_gpus = [2, 1] packing_arg = "--train.sequence_packing" if packing else "--no-train.sequence_packing" _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[packing_arg]) + + +def test_ckpt(tmp_path: Path): + num_gpus = [1, 2] + v1_file = tmp_path / "v1.log" + v2_file = tmp_path / "v2.log" + + v1_ckpt = tmp_path / "v1_ckpt" + v2_ckpt = tmp_path / "v2_ckpt" + + os.mkdir(v1_ckpt) + os.mkdir(v2_ckpt) + + _test_multi_gpu( + num_gpus, + "debug/diloco.toml", + extra_args=[ + "--project", + str(v1_file), + "--ckpt.path", + str(v1_ckpt), + "--ckpt.interval", + "5", + "--optim.total_steps", + "20", + "--train.log_model_hash", + ], + diloco=True, + ) + _test_multi_gpu( + num_gpus, + "debug/diloco.toml", + extra_args=[ + "--project", + str(v2_file), + "--ckpt.path", + str(v2_ckpt), + "--ckpt.interval", + "5", + "--optim.total_steps", + "20", + "--train.log_model_hash", + ], + diloco=True, + ) + + key_to_remove = [ + "remaining_cpu_ram", + "time", + "tokens_per_second", + "mfu", + "outer_mfu", + "outer_tokens_per_second", + "all_reduce_step", + ] + key_to_round = ["Perplexity", "Loss"] + digit_to_round = [0, 3] + + def read_logs(path: Path): + with path.open("rb") as f: + data = pickle.load(f) + + filtered_data = {} + for entry in data: + step = entry.pop("step") + # Remove unwanted columns + for key in key_to_remove: + entry.pop(key, None) + + # Round perplexity and loss + for key, digit in zip(key_to_round, digit_to_round): + if key in entry: + entry[key] = round(entry[key], digit) + + if step in filtered_data: + filtered_data[step].update(entry) + else: + filtered_data[step] = entry + + return filtered_data + + v1_data = read_logs(v1_file) + v2_data = read_logs(v2_file) + + assert v1_data == v2_data