-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
91 lines (62 loc) · 1.96 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import torch.nn as nn
from torch.autograd import Variable
class ContentLoss(nn.Module):
"""
ContentLoss: Computes the "content loss" between
the feature representations of a white noise image at
layer_i vs the feature representations of the actual image
at layer_i.
Here weight is the content_loss weight (or alpha in the paper)
"""
def __init__(self, weight, target):
super(ContentLoss, self).__init__()
self.target = target.detach() * weight
self.weight = weight
self.criteria = nn.MSELoss()
def forward(self, inp):
self.loss = self.criteria.forward(inp * self.weight, self.target)
self.output = inp
return self.output
def backward(self):
self.loss.backward(retain_variables=True)
return self.loss
class GramMatrix(nn.Module):
"""
GramMatrix: Calculate the correlation between
the vectorized feature maps in a conv layer activation
The output of the feature map is reshaped to KxN where
K: bsxnum_ch and N: hxw
The gram matrix is calculated as the inner product of the fm
"""
def forward(self, inp):
bs, ch, h, w = inp.size()
k = bs*ch
n = h*w
inp = inp.view(k,n)
gm = torch.mm(inp, inp.t())
return gm.div(k*n)
class StyleLoss(nn.Module):
"""
StyleLoss: Computes the "style loss" between
the white noise and the target artwork.
The style loss is calculated as the MSELoss
between the gram matrices of the random image
and the artwork
Here the weight is style_loss weight(or beta in the paper)
"""
def __init__(self, weight, target):
super(StyleLoss, self).__init__()
self.target = target.detach() * weight
self.weight = weight
self.gram = GramMatrix()
self.criteria = nn.MSELoss()
def forward(self, inp):
self.output = inp.clone()
self.gm_inp = self.gram.forward(inp)
self.gm_inp.mul_(self.weight)
self.loss = self.criteria.forward(self.gm_inp, self.target)
return self.output
def backward(self):
self.loss.backward(retain_variables=True)
return self.loss