From c3fa9d226fef9ae65c0db8fca70ec6499280c8d2 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Wed, 25 Sep 2024 17:31:08 +0000 Subject: [PATCH] refactor: move pg concerns into edm --- src/zeroband/comms.py | 30 ++++++++++++++++++++++++++++++ src/zeroband/diloco.py | 25 +------------------------ src/zeroband/train.py | 6 ++---- 3 files changed, 33 insertions(+), 28 deletions(-) create mode 100644 src/zeroband/comms.py diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py new file mode 100644 index 00000000..0c2af6f8 --- /dev/null +++ b/src/zeroband/comms.py @@ -0,0 +1,30 @@ +from torch.distributed.device_mesh import init_device_mesh +from zeroband.utils.world_info import get_world_info +from zeroband.utils.logging import get_logger +import torch.distributed as dist + + +class ElasticDeviceMesh: + """Init two process group through device mesh, one local on gpu and one global on cpu""" + + def __init__(self): + self._logger = get_logger() + + self.world_info = get_world_info() + + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + # right now device mesh does not support two backend so we just create two identicaly mesh expect the backend + self.device_mesh = init_device_mesh( + "cuda", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local") + ) + self.device_mesh_cpu = init_device_mesh( + "gloo", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local") + ) + + self.global_pg = self.device_mesh_cpu.get_group("global") + self.local_pg = self.device_mesh.get_group("local") + + self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}") + + def __del__(self): + dist.destroy_process_group() diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 26bba0a3..5d218176 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -1,9 +1,9 @@ from pydantic_config import BaseConfig import torch -from torch.distributed.device_mesh import init_device_mesh from torch import nn from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger +from zeroband.comms import ElasticDeviceMesh from torch.distributed.fsdp import ShardingStrategy import torch.distributed as dist @@ -12,29 +12,6 @@ class DilocoConfig(BaseConfig): outer_lr: float = 0.7 inner_steps: int - -class ElasticDeviceMesh: - """Init two process group through device mesh, one local on gpu and one global on cpu""" - - def __init__(self): - self._logger = get_logger() - - self.world_info = get_world_info() - - # right now device mesh does not support two backend so we just create two identicaly mesh expect the backend - self.device_mesh = init_device_mesh( - "cuda", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local") - ) - self.device_mesh_cpu = init_device_mesh( - "gloo", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local") - ) - - self.global_pg = self.device_mesh_cpu.get_group("global") - self.local_pg = self.device_mesh.get_group("local") - - self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}") - - class Diloco: """ This class implements the diloco algorithm from https://arxiv.org/abs/2311.08105 and https://arxiv.org/abs/2407.07852. diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 6f52d94d..7520b025 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -5,7 +5,6 @@ import torch from pydantic_config import parse_argv, BaseConfig -from torch.distributed import destroy_process_group, init_process_group from einops import rearrange from torch.nn import functional as F @@ -19,7 +18,8 @@ ) import torch.distributed as dist from zeroband import utils -from zeroband.diloco import Diloco, DilocoConfig, ElasticDeviceMesh +from zeroband.diloco import Diloco, DilocoConfig +from zeroband.comms import ElasticDeviceMesh from zeroband.utils import get_sharding_strategy from zeroband.utils.monitor import WandbMonitor, DummyMonitor @@ -248,11 +248,9 @@ def train(config: Config): world_info = get_world_info() logger = get_logger() - init_process_group() torch.cuda.set_device(world_info.local_rank) config = Config(**parse_argv()) logger.debug(f"config: {config.model_dump()}") train(config) - destroy_process_group()