-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrun_exp_n3.py
35 lines (28 loc) · 916 Bytes
/
run_exp_n3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from src.model_run import Model_Run
from src.args import read_args
from src.input_analysis import checkflow
from src.res_analysis import workflow
import os
if __name__ == '__main__':
args = read_args()
args.experiment_name = 'GMUC_N3' # Experiment ID
args.set_aggregator = 'GMUC'
args.datapath = './data/MAGA-PLUS-NL27K/NL27K-N3'
args.eval_every = 250
args.max_batches = 60000
args.if_conf = 1
args.rank_weight = 1.00
args.ae_weight = 1.00
# make Experiment dir
exp_path = './Experiments/' + args.experiment_name
if(os.path.exists(exp_path) == False):
os.makedirs(exp_path)
if(os.path.exists(exp_path + '/checkpoints') == False):
os.makedirs(exp_path + '/checkpoints')
# input dataset analysis
# checkflow(args)
# model execution
model_run = Model_Run(args)
model_run.train()
# result analysis
workflow(args)