-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlinear_train.py
91 lines (73 loc) · 3 KB
/
linear_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from lpl.model import LPLVGG11
import torch
import torch.nn.functional as F
from torchvision.datasets import STL10
from torchvision.transforms import ToTensor
import numpy as np
import argparse
from pathlib import Path
parser = argparse.ArgumentParser(description='Train linear decoders from all VGG layers.')
parser.add_argument('--name', type=str, help='The model to be tested')
parser.add_argument('--device', type=str, default='cuda', help='Device (cuda, cpu)')
parser.add_argument('--avgpool', action='store_true',
help='Apply global average pooling to the layer before decoding.')
args = parser.parse_args()
model = LPLVGG11()
model.load_state_dict(torch.load(args.name))
exp_name = Path(args.name).stem
device = torch.device(args.device)
model.to(device)
ds = STL10(root='../datasets/', transform=ToTensor(), split='train')
ds_test = STL10(root='../datasets/', transform=ToTensor(), split='test')
dl = torch.utils.data.DataLoader(ds, batch_size=800)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=800)
AVGPOOL = args.avgpool
SIZE_MUL = 9 # 9 for STL, 1 for CIFAR (not tested!)
channel_n = np.asarray(model.C_SIZES)
maps_sizes_cifar = np.asarray(model.HW_SIZES)
OUT_SIZES = channel_n if AVGPOOL else channel_n*maps_sizes_cifar**2*SIZE_MUL
print(max(OUT_SIZES))
report_name = f"reports/{exp_name}_avgpool{AVGPOOL}.txt"
with open(report_name, "w") as rep:
rep.write(f"layer,accuracy\n")
for i, layer in enumerate(model.TEST_LAYERS):
submodel = model.model[:layer+1]
print(submodel)
linear = torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(OUT_SIZES[i], 10).to(device)
)
optimizer = torch.optim.Adam(linear.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(20):
for images, labels in dl:
images = images.to(device)
labels = labels.to(device)
representation = submodel(images)
if AVGPOOL:
representation = F.adaptive_avg_pool2d(representation, (1, 1))
out = linear(representation)
loss = criterion(out, labels)
_, pred = torch.max(out, axis=1)
acc = (pred == labels).float().mean().item()
# print("Batch accuracy:", acc)
# print("Loss", loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
accs = []
for images, labels in dl_test:
images = images.to(device)
labels = labels.to(device)
representation = submodel(images)
if AVGPOOL:
representation = F.adaptive_avg_pool2d(representation, (1, 1))
out = linear(representation)
_, pred = torch.max(out, axis=1)
acc = (pred == labels).float().mean().item()
accs.append(acc)
print(f"Layer {layer}, epoch {epoch}")
accuracy = sum(accs) / len(accs)
print("Accuracy", accuracy)
with open(report_name, "a") as rep:
rep.write(f"{layer},{accuracy}\n")