-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgeneral_launch.py
62 lines (49 loc) · 1.99 KB
/
general_launch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
### This script requires Python >= 3.7
import argparse
import cloudpickle
import torch
torch.multiprocessing.set_start_method('forkserver', force=True)
from multiprocessing import Process, Queue
from mars.env.import_env import make_env
from mars.rl.agents import *
from mars.rl.agents.multiagent import MultiAgent
from mars.utils.func import multiprocess_conf
from mars.rolloutExperience import rolloutExperience
from mars.updateModel import updateModel
from mars.utils.args_parser import get_args
parser = argparse.ArgumentParser(description='Arguments of the general launching script for MARS.')
def launch():
args = get_args()
env = args.env_name
method = args.marl_method
multiprocess_conf(args, method)
### Create env
env = make_env(args)
print(env)
### Specify models for each agent
model1 = eval(args.algorithm)(env, args)
model2 = eval(args.algorithm)(env, args)
model = MultiAgent(env, [model1, model2], args)
print(args)
env.close()
# tranform dictionary to bytes (serialization)
# args = cloudpickle.dumps(args)
# env = cloudpickle.dumps(env) # this only works for single env, not for multiprocess vecenv
processes = []
# launch multiple sample rollout processes
info_queue = Queue()
for pro_id in range(1):
play_process = Process(target=rolloutExperience, args = (model, info_queue, args, str(args.save_id)+'-'+str(pro_id)))
play_process.daemon = True # sub processes killed when main process finish
processes.append(play_process)
# launch update process (single or multiple)
for pro_id in range(args.num_process):
update_process = Process(target=updateModel, args= (model, info_queue, args, str(args.save_id)+'-'+str(pro_id)))
update_process.daemon = True
processes.append(update_process)
[p.start() for p in processes]
while all([p.is_alive()for p in processes]):
pass
[p.join() for p in processes]
if __name__ == '__main__':
launch()