Skip to content

Commit

Permalink
refactor: move pg concerns into edm
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackmin801 committed Sep 25, 2024
1 parent be995d9 commit c3fa9d2
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 28 deletions.
30 changes: 30 additions & 0 deletions src/zeroband/comms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from torch.distributed.device_mesh import init_device_mesh
from zeroband.utils.world_info import get_world_info
from zeroband.utils.logging import get_logger
import torch.distributed as dist


class ElasticDeviceMesh:
"""Init two process group through device mesh, one local on gpu and one global on cpu"""

def __init__(self):
self._logger = get_logger()

self.world_info = get_world_info()

dist.init_process_group(backend="cpu:gloo,cuda:nccl")
# right now device mesh does not support two backend so we just create two identicaly mesh expect the backend
self.device_mesh = init_device_mesh(
"cuda", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local")
)
self.device_mesh_cpu = init_device_mesh(
"gloo", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local")
)

self.global_pg = self.device_mesh_cpu.get_group("global")
self.local_pg = self.device_mesh.get_group("local")

self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}")

def __del__(self):
dist.destroy_process_group()
25 changes: 1 addition & 24 deletions src/zeroband/diloco.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from pydantic_config import BaseConfig
import torch
from torch.distributed.device_mesh import init_device_mesh
from torch import nn
from zeroband.utils.world_info import get_world_info
from zeroband.utils.logging import get_logger
from zeroband.comms import ElasticDeviceMesh
from torch.distributed.fsdp import ShardingStrategy
import torch.distributed as dist

Expand All @@ -12,29 +12,6 @@ class DilocoConfig(BaseConfig):
outer_lr: float = 0.7
inner_steps: int


class ElasticDeviceMesh:
"""Init two process group through device mesh, one local on gpu and one global on cpu"""

def __init__(self):
self._logger = get_logger()

self.world_info = get_world_info()

# right now device mesh does not support two backend so we just create two identicaly mesh expect the backend
self.device_mesh = init_device_mesh(
"cuda", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local")
)
self.device_mesh_cpu = init_device_mesh(
"gloo", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local")
)

self.global_pg = self.device_mesh_cpu.get_group("global")
self.local_pg = self.device_mesh.get_group("local")

self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}")


class Diloco:
"""
This class implements the diloco algorithm from https://arxiv.org/abs/2311.08105 and https://arxiv.org/abs/2407.07852.
Expand Down
6 changes: 2 additions & 4 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch
from pydantic_config import parse_argv, BaseConfig
from torch.distributed import destroy_process_group, init_process_group
from einops import rearrange
from torch.nn import functional as F

Expand All @@ -19,7 +18,8 @@
)
import torch.distributed as dist
from zeroband import utils
from zeroband.diloco import Diloco, DilocoConfig, ElasticDeviceMesh
from zeroband.diloco import Diloco, DilocoConfig
from zeroband.comms import ElasticDeviceMesh

from zeroband.utils import get_sharding_strategy
from zeroband.utils.monitor import WandbMonitor, DummyMonitor
Expand Down Expand Up @@ -248,11 +248,9 @@ def train(config: Config):
world_info = get_world_info()
logger = get_logger()

init_process_group()
torch.cuda.set_device(world_info.local_rank)

config = Config(**parse_argv())
logger.debug(f"config: {config.model_dump()}")

train(config)
destroy_process_group()

0 comments on commit c3fa9d2

Please sign in to comment.