Skip to content

Commit

Permalink
add ckpt tests
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Dec 19, 2024
1 parent 1a3b439 commit 93c46db
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 12 deletions.
41 changes: 29 additions & 12 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def log_hash_training_state(
inner_optimizer: torch.optim.Optimizer,
diloco: Diloco | None,
metric_logger: MetricLogger,
step: int,
id: str = "",
):
"""Log the hash of the model and optimizer. This function is slow"""
Expand All @@ -143,10 +144,11 @@ def log_hash_training_state(
logger.debug(f"inner diloco model {id} : {inner_model_hash}")
logger.debug(f"inner optimizer hash {id} : {inner_optimizer_hash}")

if world_info.rank == 0:
metric_logger.log(
{"inner_model_hash_{id}": inner_model_hash, "inner_optimizer_hash_{id}": inner_optimizer_hash}
)
metrics = {
"step": step,
"inner_model_hash_{id}": inner_model_hash,
"inner_optimizer_hash_{id}": inner_optimizer_hash,
}

if config.diloco is not None and diloco is not None:
outer_optimizer_hash = get_optimizer_signature(diloco.outer_optimizer)
Expand All @@ -155,10 +157,11 @@ def log_hash_training_state(
logger.debug(f"outer diloco optimizer hash {id} : {outer_optimizer_hash}")
logger.debug(f"outer diloco model hash {id} : {outer_model_hash}")

if world_info.rank == 0:
metric_logger.log(
{f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash}
)
metrics.update(
{f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash}
)
if world_info.rank == 0:
metric_logger.log(metrics)


def train(config: Config):
Expand Down Expand Up @@ -300,7 +303,9 @@ def train(config: Config):
skip_dataloader=config.ckpt.skip_dataloader,
data_path=config.ckpt.data_path,
)
log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="resume")
log_hash_training_state(
config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="resume"
)

if config.train.memory_monitor:
gpu_mem_monitor = GPUMemoryMonitor()
Expand Down Expand Up @@ -349,7 +354,15 @@ def train(config: Config):

ckpt_manager.recv_ckpt_from_peer(elastic_device_mesh.global_pg)

log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="live_reco_recv")
log_hash_training_state(
config,
model,
inner_optimizer,
diloco,
metric_logger,
step=training_progress.step,
id="live_reco_recv",
)
need_live_recovery = False

if config.ckpt.remote_data_load:
Expand Down Expand Up @@ -483,7 +496,9 @@ def train(config: Config):
diloco.step(model=model, flag=training_progress.outer_step)
diloco_time = time.perf_counter() - time_start_inner

log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="outer_step")
log_hash_training_state(
config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="outer_step"
)

training_progress.outer_step += 1

Expand All @@ -496,7 +511,9 @@ def train(config: Config):

do_remote = config.ckpt.remote is not None and training_progress.step % config.ckpt.remote.interval == 0
ckpt_manager.save(remote=do_remote)
log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="ckpt save")
log_hash_training_state(
config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="ckpt save"
)

if config.diloco:
tokens_per_second = (
Expand Down
87 changes: 87 additions & 0 deletions tests/test_torchrun/test_train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
import os
from pathlib import Path
import pickle
import subprocess
import pytest
import socket
Expand Down Expand Up @@ -112,3 +114,88 @@ def test_packing(packing: bool):
num_gpus = [2, 1]
packing_arg = "--train.sequence_packing" if packing else "--no-train.sequence_packing"
_test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[packing_arg])


def test_ckpt(tmp_path: Path):
num_gpus = [1, 2]
v1_file = tmp_path / "v1.log"
v2_file = tmp_path / "v2.log"

v1_ckpt = tmp_path / "v1_ckpt"
v2_ckpt = tmp_path / "v2_ckpt"

os.mkdir(v1_ckpt)
os.mkdir(v2_ckpt)

_test_multi_gpu(
num_gpus,
"debug/diloco.toml",
extra_args=[
"--project",
str(v1_file),
"--ckpt.path",
str(v1_ckpt),
"--ckpt.interval",
"5",
"--optim.total_steps",
"20",
"--train.log_model_hash",
],
diloco=True,
)
_test_multi_gpu(
num_gpus,
"debug/diloco.toml",
extra_args=[
"--project",
str(v2_file),
"--ckpt.path",
str(v2_ckpt),
"--ckpt.interval",
"5",
"--optim.total_steps",
"20",
"--train.log_model_hash",
],
diloco=True,
)

key_to_remove = [
"remaining_cpu_ram",
"time",
"tokens_per_second",
"mfu",
"outer_mfu",
"outer_tokens_per_second",
"all_reduce_step",
]
key_to_round = ["Perplexity", "Loss"]
digit_to_round = [0, 3]

def read_logs(path: Path):
with path.open("rb") as f:
data = pickle.load(f)

filtered_data = {}
for entry in data:
step = entry.pop("step")
# Remove unwanted columns
for key in key_to_remove:
entry.pop(key, None)

# Round perplexity and loss
for key, digit in zip(key_to_round, digit_to_round):
if key in entry:
entry[key] = round(entry[key], digit)

if step in filtered_data:
filtered_data[step].update(entry)
else:
filtered_data[step] = entry

return filtered_data

v1_data = read_logs(v1_file)
v2_data = read_logs(v2_file)

assert v1_data == v2_data

0 comments on commit 93c46db

Please sign in to comment.