-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict.py
140 lines (117 loc) · 5.55 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import argparse
import logging
import os
from datetime import datetime
from gln.common.cmd_args import cmd_args as gln_args
from models.gln_model.gln_predictor import GLNPredictor
from models.localretro_model import localretro_parser
from models.localretro_model.localretro_predictor import LocalRetroPredictor
from models.neuralsym_model import neuralsym_parser
from models.neuralsym_model.neuralsym_predictor import NeuralSymPredictor
from models.retrocomposer_model import retrocomposer_parser
from models.retrocomposer_model.retrocomposer_predictor import RetroComposerPredictor
from models.retroxpert_model.retroxpert_predictor import RetroXpertPredictor
from models.transformer_model.transformer_predictor import TransformerPredictor
from onmt import opts as onmt_opts
from onmt.bin.translate import _get_parser as transformer_parser
from rdkit import RDLogger
from utils import misc
def get_predict_parser():
parser = argparse.ArgumentParser("predict.py", conflict_handler="resolve") # TODO: this is a hardcode
parser.add_argument("--test_all_ckpts", help="whether to test all checkpoints", action="store_true")
parser.add_argument("--model_name", help="model name", type=str, default="")
parser.add_argument("--data_name", help="name of dataset, for easier reference", type=str, default="")
parser.add_argument("--log_file", help="log file", type=str, default="")
parser.add_argument("--config_file", help="model config file (optional)", type=str, default="")
parser.add_argument("--train_file", help="train SMILES file", type=str, default="")
parser.add_argument("--val_file", help="validation SMILES files", type=str, default="")
parser.add_argument("--test_file", help="test SMILES files", type=str, default="")
parser.add_argument("--processed_data_path", help="output path for processed data", type=str, default="")
parser.add_argument("--model_path", help="model output path", type=str, default="")
parser.add_argument("--test_output_path", help="test output path", type=str, default="")
return parser
def predict_main(args, predict_parser):
misc.log_args(args, message="Logging arguments")
os.makedirs(args.test_output_path, exist_ok=True)
model_name = ""
model_args = None
model_config = {}
data_name = args.data_name
raw_data_files = []
processed_data_path = args.processed_data_path
model_path = args.model_path
test_output_path = args.test_output_path
if args.model_name == "gln":
# Overwrite default gln_args with runtime args
gln_args.test_all_ckpts = args.test_all_ckpts
model_name = "gln"
model_args = gln_args
raw_data_files = [args.train_file, args.val_file, args.test_file]
PredictorClass = GLNPredictor
elif args.model_name == "localretro":
localretro_parser.add_model_opts(predict_parser)
localretro_parser.add_train_opts(predict_parser)
localretro_parser.add_predict_opts(predict_parser)
model_args, _unknown = predict_parser.parse_known_args()
model_name = "localretro"
raw_data_files = [args.test_file]
PredictorClass = LocalRetroPredictor
elif args.model_name == "transformer":
# adapted from onmt.bin.translate.main()
parser = transformer_parser()
opt, _unknown = parser.parse_known_args()
# update runtime args
opt.config = args.config_file
opt.log_file = args.log_file
model_name = "transformer"
model_args = opt
PredictorClass = TransformerPredictor
elif args.model_name == "retroxpert":
onmt_opts.config_opts(predict_parser)
onmt_opts.translate_opts(predict_parser)
model_args, _unknown = predict_parser.parse_known_args()
# update runtime args
model_args.config = args.config_file
model_args.log_file = args.log_file
model_name = "retroxpert"
PredictorClass = RetroXpertPredictor
elif args.model_name == "retrocomposer":
retrocomposer_parser.add_model_opts(predict_parser)
retrocomposer_parser.add_train_opts(predict_parser)
retrocomposer_parser.add_predict_opts(predict_parser)
model_args, _unknown = predict_parser.parse_known_args()
model_name = "retrocomposer"
PredictorClass = RetroComposerPredictor
elif args.model_name == "neuralsym":
neuralsym_parser.add_model_opts(predict_parser)
neuralsym_parser.add_train_opts(predict_parser)
neuralsym_parser.add_predict_opts(predict_parser)
model_name = "neuralsym"
model_args, _unknown = predict_parser.parse_known_args()
raw_data_files = [args.test_file]
PredictorClass = NeuralSymPredictor
else:
raise ValueError(f"Model {args.model_name} not supported!")
logging.info("Start predicting")
predictor = PredictorClass(
model_name=model_name,
model_args=model_args,
model_config=model_config,
data_name=data_name,
raw_data_files=raw_data_files,
processed_data_path=processed_data_path,
model_path=model_path,
test_output_path=test_output_path
)
predictor.predict()
if __name__ == "__main__":
predict_parser = get_predict_parser()
args, unknown = predict_parser.parse_known_args()
# logger setup
RDLogger.DisableLog("rdApp.warning")
os.makedirs("./logs/predict", exist_ok=True)
dt = datetime.strftime(datetime.now(), "%y%m%d-%H%Mh")
args.log_file = f"./logs/predict/{args.log_file}.{dt}"
misc.setup_logger(args.log_file)
# predict interface
predict_main(args, predict_parser)