From 0ee6612df78a4c30f143e8000cd0b414551fd83b Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Wed, 25 Sep 2024 17:31:08 +0000 Subject: [PATCH 01/32] refactor: move pg concerns into edm --- src/zeroband/comms.py | 30 ++++++++++++++++++++++++++++++ src/zeroband/diloco.py | 25 +------------------------ src/zeroband/train.py | 6 ++---- 3 files changed, 33 insertions(+), 28 deletions(-) create mode 100644 src/zeroband/comms.py diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py new file mode 100644 index 00000000..0c2af6f8 --- /dev/null +++ b/src/zeroband/comms.py @@ -0,0 +1,30 @@ +from torch.distributed.device_mesh import init_device_mesh +from zeroband.utils.world_info import get_world_info +from zeroband.utils.logging import get_logger +import torch.distributed as dist + + +class ElasticDeviceMesh: + """Init two process group through device mesh, one local on gpu and one global on cpu""" + + def __init__(self): + self._logger = get_logger() + + self.world_info = get_world_info() + + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + # right now device mesh does not support two backend so we just create two identicaly mesh expect the backend + self.device_mesh = init_device_mesh( + "cuda", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local") + ) + self.device_mesh_cpu = init_device_mesh( + "gloo", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local") + ) + + self.global_pg = self.device_mesh_cpu.get_group("global") + self.local_pg = self.device_mesh.get_group("local") + + self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}") + + def __del__(self): + dist.destroy_process_group() diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 114284f0..9df29f91 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -1,9 +1,9 @@ from pydantic_config import BaseConfig import torch -from torch.distributed.device_mesh import init_device_mesh from torch import nn from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger +from zeroband.comms import ElasticDeviceMesh from torch.distributed.fsdp import ShardingStrategy import torch.distributed as dist @@ -12,29 +12,6 @@ class DilocoConfig(BaseConfig): outer_lr: float = 0.7 inner_steps: int - -class ElasticDeviceMesh: - """Init two process group through device mesh, one local on gpu and one global on cpu""" - - def __init__(self): - self._logger = get_logger() - - self.world_info = get_world_info() - - # right now device mesh does not support two backend so we just create two identicaly mesh expect the backend - self.device_mesh = init_device_mesh( - "cuda", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local") - ) - self.device_mesh_cpu = init_device_mesh( - "gloo", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local") - ) - - self.global_pg = self.device_mesh_cpu.get_group("global") - self.local_pg = self.device_mesh.get_group("local") - - self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}") - - class Diloco: """ This class implements the diloco algorithm from https://arxiv.org/abs/2311.08105 and https://arxiv.org/abs/2407.07852. diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 16087c71..3cc60c68 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -4,7 +4,6 @@ import torch from pydantic_config import parse_argv, BaseConfig -from torch.distributed import destroy_process_group, init_process_group from einops import rearrange from torch.nn import functional as F @@ -18,7 +17,8 @@ ) import torch.distributed as dist from zeroband import utils -from zeroband.diloco import Diloco, DilocoConfig, ElasticDeviceMesh +from zeroband.diloco import Diloco, DilocoConfig +from zeroband.comms import ElasticDeviceMesh from zeroband.utils import PerfCounter, get_sharding_strategy from zeroband.utils.monitor import WandbMonitor, DummyMonitor @@ -248,11 +248,9 @@ def train(config: Config): world_info = get_world_info() logger = get_logger() - init_process_group() torch.cuda.set_device(world_info.local_rank) config = Config(**parse_argv()) logger.debug(f"config: {config.model_dump()}") train(config) - destroy_process_group() From 9c32225931237f979f83acb5009d1f02124076f8 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Thu, 26 Sep 2024 00:29:03 +0000 Subject: [PATCH 02/32] working but only rank 0 syncs --- src/zeroband/comms.py | 239 +++++++++++++++++++++++++++++++++++++++-- src/zeroband/diloco.py | 13 ++- 2 files changed, 236 insertions(+), 16 deletions(-) diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py index 0c2af6f8..d6679906 100644 --- a/src/zeroband/comms.py +++ b/src/zeroband/comms.py @@ -2,29 +2,246 @@ from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger import torch.distributed as dist +import os +from datetime import timedelta +import time +from typing import List, Tuple, Optional +import uuid +TCPSTORE_TIMEOUT = timedelta(seconds=10) +MAX_JOINERS = 100 # Maximum number of nodes that can join in a single reinit +MAX_LEAVERS = 100 # Maximum number of nodes that can leave in a single reinit + +def _wait_for_status(store: dist.Store, status: Optional[str] = None) -> str: + while True: + try: + ret = store.get("status").decode("utf-8") + if status is None or ret == status: + return ret + time.sleep(0.1) + except dist.DistStoreError as e: + if status is not None: + raise e + time.sleep(0.1) + +def _queue_join(store: dist.Store, unique_id: str): + for i in range(MAX_JOINERS): + joiner_id = store.get(f"joiner_{i}").decode("utf-8") + if joiner_id == "null": + store.set(f"joiner_{i}", unique_id) + store.set(f"joiner_{i + 1}", "null") + break + else: + raise RuntimeError("Too many joiners") + +def _queue_leave(store: dist.Store, unique_id: str): + for i in range(MAX_LEAVERS): + leaver_id = store.get(f"leaver_{i}").decode("utf-8") + if leaver_id == "null": + store.set(f"leaver_{i}", unique_id) + store.set(f"leaver_{i + 1}", "null") + break + else: + raise RuntimeError("Too many leavers") + +def _get_joiners_and_leavers(store: dist.Store) -> Tuple[List[str], List[str]]: + joiners = [] + leavers = [] + for i in range(MAX_JOINERS): + joiner_id = store.get(f"joiner_{i}").decode("utf-8") + if joiner_id == "null": + break + joiners.append(joiner_id) + for i in range(MAX_LEAVERS): + leaver_id = store.get(f"leaver_{i}").decode("utf-8") + if leaver_id == "null": + break + leavers.append(leaver_id) + print(f"Joiners: {joiners}, Leavers: {leavers}") + return joiners, leavers + +def _clear_joiners_and_leavers(store: dist.Store): + store.set("joiner_0", "null") + store.set("leaver_0", "null") + class ElasticDeviceMesh: - """Init two process group through device mesh, one local on gpu and one global on cpu""" + """A class to manage the process groups for elastic training without restarts. + + The way it works is rank 0 coordinates the joining and leaving of nodes. + Rank 0 manages the status to coordinate the creation and recreation of the process groups. + When a node wants to join, rank 0 will setup the store so that all nodes know the new world size and their respective ranks. + + Store keys used: + - status: "init", "running", "reinit" + - world_size: The current world size + - mesh_count: The version of the mesh + - rank_{uuid}: The rank of the node with the given uuid + - rank_map_{rank}: The new rank of the node with the given rank. Used to remap ranks when nodes leave. + - joiner_{i}: The uuid of the ith joiner. Its a KV implmentation of a queue. + - leaver_{i}: The uuid of the ith leaver. Its a KV implmentation of a queue. + """ + + local_pg: dist.ProcessGroup + global_pg: dist.ProcessGroup def __init__(self): self._logger = get_logger() - self.world_info = get_world_info() + # Initialize global process group + self._init_unique_id() + if self.world_info.rank == 0: + self.global_pg = self._init_global_pg() + else: + self.global_pg = None + + # Initialize local process group dist.init_process_group(backend="cpu:gloo,cuda:nccl") - # right now device mesh does not support two backend so we just create two identicaly mesh expect the backend - self.device_mesh = init_device_mesh( - "cuda", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local") + self._device_mesh = init_device_mesh( + "cuda", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("internode", "intranode") ) - self.device_mesh_cpu = init_device_mesh( - "gloo", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local") + self.local_pg = self._device_mesh.get_group("intranode") + + if self.world_info.rank == 0: + self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}") + else: + self._logger.debug(f"local pg world : {self.local_pg.size()}") + + def __del__(self): + dist.destroy_process_group() + + def _init_global_pg(self) -> dist.Store: + global_addr = os.environ["GLOBAL_ADDR"] + global_port = int(os.environ["GLOBAL_PORT"]) + global_world_size = int(os.environ["GLOBAL_WORLD_SIZE"]) + global_rank = int(os.environ["GLOBAL_RANK"]) + + store = dist.TCPStore( + host_name=global_addr, + port=global_port, + timeout=TCPSTORE_TIMEOUT, + is_master=(global_rank == 0), ) - self.global_pg = self.device_mesh_cpu.get_group("global") - self.local_pg = self.device_mesh.get_group("local") + # Initialize store + if global_rank == 0: + store.set("mesh_count", "0") + store.set("joiner_0", "null") + store.set("leaver_0", "null") + store.set("status", "init") + status = "init" + else: + status = _wait_for_status(store) + + if status == "init": + # First time initialization + self.mesh_count = 0 + self.prefix_store = dist.PrefixStore("mesh_0", store) + pg = dist.ProcessGroupGloo(self.prefix_store, global_rank, global_world_size, TCPSTORE_TIMEOUT) + if global_rank == 0: + store.set("status", "running") + store.set(f"rank_{self.unique_id}", str(global_rank)) + elif status == "running": + # Node wants to join + _queue_join(store, self.unique_id) + _wait_for_status(store, "reinit") + # Get assigned rank + global_rank = int(store.get(f"rank_{self.unique_id}").decode("utf-8")) + # Get updated world_size + global_world_size = int(store.get("world_size").decode("utf-8")) + self.mesh_count = int(store.get("mesh_count").decode("utf-8")) + self.prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", store) + pg = dist.ProcessGroupGloo(self.prefix_store, global_rank, global_world_size, TCPSTORE_TIMEOUT) + else: + # TODO: Could be in "reinit" status + raise RuntimeError(f"Unknown status {status}") + + # Setting instance variables + self.global_store = store + self.global_rank = global_rank + self.leaving = False + return pg + + def _init_unique_id(self): + """Initialize a unique ID for the node. + If TORCH_UNIQUE_ID is set, use that. + Otherwise, local rank 0 generates an ID and broadcasts to other nodes. + """ + if "TORCH_UNIQUE_ID" in os.environ: + self.unique_id = os.environ["TORCH_UNIQUE_ID"] + return + if self.local_rank == 0: + self.unique_id = str(uuid.uuid4()) + with open('/tmp/torch_unique_id', 'w') as f: + f.write(self.unique_id) + else: + while True: + try: + with open('/tmp/torch_unique_id', 'r') as f: + self.unique_id = f.read() + break + except FileNotFoundError: + time.sleep(0.1) - self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}") - def __del__(self): + def _resolve_world(self): + """Set the new world size and ranks for all nodes.""" + # Find joiners and leavers + joiners, leavers = _get_joiners_and_leavers(self.global_store) + # If no joiners or leavers, no resolution needed + if len(joiners) == 0 and len(leavers) == 0: + return + + # Remap live ranks to smaller world_size caused by leavers + leaving_ranks = {int(self.global_store.get(f"rank_{leaver_id}").decode("utf-8")) for leaver_id in leavers} + live_ranks = [i for i in range(0, self.world_size, self.local_world_size) if i not in leaving_ranks] + for i, rank in enumerate(live_ranks): + self.global_store.set(f"rank_map_{rank}", str(i * self.local_world_size)) + new_world_size = len(live_ranks) * self.local_world_size + + # Give joiners new ranks + for joiner_id in joiners: + self.global_store.set(f"rank_{joiner_id}", str(new_world_size)) + new_world_size += self.local_world_size + + # Update world_size + self.global_store.set("world_size", str(new_world_size)) + self.global_store.set("mesh_count", str(self.mesh_count + 1)) + # Set status to "reinit" + self.global_store.set("status", "reinit") + + def maybe_reinit_device_mesh(self): + """Reinitialize the device mesh if there are joiners or leavers.""" + if self.rank == 0: + self._resolve_world() + dist.barrier() + status = self.global_store.get("status").decode("utf-8") + if status == "running": + return + + print("Reinitializing device mesh") dist.destroy_process_group() + print("Destroyed process group") + if self.leaving: + print("Leaving") + return + + # Check if we got remapped + prev_uuid_rank = int(self.global_store.get(f"rank_{self.unique_id}").decode("utf-8")) + new_uuid_rank = int(self.global_store.get(f"rank_map_{prev_uuid_rank}").decode("utf-8")) + self.rank = new_uuid_rank + self.local_rank + + self.world_size = int(self.global_store.get("world_size").decode("utf-8")) + self.mesh_count = int(self.global_store.get("mesh_count").decode("utf-8")) + self.prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", self.global_store) + dist.init_process_group(backend="cpu:gloo,cuda:nccl", store=self.prefix_store, rank=self.rank, world_size=self.world_size) + + if self.rank == 0: + _clear_joiners_and_leavers(self.global_store) + self.global_store.set("status", "running") + # Update rank if needed (otherwise, the next remap will do the lookup incorrectly) + if self.local_rank == 0 and new_uuid_rank != prev_uuid_rank: + self.global_store.set(f"rank_{self.unique_id}", str(new_uuid_rank)) + # Reinitialize sub process groups + self.world_rank = self.rank // self.local_world_size diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 9df29f91..83fdd36c 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -69,14 +69,17 @@ def sync_pseudo_gradient(self, model: nn.Module): Sync the pseudo gradient from the local process group to the global process group """ self._logger.debug("sync pseudo gradient") - for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): - # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices - param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) + if self.elastic_device_mesh.global_pg is not None: + for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): + # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices + param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) # gloo does not support AVG param_offloaded.grad = param_offloaded.grad / self.elastic_device_mesh.global_pg.size() - dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.elastic_device_mesh.global_pg) - # todo maybe do async here + dist.all_reduce( + param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.elastic_device_mesh.global_pg, async_op=True + ) + # todo async here def sync_inner_model(self, model: nn.Module): """ From 6da6f10d4571be6cb1562c6b07bdac9fbc94ca81 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Thu, 26 Sep 2024 19:22:15 +0000 Subject: [PATCH 03/32] use fake pg instead of None --- src/zeroband/comms.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py index d6679906..abb57d50 100644 --- a/src/zeroband/comms.py +++ b/src/zeroband/comms.py @@ -7,6 +7,7 @@ import time from typing import List, Tuple, Optional import uuid +from torch.testing._internal.distributed.fake_pg import FakeProcessGroup TCPSTORE_TIMEOUT = timedelta(seconds=10) @@ -94,7 +95,7 @@ def __init__(self): if self.world_info.rank == 0: self.global_pg = self._init_global_pg() else: - self.global_pg = None + self.global_pg = FakeProcessGroup(self.world_info.rank, self.world_info.world_size) # Initialize local process group dist.init_process_group(backend="cpu:gloo,cuda:nccl") From 40a3e2a8a204da8e2cc88558a62da1f18831a396 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Thu, 26 Sep 2024 20:29:56 +0000 Subject: [PATCH 04/32] testing utils --- src/zeroband/testing.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 src/zeroband/testing.py diff --git a/src/zeroband/testing.py b/src/zeroband/testing.py new file mode 100644 index 00000000..0e9f5285 --- /dev/null +++ b/src/zeroband/testing.py @@ -0,0 +1,31 @@ +import torch +import hashlib + +TENSOR_SIG_SAMPLE_SIZE = 100 + +def get_tensor_signature(a: torch.Tensor | torch.nn.Parameter) -> str: + """ + Get the tensor signature + """ + while isinstance(a, torch.nn.Parameter): + a = a.data + if a.numel() < TENSOR_SIG_SAMPLE_SIZE: + b = a.as_strided(size=(a.numel(),), stride=(1,)) + else: + step_size = a.numel() // TENSOR_SIG_SAMPLE_SIZE + b = a.as_strided(size=(TENSOR_SIG_SAMPLE_SIZE,), stride=(step_size,)) + element_str = ''.join([f'{x:.3e}' for x in b]) + element_hash = hashlib.md5(element_str.encode("utf-8")).hexdigest() + return f"{a.dtype}{a.shape}{a.stride()}<{element_hash}>" + +def get_module_signature(module: torch.nn.Module, compress: bool=True) -> str: + """ + Get the module signature + """ + state_dict_sig = {name: get_tensor_signature(param) for name, param in module.named_parameters()} + if compress: + return hashlib.md5(str(state_dict_sig).encode("utf-8")).hexdigest() + else: + return '\n'.join(f"{name}: {sig}" for name, sig in state_dict_sig.items()) + + \ No newline at end of file From 22db84a783420092693c78cca024086d778ce813 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Thu, 26 Sep 2024 23:17:38 +0000 Subject: [PATCH 05/32] syncing correctly but ugly --- src/zeroband/diloco.py | 20 +++++++++++++------- src/zeroband/train.py | 27 ++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 83fdd36c..55dd553a 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -69,17 +69,20 @@ def sync_pseudo_gradient(self, model: nn.Module): Sync the pseudo gradient from the local process group to the global process group """ self._logger.debug("sync pseudo gradient") + works = [] if self.elastic_device_mesh.global_pg is not None: for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) - # gloo does not support AVG - param_offloaded.grad = param_offloaded.grad / self.elastic_device_mesh.global_pg.size() - dist.all_reduce( - param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.elastic_device_mesh.global_pg, async_op=True - ) - # todo async here + # gloo does not support AVG + param_offloaded.grad = param_offloaded.grad / self.elastic_device_mesh.global_pg.size() + work = dist.all_reduce( + param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.elastic_device_mesh.global_pg, async_op=True + ) + works.append(work) + for work in works: + work.wait() def sync_inner_model(self, model: nn.Module): """ @@ -88,7 +91,10 @@ def sync_inner_model(self, model: nn.Module): self._logger.debug("sync inner model") for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): - param.data = param_offloaded.data.to("cuda") # todo: use copy_ here + param.data.copy_(param_offloaded.data) + + for param in model.parameters(): + dist.broadcast(param.data, src=0, group=self.elastic_device_mesh.local_pg) def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: """ diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 3cc60c68..4a616277 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -27,6 +27,8 @@ from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger +from zeroband.testing import get_module_signature + class DataConfig(BaseConfig): seq_length: int = 1024 @@ -110,8 +112,21 @@ def train(config: Config): config.data.seq_length, ) + from torch.distributed.distributed_c10d import BroadcastOptions elastic_device_mesh = ElasticDeviceMesh() - + if world_info.rank == 0: + for param in model.parameters(): + # TODO: Kinda ugly but somethings wrong with the world registration + #dist.broadcast(param.data, src=0, group=elastic_device_mesh.global_pg) + opts = BroadcastOptions() + opts.rootRank = 0 + opts.rootTensor = 0 + elastic_device_mesh.global_pg.broadcast([param.data], opts) + + for param in model.parameters(): + dist.broadcast(param.data, src=0, group=elastic_device_mesh.local_pg) + + print(f"[Rank {world_info.rank}] {get_module_signature(model)}") model = FSDP( model, sharding_strategy=sharding_strategy, @@ -128,7 +143,8 @@ def train(config: Config): if world_info.local_world_size == 1: raise ValueError("Diloco is not supported for local_world_size == 1 because of a pytorch bug") - diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh) + with FSDP.summon_full_params(model): + diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh) # Setup optimizers inner_optimizer = torch.optim.AdamW( @@ -163,6 +179,8 @@ def train(config: Config): logger.info(f"outer_step step: {outer_step}") for inner_step in range(num_inner_steps): + with FSDP.summon_full_params(model): + print(f"[Rank {world_info.rank}] {get_module_signature(model)}") loss_batch = 0 for grad_acc_step in range(gradient_accumulation_steps): @@ -225,7 +243,10 @@ def train(config: Config): logger.info(log) if config.diloco is not None: - diloco.step(model) + with FSDP.summon_full_params(model): + print(f"[Rank {world_info.rank}] pre diloco step {get_module_signature(model)}") + diloco.step(model) + print(f"[Rank {world_info.rank}] post diloco step {get_module_signature(model)}") outer_step += 1 From 84b41f7662111d32571cbe090d52e73cd03a67bb Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Fri, 27 Sep 2024 02:09:00 +0000 Subject: [PATCH 06/32] make cpu offload use mmaped file --- src/zeroband/diloco.py | 57 ++++++++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 22 deletions(-) diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 55dd553a..5ac0b6a0 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -1,3 +1,5 @@ +import os +import shutil from pydantic_config import BaseConfig import torch from torch import nn @@ -7,6 +9,7 @@ from torch.distributed.fsdp import ShardingStrategy import torch.distributed as dist +from torch.testing._internal.distributed.fake_pg import FakeProcessGroup class DilocoConfig(BaseConfig): outer_lr: float = 0.7 @@ -57,7 +60,7 @@ def __init__( self._init_offloaded_optimizer(model=model) - def _init_offloaded_optimizer(self, model): + def _init_offloaded_optimizer(self, model: nn.Module): self.param_list_cpu = self.get_offloaded_param(model) self.outer_optimizer = torch.optim.SGD( self.param_list_cpu, lr=self.config.outer_lr, momentum=0.9, nesterov=True @@ -70,17 +73,17 @@ def sync_pseudo_gradient(self, model: nn.Module): """ self._logger.debug("sync pseudo gradient") works = [] - if self.elastic_device_mesh.global_pg is not None: - for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): - # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices - param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) - - # gloo does not support AVG - param_offloaded.grad = param_offloaded.grad / self.elastic_device_mesh.global_pg.size() - work = dist.all_reduce( - param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.elastic_device_mesh.global_pg, async_op=True - ) - works.append(work) + # TODO: This assumes all params require grad, which is used by the offload + for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): + # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices + param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) + + # gloo does not support AVG + param_offloaded.grad = param_offloaded.grad / self.elastic_device_mesh.global_pg.size() + work = dist.all_reduce( + param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.elastic_device_mesh.global_pg, async_op=True + ) + works.append(work) for work in works: work.wait() @@ -92,31 +95,41 @@ def sync_inner_model(self, model: nn.Module): self._logger.debug("sync inner model") for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): param.data.copy_(param_offloaded.data) - - for param in model.parameters(): - dist.broadcast(param.data, src=0, group=self.elastic_device_mesh.local_pg) def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: """ Offload the model parameters to cpu """ + unique_id = self.elastic_device_mesh.unique_id offloaded_params = [] + os.makedirs("/dev/shm/zeroband", exist_ok=True) - for param in model.parameters(): + for param_name, param in model.named_parameters(): if param.requires_grad: - offloaded_param = param.data.detach().clone().to("cpu") - offloaded_param.requires_grad = True + storage = torch.UntypedStorage.from_file(f"/dev/shm/zeroband/{unique_id}-{param_name}", True, param.data.untyped_storage().size()) + offloaded_param = torch.tensor(storage, dtype=param.dtype, device="cpu") + offloaded_param.as_strided_(size=param.data.size(), stride=param.data.stride()) + if self.world_info.rank == 0: + # TODO: Can we async or split the copy among gpus probs overkill? + offloaded_param.copy_(param.data) + offloaded_param.requires_grad = False # TODO: check if we need to set this to True offloaded_params.append(offloaded_param) + dist.barrier() return offloaded_params def step(self, model: nn.Module): """ Step the optimizer """ - self.sync_pseudo_gradient(model) - if self.outer_optimizer is not None: - self.outer_optimizer.step() - self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this + if self.world_info.rank == 0: + self.sync_pseudo_gradient(model) + if self.outer_optimizer is not None: + self.outer_optimizer.step() + self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this + dist.barrier() self.sync_inner_model(model) + + def __del__(self): + shutil.rmtree("/dev/shm/zeroband", ignore_errors=True) From 9e79c043c6e9c7758960011bb1438c16acf61158 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Fri, 27 Sep 2024 02:57:54 +0000 Subject: [PATCH 07/32] fix: allow none diloco to work with fake pg --- src/zeroband/comms.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py index abb57d50..5f82d423 100644 --- a/src/zeroband/comms.py +++ b/src/zeroband/comms.py @@ -91,11 +91,15 @@ def __init__(self): self.world_info = get_world_info() # Initialize global process group - self._init_unique_id() - if self.world_info.rank == 0: - self.global_pg = self._init_global_pg() - else: - self.global_pg = FakeProcessGroup(self.world_info.rank, self.world_info.world_size) + self.global_pg = FakeProcessGroup(self.world_info.rank, 1) + if "GLOBAL_RANK" in os.environ: + self._init_unique_id() + if self.world_info.rank == 0: + self.global_pg = self._init_global_pg() + #from torch.distributed.distributed_c10d import _world + #global_rank = int(os.environ["GLOBAL_RANK"]) + #_world.pg_group_ranks[self.global_pg] = {i: global_rank for i in range(self.world_info.world_size)} + #_world.pg_map[self.global_pg] = "gloo", self.global_store # Initialize local process group dist.init_process_group(backend="cpu:gloo,cuda:nccl") From fa5698050b5d4a52b9e210b4306d6a025a80afce Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Fri, 27 Sep 2024 02:58:23 +0000 Subject: [PATCH 08/32] simulate multi node diloco script --- scripts/simulate_multi_node_diloco.sh | 69 +++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100755 scripts/simulate_multi_node_diloco.sh diff --git a/scripts/simulate_multi_node_diloco.sh b/scripts/simulate_multi_node_diloco.sh new file mode 100755 index 00000000..858c1805 --- /dev/null +++ b/scripts/simulate_multi_node_diloco.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +# +# simulate multi nodes on one gpu. start N torchrun on X gpu locally. +# example how to run ./scripts/simulate_multi_node.sh 2 1 src/zeroband/train.py @configs/debug/normal.toml + +# Function to get CUDA devices based on the number of GPUs and index +function get_cuda_devices() { + local num_gpu=$1 + local index=$2 + local start_gpu=$((num_gpu * index)) + local end_gpu=$((start_gpu + num_gpu - 1)) + + if [ "$num_gpu" -eq 1 ]; then + echo $start_gpu + else + echo $(seq -s ',' $start_gpu $end_gpu) + fi +} + +# Array to store PIDs of child processes +child_pids=() + +# Function to kill all child processes +cleanup() { + echo "Cleaning up child processes..." + local killed=0 + for pid in "${child_pids[@]}"; do + if kill -TERM "$pid" 2>/dev/null; then + ((killed++)) + fi + done + wait + echo "All child processes terminated. Killed $killed processes." + exit +} + +# Check if at least three arguments were passed +if [ "$#" -lt 3 ]; then + echo "Usage: $0 [additional_python_args]" + exit 1 +fi + + +N=$1 # Set N from the first argument +NUM_GPU=$2 +shift 2 # Remove the first three arguments so $@ contains only additional Python arguments + +# Register the cleanup function to be called on SIGINT (Ctrl+C) +trap cleanup SIGINT + + +mkdir -p logs + +export GLOBAL_ADDR=localhost +export GLOBAL_PORT=10000 +export GLOBAL_WORLD_SIZE=$N + +for i in $(seq 0 $(($N - 1 ))) +do + > logs/log$i + TORCH_UNIQUE_ID=$i GLOBAL_RANK=$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank 0 --rdzv-endpoint localhost:$((10001 + $i)) --nnodes=1 $@ > logs/log$i 2>&1 & + child_pids+=($!) +done + +tail -f logs/log0 & +child_pids+=($!) + +wait From 3077262a9c73209fa5f59d65fca842dc272799c9 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Fri, 27 Sep 2024 02:59:04 +0000 Subject: [PATCH 09/32] docs: update docs --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 25861da1..134dd1a2 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ZeroBand is a production ready codebase for decentralized training of LLM -## developlment +## Developlment install uv @@ -40,22 +40,22 @@ run your code using uv run ... ``` -## quick check +## Quick check To check that everything is working you can do ```bash -ZERO_BAND_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug/normal.toml +ZERO_BAND_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug/normal.toml ``` -## run diloco +## Run diloco To run diloco locally you can use the helper script `scripts/simulatsimulate_multi_nodee_mutl.sh` :note: you need 4 gpus to run the following command ```bash -ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node.sh 2 2 src/zeroband/train.py @configs/debug/diloco.toml +ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node_diloco.sh 2 2 src/zeroband/train.py @configs/debug/diloco.toml ``` if you have only two gpus From e8332c4d42f8a846e687cf696e49f2ca18a10769 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Fri, 27 Sep 2024 05:13:49 +0000 Subject: [PATCH 10/32] remove prints --- src/zeroband/train.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 4a616277..464ab235 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -113,11 +113,12 @@ def train(config: Config): ) from torch.distributed.distributed_c10d import BroadcastOptions + elastic_device_mesh = ElasticDeviceMesh() if world_info.rank == 0: for param in model.parameters(): # TODO: Kinda ugly but somethings wrong with the world registration - #dist.broadcast(param.data, src=0, group=elastic_device_mesh.global_pg) + # dist.broadcast(param.data, src=0, group=elastic_device_mesh.global_pg) opts = BroadcastOptions() opts.rootRank = 0 opts.rootTensor = 0 @@ -126,7 +127,6 @@ def train(config: Config): for param in model.parameters(): dist.broadcast(param.data, src=0, group=elastic_device_mesh.local_pg) - print(f"[Rank {world_info.rank}] {get_module_signature(model)}") model = FSDP( model, sharding_strategy=sharding_strategy, @@ -179,8 +179,6 @@ def train(config: Config): logger.info(f"outer_step step: {outer_step}") for inner_step in range(num_inner_steps): - with FSDP.summon_full_params(model): - print(f"[Rank {world_info.rank}] {get_module_signature(model)}") loss_batch = 0 for grad_acc_step in range(gradient_accumulation_steps): @@ -244,9 +242,9 @@ def train(config: Config): if config.diloco is not None: with FSDP.summon_full_params(model): - print(f"[Rank {world_info.rank}] pre diloco step {get_module_signature(model)}") + logger.debug("Pre diloco step %s", get_module_signature(model)) diloco.step(model) - print(f"[Rank {world_info.rank}] post diloco step {get_module_signature(model)}") + logger.debug("Post diloco step %s", get_module_signature(model)) outer_step += 1 From 4938bb4ad40d3f90103d16cb37c44b960b757cd5 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Fri, 27 Sep 2024 05:16:18 +0000 Subject: [PATCH 11/32] ruff lint --- src/zeroband/comms.py | 49 ++++++++++++++++++++++++----------------- src/zeroband/diloco.py | 10 +++++---- src/zeroband/testing.py | 10 ++++----- 3 files changed, 40 insertions(+), 29 deletions(-) diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py index 5f82d423..a9a570f7 100644 --- a/src/zeroband/comms.py +++ b/src/zeroband/comms.py @@ -11,8 +11,9 @@ TCPSTORE_TIMEOUT = timedelta(seconds=10) -MAX_JOINERS = 100 # Maximum number of nodes that can join in a single reinit -MAX_LEAVERS = 100 # Maximum number of nodes that can leave in a single reinit +MAX_JOINERS = 100 # Maximum number of nodes that can join in a single reinit +MAX_LEAVERS = 100 # Maximum number of nodes that can leave in a single reinit + def _wait_for_status(store: dist.Store, status: Optional[str] = None) -> str: while True: @@ -26,6 +27,7 @@ def _wait_for_status(store: dist.Store, status: Optional[str] = None) -> str: raise e time.sleep(0.1) + def _queue_join(store: dist.Store, unique_id: str): for i in range(MAX_JOINERS): joiner_id = store.get(f"joiner_{i}").decode("utf-8") @@ -36,6 +38,7 @@ def _queue_join(store: dist.Store, unique_id: str): else: raise RuntimeError("Too many joiners") + def _queue_leave(store: dist.Store, unique_id: str): for i in range(MAX_LEAVERS): leaver_id = store.get(f"leaver_{i}").decode("utf-8") @@ -46,6 +49,7 @@ def _queue_leave(store: dist.Store, unique_id: str): else: raise RuntimeError("Too many leavers") + def _get_joiners_and_leavers(store: dist.Store) -> Tuple[List[str], List[str]]: joiners = [] leavers = [] @@ -62,17 +66,19 @@ def _get_joiners_and_leavers(store: dist.Store) -> Tuple[List[str], List[str]]: print(f"Joiners: {joiners}, Leavers: {leavers}") return joiners, leavers + def _clear_joiners_and_leavers(store: dist.Store): store.set("joiner_0", "null") store.set("leaver_0", "null") + class ElasticDeviceMesh: """A class to manage the process groups for elastic training without restarts. - + The way it works is rank 0 coordinates the joining and leaving of nodes. Rank 0 manages the status to coordinate the creation and recreation of the process groups. When a node wants to join, rank 0 will setup the store so that all nodes know the new world size and their respective ranks. - + Store keys used: - status: "init", "running", "reinit" - world_size: The current world size @@ -96,15 +102,17 @@ def __init__(self): self._init_unique_id() if self.world_info.rank == 0: self.global_pg = self._init_global_pg() - #from torch.distributed.distributed_c10d import _world - #global_rank = int(os.environ["GLOBAL_RANK"]) - #_world.pg_group_ranks[self.global_pg] = {i: global_rank for i in range(self.world_info.world_size)} - #_world.pg_map[self.global_pg] = "gloo", self.global_store + # from torch.distributed.distributed_c10d import _world + # global_rank = int(os.environ["GLOBAL_RANK"]) + # _world.pg_group_ranks[self.global_pg] = {i: global_rank for i in range(self.world_info.world_size)} + # _world.pg_map[self.global_pg] = "gloo", self.global_store # Initialize local process group dist.init_process_group(backend="cpu:gloo,cuda:nccl") self._device_mesh = init_device_mesh( - "cuda", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("internode", "intranode") + "cuda", + (self.world_info.nnodes, self.world_info.local_world_size), + mesh_dim_names=("internode", "intranode"), ) self.local_pg = self._device_mesh.get_group("intranode") @@ -112,10 +120,10 @@ def __init__(self): self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}") else: self._logger.debug(f"local pg world : {self.local_pg.size()}") - + def __del__(self): dist.destroy_process_group() - + def _init_global_pg(self) -> dist.Store: global_addr = os.environ["GLOBAL_ADDR"] global_port = int(os.environ["GLOBAL_PORT"]) @@ -138,7 +146,7 @@ def _init_global_pg(self) -> dist.Store: status = "init" else: status = _wait_for_status(store) - + if status == "init": # First time initialization self.mesh_count = 0 @@ -178,18 +186,17 @@ def _init_unique_id(self): return if self.local_rank == 0: self.unique_id = str(uuid.uuid4()) - with open('/tmp/torch_unique_id', 'w') as f: + with open("/tmp/torch_unique_id", "w") as f: f.write(self.unique_id) else: while True: try: - with open('/tmp/torch_unique_id', 'r') as f: + with open("/tmp/torch_unique_id", "r") as f: self.unique_id = f.read() break except FileNotFoundError: time.sleep(0.1) - def _resolve_world(self): """Set the new world size and ranks for all nodes.""" # Find joiners and leavers @@ -204,12 +211,12 @@ def _resolve_world(self): for i, rank in enumerate(live_ranks): self.global_store.set(f"rank_map_{rank}", str(i * self.local_world_size)) new_world_size = len(live_ranks) * self.local_world_size - + # Give joiners new ranks for joiner_id in joiners: self.global_store.set(f"rank_{joiner_id}", str(new_world_size)) new_world_size += self.local_world_size - + # Update world_size self.global_store.set("world_size", str(new_world_size)) self.global_store.set("mesh_count", str(self.mesh_count + 1)) @@ -224,7 +231,7 @@ def maybe_reinit_device_mesh(self): status = self.global_store.get("status").decode("utf-8") if status == "running": return - + print("Reinitializing device mesh") dist.destroy_process_group() print("Destroyed process group") @@ -240,8 +247,10 @@ def maybe_reinit_device_mesh(self): self.world_size = int(self.global_store.get("world_size").decode("utf-8")) self.mesh_count = int(self.global_store.get("mesh_count").decode("utf-8")) self.prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", self.global_store) - dist.init_process_group(backend="cpu:gloo,cuda:nccl", store=self.prefix_store, rank=self.rank, world_size=self.world_size) - + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", store=self.prefix_store, rank=self.rank, world_size=self.world_size + ) + if self.rank == 0: _clear_joiners_and_leavers(self.global_store) self.global_store.set("status", "running") diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 5ac0b6a0..374f734a 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -9,12 +9,12 @@ from torch.distributed.fsdp import ShardingStrategy import torch.distributed as dist -from torch.testing._internal.distributed.fake_pg import FakeProcessGroup class DilocoConfig(BaseConfig): outer_lr: float = 0.7 inner_steps: int + class Diloco: """ This class implements the diloco algorithm from https://arxiv.org/abs/2311.08105 and https://arxiv.org/abs/2407.07852. @@ -106,13 +106,15 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: for param_name, param in model.named_parameters(): if param.requires_grad: - storage = torch.UntypedStorage.from_file(f"/dev/shm/zeroband/{unique_id}-{param_name}", True, param.data.untyped_storage().size()) + storage = torch.UntypedStorage.from_file( + f"/dev/shm/zeroband/{unique_id}-{param_name}", True, param.data.untyped_storage().size() + ) offloaded_param = torch.tensor(storage, dtype=param.dtype, device="cpu") offloaded_param.as_strided_(size=param.data.size(), stride=param.data.stride()) if self.world_info.rank == 0: # TODO: Can we async or split the copy among gpus probs overkill? offloaded_param.copy_(param.data) - offloaded_param.requires_grad = False # TODO: check if we need to set this to True + offloaded_param.requires_grad = False # TODO: check if we need to set this to True offloaded_params.append(offloaded_param) dist.barrier() @@ -130,6 +132,6 @@ def step(self, model: nn.Module): dist.barrier() self.sync_inner_model(model) - + def __del__(self): shutil.rmtree("/dev/shm/zeroband", ignore_errors=True) diff --git a/src/zeroband/testing.py b/src/zeroband/testing.py index 0e9f5285..684072e3 100644 --- a/src/zeroband/testing.py +++ b/src/zeroband/testing.py @@ -3,6 +3,7 @@ TENSOR_SIG_SAMPLE_SIZE = 100 + def get_tensor_signature(a: torch.Tensor | torch.nn.Parameter) -> str: """ Get the tensor signature @@ -14,11 +15,12 @@ def get_tensor_signature(a: torch.Tensor | torch.nn.Parameter) -> str: else: step_size = a.numel() // TENSOR_SIG_SAMPLE_SIZE b = a.as_strided(size=(TENSOR_SIG_SAMPLE_SIZE,), stride=(step_size,)) - element_str = ''.join([f'{x:.3e}' for x in b]) + element_str = "".join([f"{x:.3e}" for x in b]) element_hash = hashlib.md5(element_str.encode("utf-8")).hexdigest() return f"{a.dtype}{a.shape}{a.stride()}<{element_hash}>" -def get_module_signature(module: torch.nn.Module, compress: bool=True) -> str: + +def get_module_signature(module: torch.nn.Module, compress: bool = True) -> str: """ Get the module signature """ @@ -26,6 +28,4 @@ def get_module_signature(module: torch.nn.Module, compress: bool=True) -> str: if compress: return hashlib.md5(str(state_dict_sig).encode("utf-8")).hexdigest() else: - return '\n'.join(f"{name}: {sig}" for name, sig in state_dict_sig.items()) - - \ No newline at end of file + return "\n".join(f"{name}: {sig}" for name, sig in state_dict_sig.items()) From bcc1d7aab961f83708806d3ecf983869bdea5c6f Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sat, 28 Sep 2024 03:17:51 +0800 Subject: [PATCH 12/32] move global info to world info and fix unique id --- src/zeroband/comms.py | 26 ++++++++++---------------- src/zeroband/diloco.py | 4 ++-- src/zeroband/train.py | 5 +---- src/zeroband/utils/world_info.py | 4 ++++ 4 files changed, 17 insertions(+), 22 deletions(-) diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py index a9a570f7..b43e2d6f 100644 --- a/src/zeroband/comms.py +++ b/src/zeroband/comms.py @@ -125,20 +125,15 @@ def __del__(self): dist.destroy_process_group() def _init_global_pg(self) -> dist.Store: - global_addr = os.environ["GLOBAL_ADDR"] - global_port = int(os.environ["GLOBAL_PORT"]) - global_world_size = int(os.environ["GLOBAL_WORLD_SIZE"]) - global_rank = int(os.environ["GLOBAL_RANK"]) - store = dist.TCPStore( - host_name=global_addr, - port=global_port, + host_name=self.world_info.global_addr, + port=self.world_info.global_port, timeout=TCPSTORE_TIMEOUT, - is_master=(global_rank == 0), + is_master=(self.world_info.global_rank == 0), ) # Initialize store - if global_rank == 0: + if self.world_info.global_rank == 0: store.set("mesh_count", "0") store.set("joiner_0", "null") store.set("leaver_0", "null") @@ -151,28 +146,27 @@ def _init_global_pg(self) -> dist.Store: # First time initialization self.mesh_count = 0 self.prefix_store = dist.PrefixStore("mesh_0", store) - pg = dist.ProcessGroupGloo(self.prefix_store, global_rank, global_world_size, TCPSTORE_TIMEOUT) - if global_rank == 0: + pg = dist.ProcessGroupGloo(self.prefix_store, self.world_info.global_rank, self.world_info.global_world_size, TCPSTORE_TIMEOUT) + if self.world_info.global_rank == 0: store.set("status", "running") - store.set(f"rank_{self.unique_id}", str(global_rank)) + store.set(f"rank_{self.unique_id}", str(self.world_info.global_rank)) elif status == "running": # Node wants to join _queue_join(store, self.unique_id) _wait_for_status(store, "reinit") # Get assigned rank - global_rank = int(store.get(f"rank_{self.unique_id}").decode("utf-8")) + self.world_info.global_rank = int(store.get(f"rank_{self.unique_id}").decode("utf-8")) # Get updated world_size - global_world_size = int(store.get("world_size").decode("utf-8")) + self.world_info.global_world_size = int(store.get("world_size").decode("utf-8")) self.mesh_count = int(store.get("mesh_count").decode("utf-8")) self.prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", store) - pg = dist.ProcessGroupGloo(self.prefix_store, global_rank, global_world_size, TCPSTORE_TIMEOUT) + pg = dist.ProcessGroupGloo(self.prefix_store, self.world_info.global_rank, self.world_info.global_world_size, TCPSTORE_TIMEOUT) else: # TODO: Could be in "reinit" status raise RuntimeError(f"Unknown status {status}") # Setting instance variables self.global_store = store - self.global_rank = global_rank self.leaving = False return pg diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 739f4d8a..0f74d43c 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -46,11 +46,11 @@ def __init__( config: DilocoConfig, model: nn.Module, fsdp_sharding_strategy: ShardingStrategy, - global_pg: dist.ProcessGroup, + elastic_device_mesh: ElasticDeviceMesh, ): self.config = config self.fsdp_sharding_strategy = fsdp_sharding_strategy - self.global_pg = global_pg + self.elastic_device_mesh = elastic_device_mesh self._logger = get_logger() self.world_info = get_world_info() diff --git a/src/zeroband/train.py b/src/zeroband/train.py index eab02755..ddbf8104 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -154,10 +154,7 @@ def train(config: Config): ) if config.diloco is not None: - if world_info.local_world_size == 1: - raise ValueError("Diloco is not supported for local_world_size == 1 because of a pytorch bug") - - diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh.global_pg) + diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh) scheduler = get_cosine_schedule_with_warmup( inner_optimizer, diff --git a/src/zeroband/utils/world_info.py b/src/zeroband/utils/world_info.py index efe30a1a..50aa3d48 100644 --- a/src/zeroband/utils/world_info.py +++ b/src/zeroband/utils/world_info.py @@ -18,6 +18,10 @@ def __init__(self): self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) self.nnodes = self.world_size // self.local_world_size + self.global_addr = os.environ.get("GLOBAL_ADDR", "") + self.global_port = int(os.environ.get("GLOBAL_PORT", -1)) + self.global_world_size = int(os.environ.get("GLOBAL_WORLD_SIZE", -1)) + self.global_rank = int(os.environ.get("GLOBAL_RANK", -1)) def get_world_info() -> WorldInfo: """ From 70a2a82645a40a5b0f38c1a96dc63a6164c28fc3 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Fri, 27 Sep 2024 20:05:18 +0000 Subject: [PATCH 13/32] fixes from merge --- src/zeroband/comms.py | 8 ++++++-- src/zeroband/train.py | 3 ++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py index b43e2d6f..e4b3409a 100644 --- a/src/zeroband/comms.py +++ b/src/zeroband/comms.py @@ -146,7 +146,9 @@ def _init_global_pg(self) -> dist.Store: # First time initialization self.mesh_count = 0 self.prefix_store = dist.PrefixStore("mesh_0", store) - pg = dist.ProcessGroupGloo(self.prefix_store, self.world_info.global_rank, self.world_info.global_world_size, TCPSTORE_TIMEOUT) + pg = dist.ProcessGroupGloo( + self.prefix_store, self.world_info.global_rank, self.world_info.global_world_size, TCPSTORE_TIMEOUT + ) if self.world_info.global_rank == 0: store.set("status", "running") store.set(f"rank_{self.unique_id}", str(self.world_info.global_rank)) @@ -160,7 +162,9 @@ def _init_global_pg(self) -> dist.Store: self.world_info.global_world_size = int(store.get("world_size").decode("utf-8")) self.mesh_count = int(store.get("mesh_count").decode("utf-8")) self.prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", store) - pg = dist.ProcessGroupGloo(self.prefix_store, self.world_info.global_rank, self.world_info.global_world_size, TCPSTORE_TIMEOUT) + pg = dist.ProcessGroupGloo( + self.prefix_store, self.world_info.global_rank, self.world_info.global_world_size, TCPSTORE_TIMEOUT + ) else: # TODO: Could be in "reinit" status raise RuntimeError(f"Unknown status {status}") diff --git a/src/zeroband/train.py b/src/zeroband/train.py index ddbf8104..7d4d064e 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -154,7 +154,8 @@ def train(config: Config): ) if config.diloco is not None: - diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh) + with FSDP.summon_full_params(model): + diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh) scheduler = get_cosine_schedule_with_warmup( inner_optimizer, From 970e1f597078390606b58fb8ad5911e80593e960 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Fri, 27 Sep 2024 20:16:02 +0000 Subject: [PATCH 14/32] move unique id to world info --- src/zeroband/comms.py | 40 +++++--------------------------- src/zeroband/diloco.py | 9 +++---- src/zeroband/utils/world_info.py | 2 ++ 3 files changed, 13 insertions(+), 38 deletions(-) diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py index e4b3409a..44a44058 100644 --- a/src/zeroband/comms.py +++ b/src/zeroband/comms.py @@ -2,11 +2,9 @@ from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger import torch.distributed as dist -import os from datetime import timedelta import time from typing import List, Tuple, Optional -import uuid from torch.testing._internal.distributed.fake_pg import FakeProcessGroup @@ -98,14 +96,9 @@ def __init__(self): # Initialize global process group self.global_pg = FakeProcessGroup(self.world_info.rank, 1) - if "GLOBAL_RANK" in os.environ: - self._init_unique_id() + if self.world_info.global_world_size > 1: if self.world_info.rank == 0: self.global_pg = self._init_global_pg() - # from torch.distributed.distributed_c10d import _world - # global_rank = int(os.environ["GLOBAL_RANK"]) - # _world.pg_group_ranks[self.global_pg] = {i: global_rank for i in range(self.world_info.world_size)} - # _world.pg_map[self.global_pg] = "gloo", self.global_store # Initialize local process group dist.init_process_group(backend="cpu:gloo,cuda:nccl") @@ -151,13 +144,13 @@ def _init_global_pg(self) -> dist.Store: ) if self.world_info.global_rank == 0: store.set("status", "running") - store.set(f"rank_{self.unique_id}", str(self.world_info.global_rank)) + store.set(f"rank_{self.world_info.global_unique_id}", str(self.world_info.global_rank)) elif status == "running": # Node wants to join - _queue_join(store, self.unique_id) + _queue_join(store, self.world_info.global_unique_id) _wait_for_status(store, "reinit") # Get assigned rank - self.world_info.global_rank = int(store.get(f"rank_{self.unique_id}").decode("utf-8")) + self.world_info.global_rank = int(store.get(f"rank_{self.world_info.global_unique_id}").decode("utf-8")) # Get updated world_size self.world_info.global_world_size = int(store.get("world_size").decode("utf-8")) self.mesh_count = int(store.get("mesh_count").decode("utf-8")) @@ -174,27 +167,6 @@ def _init_global_pg(self) -> dist.Store: self.leaving = False return pg - def _init_unique_id(self): - """Initialize a unique ID for the node. - If TORCH_UNIQUE_ID is set, use that. - Otherwise, local rank 0 generates an ID and broadcasts to other nodes. - """ - if "TORCH_UNIQUE_ID" in os.environ: - self.unique_id = os.environ["TORCH_UNIQUE_ID"] - return - if self.local_rank == 0: - self.unique_id = str(uuid.uuid4()) - with open("/tmp/torch_unique_id", "w") as f: - f.write(self.unique_id) - else: - while True: - try: - with open("/tmp/torch_unique_id", "r") as f: - self.unique_id = f.read() - break - except FileNotFoundError: - time.sleep(0.1) - def _resolve_world(self): """Set the new world size and ranks for all nodes.""" # Find joiners and leavers @@ -238,7 +210,7 @@ def maybe_reinit_device_mesh(self): return # Check if we got remapped - prev_uuid_rank = int(self.global_store.get(f"rank_{self.unique_id}").decode("utf-8")) + prev_uuid_rank = int(self.global_store.get(f"rank_{self.world_info.global_unique_id}").decode("utf-8")) new_uuid_rank = int(self.global_store.get(f"rank_map_{prev_uuid_rank}").decode("utf-8")) self.rank = new_uuid_rank + self.local_rank @@ -254,6 +226,6 @@ def maybe_reinit_device_mesh(self): self.global_store.set("status", "running") # Update rank if needed (otherwise, the next remap will do the lookup incorrectly) if self.local_rank == 0 and new_uuid_rank != prev_uuid_rank: - self.global_store.set(f"rank_{self.unique_id}", str(new_uuid_rank)) + self.global_store.set(f"rank_{self.world_info.global_unique_id}", str(new_uuid_rank)) # Reinitialize sub process groups self.world_rank = self.rank // self.local_world_size diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 0f74d43c..5798d565 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -99,14 +99,15 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: """ Offload the model parameters to cpu """ - unique_id = self.elastic_device_mesh.unique_id offloaded_params = [] - os.makedirs("/dev/shm/zeroband", exist_ok=True) + os.makedirs(f"/dev/shm/zeroband/{self.world_info.global_unique_id}", exist_ok=True) for param_name, param in model.named_parameters(): if param.requires_grad: storage = torch.UntypedStorage.from_file( - f"/dev/shm/zeroband/{unique_id}-{param_name}", True, param.data.untyped_storage().size() + f"/dev/shm/zeroband/{self.world_info.global_unique_id}/{param_name}", + True, + param.data.untyped_storage().size(), ) offloaded_param = torch.tensor(storage, dtype=param.dtype, device="cpu") offloaded_param.as_strided_(size=param.data.size(), stride=param.data.stride()) @@ -133,4 +134,4 @@ def step(self, model: nn.Module): self.sync_inner_model(model) def __del__(self): - shutil.rmtree("/dev/shm/zeroband", ignore_errors=True) + shutil.rmtree(f"/dev/shm/zeroband/{self.world_info.global_unique_id}", ignore_errors=True) diff --git a/src/zeroband/utils/world_info.py b/src/zeroband/utils/world_info.py index 50aa3d48..7261a7b6 100644 --- a/src/zeroband/utils/world_info.py +++ b/src/zeroband/utils/world_info.py @@ -18,11 +18,13 @@ def __init__(self): self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) self.nnodes = self.world_size // self.local_world_size + self.global_unique_id = os.environ.get("GLOBAL_UNIQUE_ID", "") self.global_addr = os.environ.get("GLOBAL_ADDR", "") self.global_port = int(os.environ.get("GLOBAL_PORT", -1)) self.global_world_size = int(os.environ.get("GLOBAL_WORLD_SIZE", -1)) self.global_rank = int(os.environ.get("GLOBAL_RANK", -1)) + def get_world_info() -> WorldInfo: """ Return a WorldInfo singleton. From 9c22e20e3992ec84e2570497c6ae30b153dcf3a3 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Fri, 27 Sep 2024 20:16:16 +0000 Subject: [PATCH 15/32] update command in readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 134dd1a2..a53ed1f6 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node_diloco.sh 2 2 src/zeroba if you have only two gpus ```bash -ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node.sh 2 1 src/zeroband/train.py @configs/debug/diloco.toml +ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node_diloco.sh 2 1 src/zeroband/train.py @configs/debug/diloco.toml ``` One gpu is not supported at the moment because of a fsdp bug in our implementation. From 84b72974dfe6ee7069b32f11b54bc5db98d1eb42 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Fri, 27 Sep 2024 20:22:14 +0000 Subject: [PATCH 16/32] remove broadcasts at init --- src/zeroband/train.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 7d4d064e..9821a47f 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -118,20 +118,7 @@ def train(config: Config): config.data.seq_length, ) - from torch.distributed.distributed_c10d import BroadcastOptions - elastic_device_mesh = ElasticDeviceMesh() - if world_info.rank == 0: - for param in model.parameters(): - # TODO: Kinda ugly but somethings wrong with the world registration - # dist.broadcast(param.data, src=0, group=elastic_device_mesh.global_pg) - opts = BroadcastOptions() - opts.rootRank = 0 - opts.rootTensor = 0 - elastic_device_mesh.global_pg.broadcast([param.data], opts) - - for param in model.parameters(): - dist.broadcast(param.data, src=0, group=elastic_device_mesh.local_pg) model = FSDP( model, From 7345420847627974e62a88b18cf541f9abdf18aa Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Fri, 27 Sep 2024 20:44:10 +0000 Subject: [PATCH 17/32] move summon full params to diloco class --- src/zeroband/data.py | 3 ++- src/zeroband/diloco.py | 32 +++++++++++++++++++------------- src/zeroband/train.py | 10 ++-------- 3 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/zeroband/data.py b/src/zeroband/data.py index 61a1a986..8e1be6c6 100644 --- a/src/zeroband/data.py +++ b/src/zeroband/data.py @@ -69,7 +69,7 @@ def get_dataloader( ds = load_dataset("allenai/c4", "en", streaming=True) def tokenize_function(data): - outputs = tokenizer(data["text"], truncation=True, max_length=seq_length) + outputs = tokenizer(data["text"], truncation=True, max_length=seq_length, padding="max_length") return outputs tokenized_datasets = ds.map( @@ -78,6 +78,7 @@ def tokenize_function(data): train_dataset = split_dataset_by_node(tokenized_datasets, world_size=world_size, rank=rank) data_collator = collate_causal_mask(max_seq_length=seq_length, pad_id=tokenizer.pad_token_id, ignore_index=-100) + print(train_dataset, flush=True) return DataLoader( train_dataset, diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 5798d565..68c09156 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -8,6 +8,8 @@ from zeroband.comms import ElasticDeviceMesh from torch.distributed.fsdp import ShardingStrategy import torch.distributed as dist +from zeroband.testing import get_module_signature +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP class DilocoConfig(BaseConfig): @@ -61,11 +63,12 @@ def __init__( self._init_offloaded_optimizer(model=model) def _init_offloaded_optimizer(self, model: nn.Module): - self.param_list_cpu = self.get_offloaded_param(model) - self.outer_optimizer = torch.optim.SGD( - self.param_list_cpu, lr=self.config.outer_lr, momentum=0.9, nesterov=True - ) - self._logger.debug("offload model to cpu") + with FSDP.summon_full_params(model): + self.param_list_cpu = self.get_offloaded_param(model) + self.outer_optimizer = torch.optim.SGD( + self.param_list_cpu, lr=self.config.outer_lr, momentum=0.9, nesterov=True + ) + self._logger.debug("offload model to cpu") def sync_pseudo_gradient(self, model: nn.Module): """ @@ -124,14 +127,17 @@ def step(self, model: nn.Module): """ Step the optimizer """ - if self.world_info.rank == 0: - self.sync_pseudo_gradient(model) - if self.outer_optimizer is not None: - self.outer_optimizer.step() - self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this - - dist.barrier() - self.sync_inner_model(model) + with FSDP.summon_full_params(model): + self._logger.debug("Pre diloco step %s", get_module_signature(model)) + if self.world_info.rank == 0: + self.sync_pseudo_gradient(model) + if self.outer_optimizer is not None: + self.outer_optimizer.step() + self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this + + dist.barrier() + self.sync_inner_model(model) + self._logger.debug("Post meow diloco step %s", get_module_signature(model)) def __del__(self): shutil.rmtree(f"/dev/shm/zeroband/{self.world_info.global_unique_id}", ignore_errors=True) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 9821a47f..d431fb06 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -27,8 +27,6 @@ from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger -from zeroband.testing import get_module_signature - class DataConfig(BaseConfig): seq_length: int = 1024 @@ -141,8 +139,7 @@ def train(config: Config): ) if config.diloco is not None: - with FSDP.summon_full_params(model): - diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh) + diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh) scheduler = get_cosine_schedule_with_warmup( inner_optimizer, @@ -231,10 +228,7 @@ def train(config: Config): logger.info(log) if config.diloco is not None: - with FSDP.summon_full_params(model): - logger.debug("Pre diloco step %s", get_module_signature(model)) - diloco.step(model) - logger.debug("Post diloco step %s", get_module_signature(model)) + diloco.step(model) outer_step += 1 From a02c724c059ef38f5cb73dcf75a5a334c65dd3a7 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Fri, 27 Sep 2024 23:30:12 +0000 Subject: [PATCH 18/32] fix data split --- src/zeroband/train.py | 8 +++++--- src/zeroband/utils/world_info.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index d431fb06..46082593 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -85,8 +85,8 @@ def train(config: Config): train_dataloader = get_dataloader( tokenizer=tokenizer, - world_size=world_info.world_size, - rank=world_info.rank, + world_size=world_info.world_size * world_info.global_world_size, + rank=world_info.rank + world_info.global_rank * world_info.global_world_size, seq_length=config.data.seq_length, batch_size=config.train.micro_bs, num_workers=config.data.num_workers, @@ -95,7 +95,9 @@ def train(config: Config): model, model_config = get_model( config.name_model, config.type_model, - vocab_size=tokenizer.vocab_size if config.name_model != "debugmodel" else TEST_VOCAB_SIZE, + vocab_size=tokenizer.vocab_size + if config.name_model != "debugmodel" or not config.data.fake + else TEST_VOCAB_SIZE, ) if config.train.log_model_hash: diff --git a/src/zeroband/utils/world_info.py b/src/zeroband/utils/world_info.py index 7261a7b6..2d0803e8 100644 --- a/src/zeroband/utils/world_info.py +++ b/src/zeroband/utils/world_info.py @@ -21,8 +21,8 @@ def __init__(self): self.global_unique_id = os.environ.get("GLOBAL_UNIQUE_ID", "") self.global_addr = os.environ.get("GLOBAL_ADDR", "") self.global_port = int(os.environ.get("GLOBAL_PORT", -1)) - self.global_world_size = int(os.environ.get("GLOBAL_WORLD_SIZE", -1)) - self.global_rank = int(os.environ.get("GLOBAL_RANK", -1)) + self.global_world_size = int(os.environ.get("GLOBAL_WORLD_SIZE", 1)) + self.global_rank = int(os.environ.get("GLOBAL_RANK", 0)) def get_world_info() -> WorldInfo: From ccddfe47d790396d5e4d58326daa74851691b7ed Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Fri, 27 Sep 2024 23:33:17 +0000 Subject: [PATCH 19/32] move testing to utils --- src/zeroband/diloco.py | 2 +- src/zeroband/utils/__init__.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 68c09156..f42397af 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -3,12 +3,12 @@ from pydantic_config import BaseConfig import torch from torch import nn +from zeroband.utils import get_module_signature from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger from zeroband.comms import ElasticDeviceMesh from torch.distributed.fsdp import ShardingStrategy import torch.distributed as dist -from zeroband.testing import get_module_signature from torch.distributed.fsdp import FullyShardedDataParallel as FSDP diff --git a/src/zeroband/utils/__init__.py b/src/zeroband/utils/__init__.py index 80abef69..d29d4b1e 100644 --- a/src/zeroband/utils/__init__.py +++ b/src/zeroband/utils/__init__.py @@ -105,3 +105,33 @@ def get_model_hash(model: torch.nn.Module) -> str: # Compute SHA256 hash return hashlib.sha256(param_bytes).hexdigest() + + +TENSOR_SIG_SAMPLE_SIZE = 100 + + +def get_tensor_signature(a: torch.Tensor | torch.nn.Parameter) -> str: + """ + Get the tensor signature + """ + while isinstance(a, torch.nn.Parameter): + a = a.data + if a.numel() < TENSOR_SIG_SAMPLE_SIZE: + b = a.as_strided(size=(a.numel(),), stride=(1,)) + else: + step_size = a.numel() // TENSOR_SIG_SAMPLE_SIZE + b = a.as_strided(size=(TENSOR_SIG_SAMPLE_SIZE,), stride=(step_size,)) + element_str = "".join([f"{x:.3e}" for x in b]) + element_hash = hashlib.md5(element_str.encode("utf-8")).hexdigest() + return f"{a.dtype}{a.shape}{a.stride()}<{element_hash}>" + + +def get_module_signature(module: torch.nn.Module, compress: bool = True) -> str: + """ + Get the module signature + """ + state_dict_sig = {name: get_tensor_signature(param) for name, param in module.named_parameters()} + if compress: + return hashlib.md5(str(state_dict_sig).encode("utf-8")).hexdigest() + else: + return "\n".join(f"{name}: {sig}" for name, sig in state_dict_sig.items()) From ed8d6e81b406e971a16144d704819a1cf38b85d5 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Fri, 27 Sep 2024 23:36:59 +0000 Subject: [PATCH 20/32] document offloading logic --- src/zeroband/diloco.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index f42397af..1a09bfa2 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -102,6 +102,9 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: """ Offload the model parameters to cpu """ + # The change here makes processes which are part of the same FSDP replica group (which are assumed to be on the same node with the same /dev/shm) use the same underlying storage for the offloaded_param. + # All the processes use the same shared memory file to create a storage for each parameter. Only rank 0 will do the copy. + # A barrier is added to ensure that after the function completes, the parameters are all offloaded. Otherwise, the non 0 ranks might access uninitialized memory. offloaded_params = [] os.makedirs(f"/dev/shm/zeroband/{self.world_info.global_unique_id}", exist_ok=True) From d5b6e2f60015154377b9679f3a0046a10a21c622 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Fri, 27 Sep 2024 23:43:29 +0000 Subject: [PATCH 21/32] add envs to readme --- README.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a53ed1f6..424fc304 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ZeroBand is a production ready codebase for decentralized training of LLM -## Developlment +## Development install uv @@ -71,8 +71,15 @@ One gpu is not supported at the moment because of a fsdp bug in our implementati You need a machine with a least two gpus to run the full test suite. Some test must be run from the root directory. - ```bash uv run pytest ``` +## Environment variables +| Environment Variable | Description | Default Value | +|-----------------------|--------------------------------------------------|---------------| +| `GLOBAL_UNIQUE_ID` | Unique identifier worker in global store. | `""` (empty string) | +| `GLOBAL_ADDR` | IP Address of the global store | `""` (empty string) | +| `GLOBAL_PORT` | Port number of the global store. | `-1` | +| `GLOBAL_WORLD_SIZE` | The size of the global process group. | `1` | +| `GLOBAL_RANK` | Rank of the process in the global process group. | `0` | From a890e0db469de71d3ae487402fb2ab3235e91ece Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sat, 28 Sep 2024 00:30:37 +0000 Subject: [PATCH 22/32] repre for worldinfo --- src/zeroband/utils/world_info.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/zeroband/utils/world_info.py b/src/zeroband/utils/world_info.py index 2d0803e8..d08d8bad 100644 --- a/src/zeroband/utils/world_info.py +++ b/src/zeroband/utils/world_info.py @@ -24,6 +24,9 @@ def __init__(self): self.global_world_size = int(os.environ.get("GLOBAL_WORLD_SIZE", 1)) self.global_rank = int(os.environ.get("GLOBAL_RANK", 0)) + def __repr__(self): + return f"WorldInfo(world_size={self.world_size}, rank={self.rank}, local_rank={self.local_rank}, local_world_size={self.local_world_size}, nnodes={self.nnodes}, global_unique_id={self.global_unique_id}, global_addr={self.global_addr}, global_port={self.global_port}, global_world_size={self.global_world_size}, global_rank={self.global_rank})" + def get_world_info() -> WorldInfo: """ From 63bc0f67e982fcc44f09498428f9f0ee93af9577 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sat, 28 Sep 2024 00:31:16 +0000 Subject: [PATCH 23/32] revert to global pg --- src/zeroband/diloco.py | 17 +++++------------ src/zeroband/train.py | 2 +- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 1a09bfa2..9d9c12fc 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -6,7 +6,6 @@ from zeroband.utils import get_module_signature from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger -from zeroband.comms import ElasticDeviceMesh from torch.distributed.fsdp import ShardingStrategy import torch.distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -48,11 +47,11 @@ def __init__( config: DilocoConfig, model: nn.Module, fsdp_sharding_strategy: ShardingStrategy, - elastic_device_mesh: ElasticDeviceMesh, + global_pg: dist.ProcessGroup, ): self.config = config self.fsdp_sharding_strategy = fsdp_sharding_strategy - self.elastic_device_mesh = elastic_device_mesh + self.global_pg = global_pg self._logger = get_logger() self.world_info = get_world_info() @@ -75,19 +74,13 @@ def sync_pseudo_gradient(self, model: nn.Module): Sync the pseudo gradient from the local process group to the global process group """ self._logger.debug("sync pseudo gradient") - works = [] # TODO: This assumes all params require grad, which is used by the offload for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) # gloo does not support AVG - param_offloaded.grad = param_offloaded.grad / self.elastic_device_mesh.global_pg.size() - work = dist.all_reduce( - param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.elastic_device_mesh.global_pg, async_op=True - ) - works.append(work) - for work in works: - work.wait() + param_offloaded.grad = param_offloaded.grad / self.global_pg.size() + dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.global_pg) def sync_inner_model(self, model: nn.Module): """ @@ -117,7 +110,7 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: ) offloaded_param = torch.tensor(storage, dtype=param.dtype, device="cpu") offloaded_param.as_strided_(size=param.data.size(), stride=param.data.stride()) - if self.world_info.rank == 0: + if self.world_info.local_rank == 0: # TODO: Can we async or split the copy among gpus probs overkill? offloaded_param.copy_(param.data) offloaded_param.requires_grad = False # TODO: check if we need to set this to True diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 46082593..c13d3c63 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -141,7 +141,7 @@ def train(config: Config): ) if config.diloco is not None: - diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh) + diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh.global_pg) scheduler = get_cosine_schedule_with_warmup( inner_optimizer, From fca81652c056d15ca79d39a4361f4c4a489e3a06 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sat, 28 Sep 2024 00:33:16 +0000 Subject: [PATCH 24/32] set unique id in tests --- tests/test_dist/conftest.py | 9 ++++++--- tests/test_dist/test_diloco.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_dist/conftest.py b/tests/test_dist/conftest.py index 2f01829e..dd9df4d6 100644 --- a/tests/test_dist/conftest.py +++ b/tests/test_dist/conftest.py @@ -45,14 +45,17 @@ def random_available_port(): @pytest.fixture() def dist_environment() -> callable: @contextmanager - def dist_environment(random_available_port, local_rank=0, world_size=1, local_world_size=1): + def dist_environment( + random_available_port, rank=0, local_rank=0, world_size=1, local_world_size=1, global_unique_id="" + ): with mock.patch.dict( os.environ, { - "LOCAL_RANK": str(local_rank), + "GLOBAL_UNIQUE_ID": global_unique_id, + "RANK": str(rank), "WORLD_SIZE": str(world_size), + "LOCAL_RANK": str(local_rank), "LOCAL_WORLD_SIZE": str(local_world_size), - "RANK": str(local_rank), "MASTER_ADDR": "localhost", "MASTER_PORT": str(random_available_port), "ZERO_BAND_LOG_LEVEL": "DEBUG", diff --git a/tests/test_dist/test_diloco.py b/tests/test_dist/test_diloco.py index 43fbdcc9..c9a90e94 100644 --- a/tests/test_dist/test_diloco.py +++ b/tests/test_dist/test_diloco.py @@ -22,7 +22,7 @@ def test_diloco_all_reduce(world_size, random_available_port, dist_environment): """ def all_reduce(rank: int, world_size: int): - with dist_environment(random_available_port, local_rank=rank, world_size=world_size): + with dist_environment(random_available_port, rank=rank, world_size=world_size, global_unique_id=str(rank)): diloco_config = DilocoConfig(inner_steps=10) model = torch.nn.Linear(10, 10) From ceb96fcaa114cccfce1f44ac3db63ba5bbd7c3fe Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sat, 28 Sep 2024 00:37:30 +0000 Subject: [PATCH 25/32] fix: nccl cannot all reduce same device --- tests/test_dist/test_all_reduce.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_dist/test_all_reduce.py b/tests/test_dist/test_all_reduce.py index 28133070..2ee020f4 100644 --- a/tests/test_dist/test_all_reduce.py +++ b/tests/test_dist/test_all_reduce.py @@ -16,8 +16,8 @@ @pytest.mark.parametrize("world_size", [2]) def test_all_reduce(world_size, random_available_port, dist_environment): def all_reduce(rank: int, world_size: int): - with dist_environment(random_available_port, local_rank=rank, world_size=world_size): - data = (rank + 1) * torch.ones(10, 10).to("cuda") + with dist_environment(random_available_port, rank=rank, world_size=world_size): + data = (rank + 1) * torch.ones(10, 10).to(f"cuda:{rank}") dist.all_reduce(data, op=dist.ReduceOp.SUM) assert data.mean() == sum([i + 1 for i in range(world_size)]) From 4b454716c4f25308abb0936cae967b5e60ace1a2 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sun, 29 Sep 2024 03:34:09 +0800 Subject: [PATCH 26/32] use get module signature instead of model hash --- src/zeroband/train.py | 4 ++-- src/zeroband/utils/__init__.py | 14 -------------- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index c13d3c63..46b6d8a0 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -20,7 +20,7 @@ from zeroband.diloco import Diloco, DilocoConfig from zeroband.comms import ElasticDeviceMesh -from zeroband.utils import PerfCounter, get_model_hash, get_sharding_strategy +from zeroband.utils import PerfCounter, get_module_signature, get_sharding_strategy from zeroband.utils.monitor import WandbMonitor, DummyMonitor from zeroband.data import TEST_VOCAB_SIZE, get_dataloader from zeroband.models.llama import get_model @@ -102,7 +102,7 @@ def train(config: Config): if config.train.log_model_hash: # Compute SHA256 hash - logger.info(f"Model hash: {get_model_hash(model)}") + logger.info(f"Model hash: {get_module_signature(model)}") model = model.to(world_info.local_rank) logger.debug("model loaded") diff --git a/src/zeroband/utils/__init__.py b/src/zeroband/utils/__init__.py index d29d4b1e..14486b23 100644 --- a/src/zeroband/utils/__init__.py +++ b/src/zeroband/utils/__init__.py @@ -93,20 +93,6 @@ def get_tokens_per_second(self) -> float | None: return sum(self.tokens) / (self.times[-1] - self.times[0]) -def get_model_hash(model: torch.nn.Module) -> str: - """ - Get the hash of the model parameters. - """ - # Concatenate all model parameters into a single tensor - all_params = torch.cat([p.data.view(-1) for p in model.parameters()]) - - # Convert the tensor to a byte string - param_bytes = all_params.cpu().numpy().tobytes() - - # Compute SHA256 hash - return hashlib.sha256(param_bytes).hexdigest() - - TENSOR_SIG_SAMPLE_SIZE = 100 From fa8d3dddd56d0dedda396c3b24871fd8e95745eb Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sun, 29 Sep 2024 03:37:55 +0800 Subject: [PATCH 27/32] change default global unique id to none --- src/zeroband/utils/world_info.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/zeroband/utils/world_info.py b/src/zeroband/utils/world_info.py index d08d8bad..9b73f328 100644 --- a/src/zeroband/utils/world_info.py +++ b/src/zeroband/utils/world_info.py @@ -18,9 +18,9 @@ def __init__(self): self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) self.nnodes = self.world_size // self.local_world_size - self.global_unique_id = os.environ.get("GLOBAL_UNIQUE_ID", "") - self.global_addr = os.environ.get("GLOBAL_ADDR", "") - self.global_port = int(os.environ.get("GLOBAL_PORT", -1)) + self.global_unique_id = os.environ.get("GLOBAL_UNIQUE_ID", None) + self.global_addr = os.environ.get("GLOBAL_ADDR", None) + self.global_port = int(os.environ.get("GLOBAL_PORT")) if "GLOBAL_PORT" in os.environ else None self.global_world_size = int(os.environ.get("GLOBAL_WORLD_SIZE", 1)) self.global_rank = int(os.environ.get("GLOBAL_RANK", 0)) From 73800d98223c73270e9a6a6d68373e17632d61de Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sun, 29 Sep 2024 03:38:42 +0800 Subject: [PATCH 28/32] revert data changes --- src/zeroband/data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/zeroband/data.py b/src/zeroband/data.py index 8e1be6c6..61a1a986 100644 --- a/src/zeroband/data.py +++ b/src/zeroband/data.py @@ -69,7 +69,7 @@ def get_dataloader( ds = load_dataset("allenai/c4", "en", streaming=True) def tokenize_function(data): - outputs = tokenizer(data["text"], truncation=True, max_length=seq_length, padding="max_length") + outputs = tokenizer(data["text"], truncation=True, max_length=seq_length) return outputs tokenized_datasets = ds.map( @@ -78,7 +78,6 @@ def tokenize_function(data): train_dataset = split_dataset_by_node(tokenized_datasets, world_size=world_size, rank=rank) data_collator = collate_causal_mask(max_seq_length=seq_length, pad_id=tokenizer.pad_token_id, ignore_index=-100) - print(train_dataset, flush=True) return DataLoader( train_dataset, From e64eb2d754c1ffd4c00691ce7c112bf24b163dd0 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sun, 29 Sep 2024 04:06:36 +0800 Subject: [PATCH 29/32] make /dev/shm/zeroband a constant and some fixes --- README.md | 6 +++--- scripts/simulate_multi_node_diloco.sh | 2 +- src/zeroband/diloco.py | 11 +++++++---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 424fc304..42116a54 100644 --- a/README.md +++ b/README.md @@ -78,8 +78,8 @@ uv run pytest ## Environment variables | Environment Variable | Description | Default Value | |-----------------------|--------------------------------------------------|---------------| -| `GLOBAL_UNIQUE_ID` | Unique identifier worker in global store. | `""` (empty string) | -| `GLOBAL_ADDR` | IP Address of the global store | `""` (empty string) | -| `GLOBAL_PORT` | Port number of the global store. | `-1` | +| `GLOBAL_UNIQUE_ID` | Unique identifier worker in global store. | `None` | +| `GLOBAL_ADDR` | IP Address of the global store | `None` | +| `GLOBAL_PORT` | Port number of the global store. | `None` | | `GLOBAL_WORLD_SIZE` | The size of the global process group. | `1` | | `GLOBAL_RANK` | Rank of the process in the global process group. | `0` | diff --git a/scripts/simulate_multi_node_diloco.sh b/scripts/simulate_multi_node_diloco.sh index 858c1805..cbbd8737 100755 --- a/scripts/simulate_multi_node_diloco.sh +++ b/scripts/simulate_multi_node_diloco.sh @@ -59,7 +59,7 @@ export GLOBAL_WORLD_SIZE=$N for i in $(seq 0 $(($N - 1 ))) do > logs/log$i - TORCH_UNIQUE_ID=$i GLOBAL_RANK=$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank 0 --rdzv-endpoint localhost:$((10001 + $i)) --nnodes=1 $@ > logs/log$i 2>&1 & + GLOBAL_UNIQUE_ID=$i GLOBAL_RANK=$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank 0 --rdzv-endpoint localhost:$((10001 + $i)) --nnodes=1 $@ > logs/log$i 2>&1 & child_pids+=($!) done diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 9d9c12fc..73fc23dd 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -1,5 +1,5 @@ -import os import shutil +import os from pydantic_config import BaseConfig import torch from torch import nn @@ -16,6 +16,9 @@ class DilocoConfig(BaseConfig): inner_steps: int +SHARED_MEMORY_PATH = "/dev/shm/zeroband" + + class Diloco: """ This class implements the diloco algorithm from https://arxiv.org/abs/2311.08105 and https://arxiv.org/abs/2407.07852. @@ -99,12 +102,12 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: # All the processes use the same shared memory file to create a storage for each parameter. Only rank 0 will do the copy. # A barrier is added to ensure that after the function completes, the parameters are all offloaded. Otherwise, the non 0 ranks might access uninitialized memory. offloaded_params = [] - os.makedirs(f"/dev/shm/zeroband/{self.world_info.global_unique_id}", exist_ok=True) + os.makedirs(f"{SHARED_MEMORY_PATH}/{self.world_info.global_unique_id}", exist_ok=True) for param_name, param in model.named_parameters(): if param.requires_grad: storage = torch.UntypedStorage.from_file( - f"/dev/shm/zeroband/{self.world_info.global_unique_id}/{param_name}", + f"{SHARED_MEMORY_PATH}/{self.world_info.global_unique_id}/{param_name}", True, param.data.untyped_storage().size(), ) @@ -136,4 +139,4 @@ def step(self, model: nn.Module): self._logger.debug("Post meow diloco step %s", get_module_signature(model)) def __del__(self): - shutil.rmtree(f"/dev/shm/zeroband/{self.world_info.global_unique_id}", ignore_errors=True) + shutil.rmtree(f"{SHARED_MEMORY_PATH}/{self.world_info.global_unique_id}", ignore_errors=True) From d41de80a1064a9f553315c9f86e88e0cb35c474d Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sun, 29 Sep 2024 04:53:52 +0800 Subject: [PATCH 30/32] revert shm offload --- src/zeroband/diloco.py | 64 +++++++++++------------------------------- 1 file changed, 17 insertions(+), 47 deletions(-) diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 73fc23dd..f2c34e3d 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -1,14 +1,10 @@ -import shutil -import os from pydantic_config import BaseConfig import torch from torch import nn -from zeroband.utils import get_module_signature from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger from torch.distributed.fsdp import ShardingStrategy import torch.distributed as dist -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP class DilocoConfig(BaseConfig): @@ -16,9 +12,6 @@ class DilocoConfig(BaseConfig): inner_steps: int -SHARED_MEMORY_PATH = "/dev/shm/zeroband" - - class Diloco: """ This class implements the diloco algorithm from https://arxiv.org/abs/2311.08105 and https://arxiv.org/abs/2407.07852. @@ -64,26 +57,25 @@ def __init__( self._init_offloaded_optimizer(model=model) - def _init_offloaded_optimizer(self, model: nn.Module): - with FSDP.summon_full_params(model): - self.param_list_cpu = self.get_offloaded_param(model) - self.outer_optimizer = torch.optim.SGD( - self.param_list_cpu, lr=self.config.outer_lr, momentum=0.9, nesterov=True - ) - self._logger.debug("offload model to cpu") + def _init_offloaded_optimizer(self, model): + self.param_list_cpu = self.get_offloaded_param(model) + self.outer_optimizer = torch.optim.SGD( + self.param_list_cpu, lr=self.config.outer_lr, momentum=0.9, nesterov=True + ) + self._logger.debug("offload model to cpu") def sync_pseudo_gradient(self, model: nn.Module): """ Sync the pseudo gradient from the local process group to the global process group """ self._logger.debug("sync pseudo gradient") - # TODO: This assumes all params require grad, which is used by the offload for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) # gloo does not support AVG param_offloaded.grad = param_offloaded.grad / self.global_pg.size() dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.global_pg) + # todo async here def sync_inner_model(self, model: nn.Module): """ @@ -92,51 +84,29 @@ def sync_inner_model(self, model: nn.Module): self._logger.debug("sync inner model") for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): - param.data.copy_(param_offloaded.data) + param.data.copy_(param_offloaded.data) # todo: use copy_ here def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: """ Offload the model parameters to cpu """ - # The change here makes processes which are part of the same FSDP replica group (which are assumed to be on the same node with the same /dev/shm) use the same underlying storage for the offloaded_param. - # All the processes use the same shared memory file to create a storage for each parameter. Only rank 0 will do the copy. - # A barrier is added to ensure that after the function completes, the parameters are all offloaded. Otherwise, the non 0 ranks might access uninitialized memory. offloaded_params = [] - os.makedirs(f"{SHARED_MEMORY_PATH}/{self.world_info.global_unique_id}", exist_ok=True) - for param_name, param in model.named_parameters(): + for param in model.parameters(): if param.requires_grad: - storage = torch.UntypedStorage.from_file( - f"{SHARED_MEMORY_PATH}/{self.world_info.global_unique_id}/{param_name}", - True, - param.data.untyped_storage().size(), - ) - offloaded_param = torch.tensor(storage, dtype=param.dtype, device="cpu") - offloaded_param.as_strided_(size=param.data.size(), stride=param.data.stride()) - if self.world_info.local_rank == 0: - # TODO: Can we async or split the copy among gpus probs overkill? - offloaded_param.copy_(param.data) - offloaded_param.requires_grad = False # TODO: check if we need to set this to True + offloaded_param = param.data.detach().clone().to("cpu") + offloaded_param.requires_grad = True offloaded_params.append(offloaded_param) - dist.barrier() return offloaded_params def step(self, model: nn.Module): """ Step the optimizer """ - with FSDP.summon_full_params(model): - self._logger.debug("Pre diloco step %s", get_module_signature(model)) - if self.world_info.rank == 0: - self.sync_pseudo_gradient(model) - if self.outer_optimizer is not None: - self.outer_optimizer.step() - self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this - - dist.barrier() - self.sync_inner_model(model) - self._logger.debug("Post meow diloco step %s", get_module_signature(model)) - - def __del__(self): - shutil.rmtree(f"{SHARED_MEMORY_PATH}/{self.world_info.global_unique_id}", ignore_errors=True) + self.sync_pseudo_gradient(model) + if self.outer_optimizer is not None: + self.outer_optimizer.step() + self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this + + self.sync_inner_model(model) From 9c3940105bac910c94255677f9a3647fade47e14 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sat, 28 Sep 2024 22:24:54 +0000 Subject: [PATCH 31/32] fix: non zero rank need to reduce too --- src/zeroband/comms.py | 5 ++--- src/zeroband/diloco.py | 2 ++ src/zeroband/train.py | 6 ++++++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py index 44a44058..6a7fdded 100644 --- a/src/zeroband/comms.py +++ b/src/zeroband/comms.py @@ -97,8 +97,7 @@ def __init__(self): # Initialize global process group self.global_pg = FakeProcessGroup(self.world_info.rank, 1) if self.world_info.global_world_size > 1: - if self.world_info.rank == 0: - self.global_pg = self._init_global_pg() + self.global_pg = self._init_global_pg() # Initialize local process group dist.init_process_group(backend="cpu:gloo,cuda:nccl") @@ -120,7 +119,7 @@ def __del__(self): def _init_global_pg(self) -> dist.Store: store = dist.TCPStore( host_name=self.world_info.global_addr, - port=self.world_info.global_port, + port=self.world_info.global_port + self.world_info.rank, timeout=TCPSTORE_TIMEOUT, is_master=(self.world_info.global_rank == 0), ) diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index f2c34e3d..46a61cbb 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -70,6 +70,8 @@ def sync_pseudo_gradient(self, model: nn.Module): """ self._logger.debug("sync pseudo gradient") for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): + if param.shape[0] == 0: + continue param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) # gloo does not support AVG diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 46b6d8a0..4ff5134d 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -230,7 +230,13 @@ def train(config: Config): logger.info(log) if config.diloco is not None: + if config.train.log_model_hash: + with FSDP.summon_full_params(model): + logger.debug("Pre diloco model: %s", get_module_signature(model)) diloco.step(model) + if config.train.log_model_hash: + with FSDP.summon_full_params(model): + logger.debug("Post diloco model: %s", get_module_signature(model)) outer_step += 1 From e21a048176dbddec3dfa38433b95e1beb54444f2 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sat, 28 Sep 2024 22:28:33 +0000 Subject: [PATCH 32/32] remove testing --- src/zeroband/testing.py | 31 ------------------------------- 1 file changed, 31 deletions(-) delete mode 100644 src/zeroband/testing.py diff --git a/src/zeroband/testing.py b/src/zeroband/testing.py deleted file mode 100644 index 684072e3..00000000 --- a/src/zeroband/testing.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -import hashlib - -TENSOR_SIG_SAMPLE_SIZE = 100 - - -def get_tensor_signature(a: torch.Tensor | torch.nn.Parameter) -> str: - """ - Get the tensor signature - """ - while isinstance(a, torch.nn.Parameter): - a = a.data - if a.numel() < TENSOR_SIG_SAMPLE_SIZE: - b = a.as_strided(size=(a.numel(),), stride=(1,)) - else: - step_size = a.numel() // TENSOR_SIG_SAMPLE_SIZE - b = a.as_strided(size=(TENSOR_SIG_SAMPLE_SIZE,), stride=(step_size,)) - element_str = "".join([f"{x:.3e}" for x in b]) - element_hash = hashlib.md5(element_str.encode("utf-8")).hexdigest() - return f"{a.dtype}{a.shape}{a.stride()}<{element_hash}>" - - -def get_module_signature(module: torch.nn.Module, compress: bool = True) -> str: - """ - Get the module signature - """ - state_dict_sig = {name: get_tensor_signature(param) for name, param in module.named_parameters()} - if compress: - return hashlib.md5(str(state_dict_sig).encode("utf-8")).hexdigest() - else: - return "\n".join(f"{name}: {sig}" for name, sig in state_dict_sig.items())