-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvae.py
96 lines (79 loc) · 2.88 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torchvision
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST
import os
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(256*256*3, 400)
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 256*256*3)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
def reparametrize(self, mu, logvar):
std = logvar.mul(0.5).exp_()
if torch.cuda.is_available():
eps = torch.cuda.FloatTensor(std.size()).normal_()
else:
eps = torch.FloatTensor(std.size()).normal_()
eps = Variable(eps)
return eps.mul(std).add_(mu)
def decode(self, z):
h3 = F.relu(self.fc3(z))
return F.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparametrize(mu, logvar)
return self.decode(z), mu, logvar
# model = VAE()
reconstruction_function = nn.MSELoss(size_average=False)
def loss_function(recon_x, x, mu, logvar):
"""
recon_x: generating images
x: origin images
mu: latent mean
logvar: latent log variance
"""
BCE = reconstruction_function(recon_x, x) # mse loss
# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
KLD = torch.sum(KLD_element).mul_(-0.5)
# KL divergence
return BCE + KLD
# optimizer = optim.Adam(model.parameters(), lr=1e-3)
# for epoch in range(num_epochs):
# train_loss = 0
# for batch_idx, data in enumerate(dataloader):
# img, _ = data
# img = img.view(img.size(0), -1)
# img = Variable(img)
# if torch.cuda.is_available():
# img = img.cuda()
# optimizer.zero_grad()
# recon_batch, mu, logvar = model(img)
# loss = loss_function(recon_batch, img, mu, logvar)
# loss.backward()
# train_loss += loss.item()
# optimizer.step()
# if batch_idx % 100 == 0:
# print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
# epoch,
# batch_idx * len(img),
# len(dataloader.dataset), 100. * batch_idx / len(dataloader),
# loss.item() / len(img)))
# print('====> Epoch: {} Average loss: {:.4f}'.format(
# epoch, train_loss / len(dataloader.dataset)))
# if epoch % 10 == 0:
# save = to_img(recon_batch.cpu().data)
# save_image(save, './vae_img/image_{}.png'.format(epoch))
# torch.save(model.state_dict(), './vae.pth')