Skip to content

Commit

Permalink
make live reco send blocking
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Nov 14, 2024
1 parent f9106d9 commit a01904c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
8 changes: 5 additions & 3 deletions src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def recv_ckpt_from_peer(self, global_pg: dist.ProcessGroup):
)

@torch.no_grad()
def send_ckpt_to_peer(self, global_pg: dist.ProcessGroup, dest_rank: int):
def send_ckpt_to_peer(self, global_pg: dist.ProcessGroup, dest_rank: int, blocking: bool = False):
def async_send():
assert self.diloco_offloaded_param_list is not None, "send_ckpt_to_peers is only supported with diloco"
time_start = time.perf_counter()
Expand Down Expand Up @@ -582,8 +582,10 @@ def async_send():
thread = threading.Thread(target=async_send)
thread.start()
self._logger.debug("Live recovery thread started")

self._live_reco_thread = thread
if blocking:
thread.join()
else:
self._live_reco_thread = thread


def delete_topk(ckpt_path: str, topk: int):
Expand Down
2 changes: 1 addition & 1 deletion src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def train(config: Config):
)
logger.info(f"inner optimizer hash: {get_optimizer_signature(inner_optimizer)}")

ckpt_manager.send_ckpt_to_peer(elastic_device_mesh.global_pg, maybe_dest_rank)
ckpt_manager.send_ckpt_to_peer(elastic_device_mesh.global_pg, maybe_dest_rank, blocking=True)

elastic_device_mesh.live_recovery.reset()
else:
Expand Down

0 comments on commit a01904c

Please sign in to comment.