-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathplot_collective_by_group_call.py
71 lines (55 loc) · 2.33 KB
/
plot_collective_by_group_call.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import config
import sys
import datetime
import time
import torch
import torch.fx
import numpy as np
import hap
import collectives
def run(global_rank, local_rank, max_ratio, queue):
import torch.distributed as dist
dist.init_process_group('nccl', rank=global_rank, timeout=datetime.timedelta(hours=2))
total_length = 4 * 1024 # 4MB
sharding_lengths = [max_ratio] + [(1 - max_ratio) / (config.world_size - 1)] * (config.world_size - 1)
sharding_lengths = [ x / sum(sharding_lengths) for x in sharding_lengths]
hap.sharding_round(total_length, sharding_lengths)
if local_rank == 0:
print("sharding_lengths:", sharding_lengths)
tensor = torch.rand(256, sharding_lengths[global_rank]).to(local_rank)
result_times = []
last_iter_time = time.time()
for iter in range(config.run_iter):
collectives.all_gather(tensor, 1, sharding_lengths, global_rank)
# collectives.all_gather_by_group_call(tensor, 1, sharding_lengths, global_rank)
# torch.cuda.synchronize()
dist.barrier()
if local_rank == 0:
iter_duration = time.time() - last_iter_time
result_times.append(iter_duration)
last_iter_time += iter_duration
print("iter time: ", iter_duration)
print("avg±std:", np.mean(result_times[-config.avg_iter:]), np.std(result_times[-config.avg_iter:]))
if local_rank == 0:
queue.put(np.mean(result_times[-config.avg_iter:]))
if __name__ == '__main__':
ranks = [ int(x) for x in sys.argv[1].split(',') ]
# if torch.cuda.device_count() != len(ranks):
# print("forget to set CUDA_VISIBLE_DEVICES")
# raise SystemExit
import os
os.environ['MASTER_ADDR'] = str(config.master_addr)
os.environ['MASTER_PORT'] = str(config.master_port)
os.environ['WORLD_SIZE'] = str(config.world_size)
import torch.multiprocessing as mp
ctx = mp.get_context('spawn')
queue = ctx.Queue(1)
result = []
for max_ratio in np.linspace(1 / config.world_size, 1, 100, endpoint=False):
for local_rank, global_rank in enumerate(ranks):
ctx.Process(target=run, args=(global_rank, local_rank, max_ratio, queue)).start()
for p in mp.active_children():
p.join()
result.append((max_ratio, queue.get()))
print(result[-1])
print(result)