Skip to content

Commit

Permalink
allow ckpt during inner steps
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Oct 24, 2024
1 parent f4b7c85 commit 180cc16
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 14 deletions.
17 changes: 16 additions & 1 deletion src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,22 @@ def save(self, remote: bool = False) -> None:
if remote and self.config.remote is not None:
self._async_save_remote(step_ckpt_path, remote_ckpt_path)

def save_inner(self, ckpt_path: str):
self.wait_for_blocking_job()

step_ckpt_path = os.path.join(ckpt_path, f"step_{self.training_progress.step}", "inner_ckpt")

states = {
"model": ModelWrapper(self.model),
"optimizer": OptimizerWrapper(self.model, self.optimizer),
"scheduler": self.scheduler,
"training_progress": self.training_progress,
}

dcp.save(states, checkpoint_id=step_ckpt_path)

self._logger.info(f"Saved inner checkpoint to {step_ckpt_path}")

def _save(self, ckpt_path: str):
self.wait_for_blocking_job()

Expand All @@ -351,7 +367,6 @@ def _save(self, ckpt_path: str):
## 1. v1: save the dataloader in the same file as the outer optimizer
## 2. v2: save the dataloader in a data folder inside the ckpt path

## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk
with open(os.path.join(ckpt_path, f"__{self.world_info.local_rank}_0.pt"), "wb") as f:
state = {"data_loader": self.dataloader.state_dict()} if self.config.data_version == "v1" else {}
if self.diloco_offloaded_optimizer:
Expand Down
24 changes: 11 additions & 13 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,6 @@ class Config(BaseConfig):

ckpt: CkptConfig = CkptConfig()

@model_validator(mode="after")
def ckpt_diloco_step(self):
if self.ckpt is not None and self.ckpt.interval is not None and self.diloco is not None:
assert (
self.ckpt.interval % self.diloco.inner_steps == 0
), "ckpt interval must be a multiple of diloco inner steps as we only save at the end of an outer step"
return self

@model_validator(mode="after")
def validate_live_recovery_rank_src(self):
if self.ckpt is not None and self.ckpt.live_recovery_rank_src is not None and self.diloco is None:
Expand All @@ -130,11 +122,6 @@ def train(config: Config):
assert batch_size % config.train.micro_bs == 0
gradient_accumulation_steps = batch_size // config.train.micro_bs

if config.ckpt is not None and config.ckpt.interval is not None and config.diloco is not None:
assert (
config.ckpt.interval % config.diloco.inner_steps == 0
), "ckpt interval must be a multiple of diloco inner steps as we only save at the end of an outer step"

if config.type_model == "llama2":
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True)
elif config.type_model == "llama3":
Expand Down Expand Up @@ -455,6 +442,17 @@ def train(config: Config):
if config.train.memory_profiler is not None:
memory_profiler.step()

if (
config.ckpt is not None
and training_progress.step > 0
and training_progress.step % num_inner_steps != 0
and training_progress.step % config.ckpt.interval == 0
):
logger.info(
f"Saving inner step ckpt at {training_progress.step}. This will only save the inner model and optimizer"
)
ckpt_manager.save_inner(ckpt_path=config.ckpt.path)

if config.diloco is not None:
if config.train.log_model_hash:
logger.debug("Pre diloco model: %s", get_module_signature(model))
Expand Down

0 comments on commit 180cc16

Please sign in to comment.