-
Notifications
You must be signed in to change notification settings - Fork 186
/
Copy pathmain.py
68 lines (56 loc) · 2.86 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from __future__ import print_function
import argparse
from torch.utils.data import DataLoader
from DBPN.solver import DBPNTrainer
from DRCN.solver import DRCNTrainer
from EDSR.solver import EDSRTrainer
from FSRCNN.solver import FSRCNNTrainer
from SRCNN.solver import SRCNNTrainer
from SRGAN.solver import SRGANTrainer
from SubPixelCNN.solver import SubPixelTrainer
from VDSR.solver import VDSRTrainer
from dataset.data import get_training_set, get_test_set
# ===========================================================
# Training settings
# ===========================================================
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
# hyper-parameters
parser.add_argument('--batchSize', type=int, default=1, help='training batch size')
parser.add_argument('--testBatchSize', type=int, default=1, help='testing batch size')
parser.add_argument('--nEpochs', type=int, default=20, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.01, help='Learning Rate. Default=0.01')
parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
# model configuration
parser.add_argument('--upscale_factor', '-uf', type=int, default=4, help="super resolution upscale factor")
parser.add_argument('--model', '-m', type=str, default='srgan', help='choose which model is going to use')
args = parser.parse_args()
def main():
# ===========================================================
# Set train dataset & test dataset
# ===========================================================
print('===> Loading datasets')
train_set = get_training_set(args.upscale_factor)
test_set = get_test_set(args.upscale_factor)
training_data_loader = DataLoader(dataset=train_set, batch_size=args.batchSize, shuffle=True)
testing_data_loader = DataLoader(dataset=test_set, batch_size=args.testBatchSize, shuffle=False)
if args.model == 'sub':
model = SubPixelTrainer(args, training_data_loader, testing_data_loader)
elif args.model == 'srcnn':
model = SRCNNTrainer(args, training_data_loader, testing_data_loader)
elif args.model == 'vdsr':
model = VDSRTrainer(args, training_data_loader, testing_data_loader)
elif args.model == 'edsr':
model = EDSRTrainer(args, training_data_loader, testing_data_loader)
elif args.model == 'fsrcnn':
model = FSRCNNTrainer(args, training_data_loader, testing_data_loader)
elif args.model == 'drcn':
model = DRCNTrainer(args, training_data_loader, testing_data_loader)
elif args.model == 'srgan':
model = SRGANTrainer(args, training_data_loader, testing_data_loader)
elif args.model == 'dbpn':
model = DBPNTrainer(args, training_data_loader, testing_data_loader)
else:
raise Exception("the model does not exist")
model.run()
if __name__ == '__main__':
main()