forked from open-mmlab/mmaction2
-
Notifications
You must be signed in to change notification settings - Fork 204
/
Copy pathrecognizer3d.py
119 lines (99 loc) · 3.8 KB
/
recognizer3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from torch import nn
from ..builder import RECOGNIZERS
from .base import BaseRecognizer
@RECOGNIZERS.register_module()
class Recognizer3D(BaseRecognizer):
"""3D recognizer model framework."""
def forward_train(self, imgs, labels, **kwargs):
"""Defines the computation performed at every call when training."""
assert self.with_cls_head
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
losses = dict()
x = self.extract_feat(imgs)
if self.with_neck:
x, loss_aux = self.neck(x, labels.squeeze())
losses.update(loss_aux)
cls_score = self.cls_head(x)
gt_labels = labels.squeeze()
loss_cls = self.cls_head.loss(cls_score, gt_labels, **kwargs)
losses.update(loss_cls)
return losses
def _do_test(self, imgs):
"""Defines the computation performed at every call when evaluation,
testing and gradcam."""
batches = imgs.shape[0]
num_segs = imgs.shape[1]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
if self.max_testing_views is not None:
total_views = imgs.shape[0]
assert num_segs == total_views, (
'max_testing_views is only compatible '
'with batch_size == 1')
view_ptr = 0
feats = []
while view_ptr < total_views:
batch_imgs = imgs[view_ptr:view_ptr + self.max_testing_views]
x = self.extract_feat(batch_imgs)
if self.with_neck:
x, _ = self.neck(x)
feats.append(x)
view_ptr += self.max_testing_views
# should consider the case that feat is a tuple
if isinstance(feats[0], tuple):
len_tuple = len(feats[0])
feat = [
torch.cat([x[i] for x in feats]) for i in range(len_tuple)
]
feat = tuple(feat)
else:
feat = torch.cat(feats)
else:
feat = self.extract_feat(imgs)
if self.with_neck:
feat, _ = self.neck(feat)
if self.feature_extraction:
# perform spatio-temporal pooling
avg_pool = nn.AdaptiveAvgPool3d(1)
if isinstance(feat, tuple):
feat = [avg_pool(x) for x in feat]
# concat them
feat = torch.cat(feat, axis=1)
else:
feat = avg_pool(feat)
# squeeze dimensions
feat = feat.reshape((batches, num_segs, -1))
# temporal average pooling
feat = feat.mean(axis=1)
return feat
# should have cls_head if not extracting features
assert self.with_cls_head
cls_score = self.cls_head(feat)
cls_score = self.average_clip(cls_score, num_segs)
return cls_score
def forward_test(self, imgs):
"""Defines the computation performed at every call when evaluation and
testing."""
return self._do_test(imgs).cpu().numpy()
def forward_dummy(self, imgs, softmax=False):
"""Used for computing network FLOPs.
See ``tools/analysis/get_flops.py``.
Args:
imgs (torch.Tensor): Input images.
Returns:
Tensor: Class score.
"""
assert self.with_cls_head
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
x = self.extract_feat(imgs)
if self.with_neck:
x, _ = self.neck(x)
outs = self.cls_head(x)
if softmax:
outs = nn.functional.softmax(outs)
return (outs, )
def forward_gradcam(self, imgs):
"""Defines the computation performed at every call when using gradcam
utils."""
assert self.with_cls_head
return self._do_test(imgs)