-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsang_gan.py
98 lines (79 loc) · 3.5 KB
/
sang_gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import autograd
# G(z)
# Generate fake image
class generator(nn.Module):
# initializers
def __init__(self):
super(generator, self).__init__()
self.deconv1 = nn.ConvTranspose2d(100, 512, 4, 1, 0)
self.deconv2 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
self.deconv2_bn = nn.BatchNorm2d(256)
self.deconv3 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
self.deconv3_bn = nn.BatchNorm2d(128)
self.deconv4 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
self.deconv4_bn = nn.BatchNorm2d(64)
self.deconv5 = nn.ConvTranspose2d(64, 3, 4, 2, 1)
# weight_init
def weight_init(self):
for m in self._modules:
#normal_init(self._modules[m], mean, std)
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight)
nn.init.xavier_normal_(m.bias)
# forward method
def forward(self, inputs):
x = F.relu(self.deconv1(inputs))
x = F.relu(self.deconv2_bn(self.deconv2(x)))
x = F.relu(self.deconv3_bn(self.deconv3(x)))
x = F.relu(self.deconv4_bn(self.deconv4(x)))
x = torch.tanh(self.deconv5(x))
return x
class discriminator(nn.Module):
# initializers
def __init__(self):
super(discriminator, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)
self.conv2 = nn.Conv2d(64, 128, 4, 2, 1)
self.conv3 = nn.Conv2d(128, 256, 4, 2, 1)
#self.conv4 = nn.Conv2d(256, 128, 4, 2, 1)
self.conv4 = nn.Conv2d(256, 1, 4, 1, 0)
# weight_init
def weight_init(self):
for m in self._modules:
#normal_init(self._modules[m], mean, std)
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight)
nn.init.xavier_normal_(m.bias)
# forward method
def forward(self, inputs):
x = F.leaky_relu(self.conv1(inputs), 0.2)
x = F.leaky_relu(self.conv2(x), 0.2)
x = F.leaky_relu(self.conv3(x), 0.2)
#x = F.leaky_relu(self.conv4(x), 0.2)
x = self.conv4(x)
return x
# Initailize weight of CNN
def normal_init(m, mean, std):
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
m.weight.data.normal_(mean, std)
m.bias.data.zero_()
def calculate_gradient_penalty(Discriminator, real_images, fake_images, lambda_gp):
eta = torch.FloatTensor(real_images.size(0), 1, 1, 1).uniform_(0, 1)
eta = eta.expand(real_images.size(0), real_images.size(1), real_images.size(2), real_images.size(3))
eta = eta.cuda()
interpolated = eta * real_images + ((1 - eta) * fake_images)
interpolated = interpolated.cuda()
# define it to calculate gradient
interpolated = Variable(interpolated, requires_grad=True)
# calculate probability of interpolated examples
prob_interpolated = Discriminator(interpolated)
# calculate gradients of probabilities with respect to examples
gradients = autograd.grad(outputs=prob_interpolated, inputs=interpolated,
grad_outputs=torch.ones(prob_interpolated.size()).cuda(),
create_graph=True, retain_graph=True)[0]
grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gp
return grad_penalty