-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
84 lines (66 loc) · 2.9 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from src.models import create_model
from torch.optim.swa_utils import AveragedModel, get_ema_multi_avg_fn
from sklearn.metrics import accuracy_score, multilabel_confusion_matrix, classification_report
from src.utils.evaluation import AP_partial, spearman_correlation, showCM
from datasets import CUFED
from options.test_options import TestOptions
args = TestOptions().parse()
def evaluate(model, test_dataset, test_loader, device):
model.eval()
scores = torch.zeros((len(test_dataset), len(test_dataset.event_labels)), dtype=torch.float32)
attentions = []
importance_labels = []
gidx = 0
with torch.no_grad():
for batch in test_loader:
feats, _, importance_scores = batch
feats = feats.to(device)
logits, attention = model(feats)
shape = logits.shape[0]
scores[gidx:gidx+shape, :] = logits.cpu()
gidx += shape
attentions.append(attention)
importance_labels.append(importance_scores)
m = nn.Sigmoid(dim=1)
preds = m(scores)
preds[preds >= args.threshold] = 1
preds[preds < args.threshold] = 0
scores = scores.numpy()
preds = preds.numpy()
attention_tensor = torch.cat(attentions).to(device)
importance_labels = torch.cat(importance_labels).to(device)
acc = accuracy_score(test_dataset.labels, preds)
cms = multilabel_confusion_matrix(test_dataset.labels, preds)
cr = classification_report(test_dataset.labels, preds)
map_micro, map_macro = AP_partial(test_dataset.labels, scores)[1:3]
spearman = spearman_correlation(attention_tensor[:, 0, 1:], importance_labels)
return map_micro, map_macro, acc, spearman, cms, cr
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if args.dataset == 'cufed':
test_dataset = CUFED(root_dir=args.dataset_path, split_dir=args.split_dir, is_train=False, img_size=args.img_size, album_clip_length=args.album_clip_length)
else:
exit("Unknown dataset!")
test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, num_workers=args.num_workers, shuffle=False)
if args.verbose:
print("running on {}".format(device))
print("test_set={}".format(len(test_dataset)))
# Setup model
state = torch.load(args.model_path, map_location='cpu')
model = create_model(args).to(device)
if args.ema:
model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(0.999))
model.load_state_dict(state['model_state_dict'], strict=True)
print('load model from epoch {}'.format(state['epoch']))
t0 = time.perf_counter()
map_micro, map_macro, acc, spearman, cms, cr = evaluate(model, test_dataset, test_loader, device)
t1 = time.perf_counter()
print("map_micro={} map_macro={} accuracy={} spearman={} dt={:.2f}sec".format(map_micro, map_macro, acc * 100, spearman, t1 - t0))
print(cr)
showCM(cms)
if __name__ == '__main__':
main()