Skip to content

Commit

Permalink
STASH
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackmin801 committed Nov 7, 2024
1 parent a97d564 commit f3c770c
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 7 deletions.
9 changes: 8 additions & 1 deletion configs/150M/3090.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@ micro_bs = 16 # change this base on the gpu
reshard_after_forward = true

[optim]
batch_size = 512
batch_size = 64
warmup_steps = 1000
total_steps = 88_000
lr = 4e-4

[data]
fake = true

[diloco]
inner_steps = 20

18 changes: 18 additions & 0 deletions meow-join.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export WANDB_MODE=disabled
export GLOBAL_ADDR=localhost
export GLOBAL_PORT=1234
#export GLOBAL_WORLD_SIZE=2

export CUDA_VISIBLE_DEVICES=2,3
export GLOBAL_UNIQUE_ID=2
export GLOBAL_RANK=100

#export GLOO_SOCKET_IFNAME=tailscale0
export ZERO_BAND_LOG_LEVEL=DEBUG
export ZERO_BAND_LOG_ALL_RANK=true

uv run torchrun --nproc_per_node=2 \
--rdzv-endpoint localhost:1000$GLOBAL_UNIQUE_ID \
src/zeroband/train.py \
@configs/150M/3090.toml \
--no-wandb-resume
4 changes: 4 additions & 0 deletions meow.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
a,b
0,0
2,1
1,2
67 changes: 67 additions & 0 deletions meow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import torch.distributed as dist
import torch
import time

rank = int(os.environ['RANK'])
def rprint(*args):
print(f"[Rank {rank}] {' '.join(map(str, args))}\n", end="")

class EDM:
def __init__(self):
master_addr = os.environ['MASTER_ADDR']
master_port = int(os.environ['MASTER_PORT'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])

self.rank = rank
self.world_size = world_size

rprint("Creating Store")
self.global_store = dist.TCPStore(
host_name=master_addr,
port=master_port + 1,
is_master=(rank==0),
world_size=2,
)

rprint("Store created. Creating ProcessGroupGloo")
self.global_pg = dist.distributed_c10d.ProcessGroupGloo(self.global_store, rank, world_size)
rprint("ProcessGroupGloo created")

self.measure_connectivity()

def measure_connectivity(self):
recv_work = []
for i in range(self.world_size):
if i == self.rank:
continue
tensor = torch.ones(1_000_000, dtype=torch.float32)
rprint(f"Recv from peer {i} with tag {self.rank + self.world_size * i}")
recv_work.append(self.global_pg.recv([tensor], i, self.rank + self.world_size * i))

self.global_pg.barrier().wait()
for i in range(self.world_size):
if i == self.rank:
continue
rprint(f"Pinging peer {i}")
time_taken = self.ping_peer(i)
rprint(f"Ping to peer {i} took {time_taken} seconds")

for work in recv_work:
work.wait()

def ping_peer(self, peer_rank: int) -> float:
tensor = torch.ones(1_000_000, dtype=torch.float32)
start_time = time.perf_counter()
rprint(f"Send from peer {self.rank} to {peer_rank} with tag {self.rank * self.world_size + peer_rank}")
self.global_pg.send([tensor], peer_rank, self.rank * self.world_size + peer_rank).wait()
end_time = time.perf_counter()
return end_time - start_time

def main():
edm = EDM()


if __name__ == "__main__":
main()
18 changes: 18 additions & 0 deletions meow0.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export WANDB_MODE=disabled
export GLOBAL_ADDR=localhost
export GLOBAL_PORT=1234
#export GLOBAL_WORLD_SIZE=2

export CUDA_VISIBLE_DEVICES=0,1
export GLOBAL_UNIQUE_ID=0
#export GLOBAL_RANK=0

#export GLOO_SOCKET_IFNAME=tailscale0
export ZERO_BAND_LOG_LEVEL=DEBUG
export ZERO_BAND_LOG_ALL_RANK=true

uv run torchrun --nproc_per_node=2 \
--rdzv-endpoint localhost:1000$GLOBAL_UNIQUE_ID \
src/zeroband/train.py \
@configs/150M/3090.toml \
--no-wandb-resume
18 changes: 18 additions & 0 deletions meow1.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export WANDB_MODE=disabled
export GLOBAL_ADDR=localhost
export GLOBAL_PORT=1234
#export GLOBAL_WORLD_SIZE=2

export CUDA_VISIBLE_DEVICES=4,5
export GLOBAL_UNIQUE_ID=1
export GLOBAL_RANK=100

#export GLOO_SOCKET_IFNAME=tailscale0
export ZERO_BAND_LOG_LEVEL=DEBUG
export ZERO_BAND_LOG_ALL_RANK=true

uv run torchrun --nproc_per_node=2 \
--rdzv-endpoint localhost:1000$GLOBAL_UNIQUE_ID \
src/zeroband/train.py \
@configs/150M/3090.toml \
--no-wandb-resume
18 changes: 18 additions & 0 deletions meow2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export WANDB_MODE=disabled
export GLOBAL_ADDR=localhost
export GLOBAL_PORT=1234
#export GLOBAL_WORLD_SIZE=2

export CUDA_VISIBLE_DEVICES=2,3
export GLOBAL_UNIQUE_ID=2
export GLOBAL_RANK=100

#export GLOO_SOCKET_IFNAME=tailscale0
export ZERO_BAND_LOG_LEVEL=DEBUG
export ZERO_BAND_LOG_ALL_RANK=true

uv run torchrun --nproc_per_node=2 \
--rdzv-endpoint localhost:1000$GLOBAL_UNIQUE_ID \
src/zeroband/train.py \
@configs/150M/3090.toml \
--no-wandb-resume
53 changes: 47 additions & 6 deletions src/zeroband/comms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import sys
import os
import time
Expand Down Expand Up @@ -30,7 +31,14 @@ def _read_topo_file() -> Dict[str, int]:
for m in mappings:
assert len(m) == 2
m[1] = int(m[1])
return {m[0]: m[1] for m in mappings}
return {m[1]: m[0] for m in mappings}

def _write_topo_file(topo: Dict[str, int]) -> None:
"""Write the topology file with the given mappings."""
with open(TOPO_FILE, "w") as f:
f.write("gid,grank\n")
for gid, grank in topo.items():
f.write(f"{gid},{grank}\n")

class ElasticDeviceMesh:
"""A class to manage the process groups for elastic training without restarts.
Expand All @@ -44,7 +52,6 @@ class ElasticDeviceMesh:
- world_size: The current world size
- mesh_count: The version of the mesh
- rank_{uuid}: The rank of the node with the given uuid
- rank_map_{rank}: The new rank of the node with the given rank. Used to remap ranks when nodes leave.
- joiner_{i}: The uuid of the ith joiner. Its a KV implmentation of a queue.
"""

Expand Down Expand Up @@ -90,7 +97,7 @@ def _init_global_store_and_status(self):
# Topology
mappings = _read_topo_file()
new_world_size = 0
for joiner_id, global_rank in mappings.items():
for global_rank, joiner_id in mappings.items():
self.global_store.set(f"rank_{joiner_id}", str(global_rank))
new_world_size += 1
self.global_store.set("world_size", str(new_world_size))
Expand Down Expand Up @@ -310,15 +317,21 @@ def _resolve_world(self, admit_joiners: bool = False) -> bool:
return False

# Remap live ranks to smaller world_size caused by dead nodes
self._grank_to_gid = _read_topo_file()
leaving_ranks = set(dead_nodes)
for rank in dead_nodes:
del self._grank_to_gid[rank]
live_ranks = [i for i in range(self.world_info.global_world_size) if i not in leaving_ranks]
for i, rank in enumerate(live_ranks):
self.global_store.set(f"rank_map_{rank}", str(i))
self.global_store.set(f"rank_{self._grank_to_gid[rank]}", str(i))
for i, rank in enumerate(live_ranks):
self._grank_to_gid[i] = self._grank_to_gid[rank]
new_world_size = len(live_ranks)

# Give joiners new ranks
for joiner_id in joiners:
self.global_store.set(f"rank_{joiner_id}", str(new_world_size))
self._grank_to_gid[new_world_size] = joiner_id
new_world_size += 1

for i in range(1, new_world_size):
Expand All @@ -328,7 +341,36 @@ def _resolve_world(self, admit_joiners: bool = False) -> bool:
self.global_store.set("mesh_count", str(self.mesh_count + 1))
# Set status to "reinit"
self.global_store.set("status", "reinit")

_write_topo_file(self._grank_to_gid)
return True

def ping_peer(self, peer_rank: int) -> float:
tensor = torch.ones(1_000_000, dtype=torch.float32)
start_time = time.perf_counter()
self.global_pg.send([tensor], peer_rank).wait()
end_time = time.perf_counter()
return end_time - start_time

def _measure_communication_cost(self):
tensor = torch.ones(1_000_000, dtype=torch.float32)
rank = self.world_info.global_rank
world_size = self.world_info.global_world_size

# Measure the time to send the tensor to all other ranks
start_time = time.perf_counter()
for i in range(world_size):
if i != rank:
work = self.global_pg.send([tensor], i)
end_time = time.perf_counter()

# Calculate the cost
cost = end_time - start_time

# Store the cost in the KV Store
self.global_store.set(f"comm_cost_{rank}", str(cost))

self._logger.info("Communication cost for rank %d: %f seconds", rank, cost)

def maybe_reinit_global_pg(self, admit_joiners: bool = False) -> bool:
"""Reinitialize the global_pg if there are is a state change.
Expand Down Expand Up @@ -373,7 +415,7 @@ def maybe_reinit_global_pg(self, admit_joiners: bool = False) -> bool:
# Check if we got remapped
old_global_rank = self.world_info.global_rank
self.world_info.global_rank = int(
self.global_store.get(f"rank_map_{self.world_info.global_rank}").decode("utf-8")
self.global_store.get(f"rank_{self.world_info.global_unique_id}").decode("utf-8")
)

self.world_info.global_world_size = int(self.global_store.get("world_size").decode("utf-8"))
Expand All @@ -400,7 +442,6 @@ def maybe_reinit_global_pg(self, admit_joiners: bool = False) -> bool:

# Update rank if needed (otherwise, the next remap will do the lookup incorrectly)
if old_global_rank != self.world_info.global_rank:
self.global_store.set(f"rank_{self.world_info.global_unique_id}", str(self.world_info.global_rank))
self.live_recovery.reset()

self._logger.debug("Reinitialized global_pg done in %s seconds", time.perf_counter() - time_start)
Expand Down
4 changes: 4 additions & 0 deletions topo.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
gid,grank
0,0
2,1
1,1

0 comments on commit f3c770c

Please sign in to comment.