Skip to content

Commit

Permalink
STASH
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackmin801 committed Nov 9, 2024
1 parent 6586714 commit 211e58e
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 6 deletions.
9 changes: 8 additions & 1 deletion configs/150M/3090.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@ micro_bs = 16 # change this base on the gpu
reshard_after_forward = true

[optim]
batch_size = 512
batch_size = 64
warmup_steps = 1000
total_steps = 88_000
lr = 4e-4

[data]
fake = true

[diloco]
inner_steps = 20

86 changes: 86 additions & 0 deletions meow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import List
import os
import torch.distributed as dist
import torch
import time
import toposolve

rank = int(os.environ['RANK'])
def rprint(*args):
print(f"[Rank {rank}] {' '.join(map(str, args))}\n", end="")

BENCH_TENSOR_SIZE = 1_000_000

class EDM:
def __init__(self):
master_addr = os.environ['MASTER_ADDR']
master_port = int(os.environ['MASTER_PORT'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])

self.rank = rank
self.world_size = world_size

rprint("Creating Store")
self.global_store = dist.TCPStore(
host_name=master_addr,
port=master_port + 1,
is_master=(rank==0),
world_size=2,
)

rprint("Store created. Creating ProcessGroupGloo")
self.global_pg = dist.distributed_c10d.ProcessGroupGloo(self.global_store, rank, world_size)
rprint("ProcessGroupGloo created")

self._measure_connectivity()
if rank == 0:
pings = self.get_pings()
print(*pings, sep="\n")
min_dist, path = toposolve.TSPSolver().solve_tsp(pings)
print(f"Min distance: {min_dist}")
print(f"Path: {path}")

def get_pings(self) -> List[List[int]]:
pings = [[1000_000_000] * self.world_size for _ in range(self.world_size)]
for i in range(self.world_size):
for j in range(self.world_size):
if i == j:
continue
pings[i][j] = int(self.global_store.get(f"ping_{i}_{j}"))
return pings

def _measure_connectivity(self):
# Recv from all other peers
recv_work = []
tensor = torch.ones(BENCH_TENSOR_SIZE, dtype=torch.float32)
for i in range(self.world_size):
if i == self.rank:
continue
recv_work.append(self.global_pg.recv([tensor], i, self.rank + self.world_size * i))

# Ping all other peers
for i in range(self.world_size):
if i == self.rank:
continue
time_taken = self._ping_peer(i)
self.global_store.set(f"ping_{self.rank}_{i}", str(time_taken))

# Wait for all recv operations to complete
for work in recv_work:
work.wait()

def _ping_peer(self, peer_rank: int) -> int:
"""Ping a peer and return the time taken in microseconds"""
tensor = torch.ones(BENCH_TENSOR_SIZE, dtype=torch.float32)
start_time = time.perf_counter()
self.global_pg.send([tensor], peer_rank, self.rank * self.world_size + peer_rank).wait()
end_time = time.perf_counter()
return int((end_time - start_time) * 1e6)

def main():
edm = EDM()


if __name__ == "__main__":
main()
18 changes: 18 additions & 0 deletions meow0.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export WANDB_MODE=disabled
export GLOBAL_ADDR=localhost
export GLOBAL_PORT=1234
export GLOBAL_WORLD_SIZE=2

export CUDA_VISIBLE_DEVICES=0,1
export GLOBAL_UNIQUE_ID=0
export GLOBAL_RANK=0

export GLOO_SOCKET_IFNAME=tailscale0
export ZERO_BAND_LOG_LEVEL=DEBUG
export ZERO_BAND_LOG_ALL_RANK=true

uv run torchrun --nproc_per_node=2 \
--rdzv-endpoint localhost:1000$GLOBAL_UNIQUE_ID \
src/zeroband/train.py \
@configs/150M/3090.toml \
--no-wandb-resume
18 changes: 18 additions & 0 deletions meow1.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export WANDB_MODE=disabled
export GLOBAL_ADDR=localhost
export GLOBAL_PORT=1234
#export GLOBAL_WORLD_SIZE=2

export CUDA_VISIBLE_DEVICES=4,5
export GLOBAL_UNIQUE_ID=1
export GLOBAL_RANK=100

#export GLOO_SOCKET_IFNAME=tailscale0
export ZERO_BAND_LOG_LEVEL=DEBUG
export ZERO_BAND_LOG_ALL_RANK=true

uv run torchrun --nproc_per_node=2 \
--rdzv-endpoint localhost:1000$GLOBAL_UNIQUE_ID \
src/zeroband/train.py \
@configs/150M/3090.toml \
--no-wandb-resume
18 changes: 18 additions & 0 deletions meow2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export WANDB_MODE=disabled
export GLOBAL_ADDR=localhost
export GLOBAL_PORT=1234
#export GLOBAL_WORLD_SIZE=2

export CUDA_VISIBLE_DEVICES=2,3
export GLOBAL_UNIQUE_ID=2
export GLOBAL_RANK=100

#export GLOO_SOCKET_IFNAME=tailscale0
export ZERO_BAND_LOG_LEVEL=DEBUG
export ZERO_BAND_LOG_ALL_RANK=true

uv run torchrun --nproc_per_node=2 \
--rdzv-endpoint localhost:1000$GLOBAL_UNIQUE_ID \
src/zeroband/train.py \
@configs/150M/3090.toml \
--no-wandb-resume
17 changes: 12 additions & 5 deletions src/zeroband/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,14 +270,21 @@ def _optimize_ring_ranks(self):
self._measure_connectivity()
self.global_pg.barrier()
self._logger.debug(f"Time taken to measure connectivity: {time.perf_counter() - start_time}")
start_time = time.perf_counter()
self._logger.debug("Calculating TSP")
if self._global_leader:
start_time = time.perf_counter()
self._logger.debug("Calculating TSP")
pings = self.get_pings()
min_dist, path = toposolve.TSPSolver().solve_tsp(pings)
print(f"Min distance: {min_dist}")
print(f"Path: {path}")
self._logger.debug(f"Time taken to calculate TSP: {time.perf_counter() - start_time}")
self._logger.debug(f"Min distance: {min_dist}")
self._logger.debug(f"Path: {path}")
for i, rank in enumerate(path[1:-1]):
self.global_store.set(f"rank_map_{rank}", str(i))
self._logger.debug(f"Time taken to calculate TSP: {time.perf_counter() - start_time}")

# Update world_size
self.global_store.set("mesh_count", str(self.mesh_count + 1))
# Set status to "reinit"
self.global_store.set("status", "reinit")

def _resolve_world(self, admit_joiners: bool = False) -> bool:
"""Set the new world size and ranks for all nodes if there are joiners or dead nodes. Else, do nothing.
Expand Down
23 changes: 23 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 211e58e

Please sign in to comment.