-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
137 lines (105 loc) · 4.08 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import argparse
import logging
import re
from itertools import product
from pathlib import Path
from pdb import run
import matplotlib.pyplot as plt
import torch
from easydict import EasyDict as edict
from tqdm import tqdm
from smiles_cl.constants import DEFAULT_EVALUATION_DATASETS, RE_CHECKPOINT
from smiles_cl.evaluation.callbacks import EvaluationCallback
from smiles_cl.evaluation.plotting import create_evaluation_summary_plot
from smiles_cl.types import PathLike
def wrap_autocast(fn):
def wrapped_fn(*args, **kwargs):
with torch.autocast("cuda"):
return fn(*args, **kwargs)
return wrapped_fn
def save_run_summary(run_dir: PathLike, **kwargs):
run_dir = Path(run_dir)
summary_fig = create_evaluation_summary_plot(run_dir, **kwargs)
summary_fig.savefig(run_dir / "summary.png", bbox_inches="tight")
plt.close(summary_fig)
def evaluate_run(args):
logger = logging.getLogger("smiles_cl")
logger.setLevel(args.log_level.upper())
eval_callback = EvaluationCallback(
datasets=args.datasets,
modalities=args.modalities,
batch_size=args.batch_size,
output_dir=args.output_dir,
device=args.device,
log_summary=False,
)
checkpoints = list(map(Path, args.checkpoints))
assert set(ckpt.suffix for ckpt in checkpoints) == {".ckpt"}, set(
ckpt.suffix for ckpt in checkpoints
)
assert (
len(set(ckpt.parent.parent for ckpt in checkpoints)) == 1
), "Found checkpoints belonging to multiple runs"
run_dir = checkpoints[0].parent.parent
logger_dummy = edict(
experiment=edict(id=run_dir.name, log=lambda *args, **kwargs: None)
)
it = list(product(args.modalities, checkpoints))
if args.command == "run":
it = tqdm(it)
for modality, checkpoint in it:
match = RE_CHECKPOINT.match(checkpoint.name)
if match is None:
print(f"Skipping checkpoint: {checkpoint.name}")
continue
eval_callback.evaluate_modality(
ckpt_path=str(checkpoint),
step_id=match.group("step_id"),
modality=modality,
logger=logger_dummy,
)
if args.create_summary:
for modality in args.modalities:
modality_dir = (
eval_callback.output_dir / logger_dummy.experiment.id / modality
)
save_run_summary(modality_dir)
def create_argparser():
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--modalities", nargs="+", default=["smiles"])
parser.add_argument("--datasets", nargs="+", default=DEFAULT_EVALUATION_DATASETS)
parser.add_argument("--output_dir", default="evaluation")
parser.add_argument("--create_summary", action="store_true")
parser.add_argument("--device", default="cuda")
parser.add_argument("--log-level", "-l", default="INFO")
subparsers = parser.add_subparsers(
dest="command",
required=True,
)
ckpt_parser = subparsers.add_parser("checkpoint")
ckpt_parser.add_argument("checkpoint_file")
run_parser = subparsers.add_parser("run")
run_parser.add_argument("run_dir")
run_parser = subparsers.add_parser("create_summary")
run_parser.add_argument("run_dir")
run_parser.add_argument("--plot_confidence_intervals", action="store_true")
run_parser.add_argument("--only_best_per_dataset", action="store_true")
return parser
if __name__ == "__main__":
parser = create_argparser()
args = parser.parse_args()
if args.command == "checkpoint":
args.checkpoints = [args.checkpoint_file]
evaluate_run(args)
elif args.command == "run":
args.checkpoints = list(Path(args.run_dir).glob("checkpoints/*.ckpt"))
evaluate_run(args)
elif args.command == "create_summary":
save_run_summary(
run_dir=args.run_dir,
plot_confidence_intervals=args.plot_confidence_intervals,
only_best_per_dataset=args.only_best_per_dataset,
)
else:
raise ValueError(f"Unknown command: {args.command}")