-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathmultiproc.py
33 lines (28 loc) · 899 Bytes
/
multiproc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import sys
import subprocess
argslist = list(sys.argv)[1:]
world_size = torch.cuda.device_count()
if '--world-size' in argslist:
argslist[argslist.index('--world-size') + 1] = str(world_size)
else:
argslist.append('--world-size')
argslist.append(str(world_size))
workers = []
for i in range(world_size):
if '--rank' in argslist:
argslist[argslist.index('--rank') + 1] = str(i)
else:
argslist.append('--rank')
argslist.append(str(i))
if '--gpu-rank' in argslist:
argslist[argslist.index('--gpu-rank') + 1] = str(i)
else:
argslist.append('--gpu-rank')
argslist.append(str(i))
stdout = None if i == 0 else open("GPU_" + str(i) + ".log", "w")
print(argslist)
p = subprocess.Popen([str(sys.executable)] + argslist, stdout=stdout, stderr=stdout)
workers.append(p)
for p in workers:
p.wait()