diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index de25afe5..b240c4b5 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -546,7 +546,7 @@ def recv_ckpt_from_peer(self, global_pg: dist.ProcessGroup): ) @torch.no_grad() - def send_ckpt_to_peer(self, global_pg: dist.ProcessGroup, dest_rank: int): + def send_ckpt_to_peer(self, global_pg: dist.ProcessGroup, dest_rank: int, blocking: bool = False): def async_send(): assert self.diloco_offloaded_param_list is not None, "send_ckpt_to_peers is only supported with diloco" time_start = time.perf_counter() @@ -582,8 +582,10 @@ def async_send(): thread = threading.Thread(target=async_send) thread.start() self._logger.debug("Live recovery thread started") - - self._live_reco_thread = thread + if blocking: + thread.join() + else: + self._live_reco_thread = thread def delete_topk(ckpt_path: str, topk: int): diff --git a/src/zeroband/train.py b/src/zeroband/train.py index af40bc45..fc47503b 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -310,7 +310,7 @@ def train(config: Config): ) logger.info(f"inner optimizer hash: {get_optimizer_signature(inner_optimizer)}") - ckpt_manager.send_ckpt_to_peer(elastic_device_mesh.global_pg, maybe_dest_rank) + ckpt_manager.send_ckpt_to_peer(elastic_device_mesh.global_pg, maybe_dest_rank, blocking=True) elastic_device_mesh.live_recovery.reset() else: