-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgan_ensemble.py
33 lines (25 loc) · 985 Bytes
/
gan_ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from torch import nn
import timm
import os
class GanEnsemble(nn.Module):
def __init__(self, model_names, num_classes=1, ckpt_path=None):
super(GanEnsemble, self).__init__()
self.models = nn.ModuleList()
for name in model_names:
self.models.append(timm.create_model(name, num_classes=num_classes))
# load weights:
self.models[0].load_state_dict(
torch.load(os.path.join(ckpt_path, "resnet50.pt"), map_location="cpu",)
)
self.models[1].load_state_dict(
torch.load(os.path.join(ckpt_path, "swin-tiny.pt"), map_location="cpu",)
)
self.models[2].load_state_dict(
torch.load(os.path.join(ckpt_path, "vit-small.pt"), map_location="cpu",)
)
def forward(self, x):
res = torch.cat([model(x) for model in self.models], dim=1)
res = torch.sigmoid(res)
res = res.mean(dim=1)
return res