-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathloss.py
145 lines (116 loc) · 5.17 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
This is the implementation of following paper:
https://arxiv.org/pdf/1802.05591.pdf
This implementation is based on following code:
https://github.com/Wizaron/instance-segmentation-pytorch
"""
from torch.nn.modules.loss import _Loss
from torch.autograd import Variable
import torch
class DiscriminativeLoss(_Loss):
def __init__(self, delta_var=0.5, delta_dist=1.5,
norm=2, alpha=1.0, beta=1.0, gamma=0.001,
usegpu=True, size_average=True):
super(DiscriminativeLoss, self).__init__(size_average)
self.delta_var = delta_var
self.delta_dist = delta_dist
self.norm = norm
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.usegpu = usegpu
assert self.norm in [1, 2]
def forward(self, input, target, n_clusters):
# _assert_no_grad(target)
return self._discriminative_loss(input, target, n_clusters)
def _discriminative_loss(self, input, target, n_clusters):
bs, n_features, height, width = input.size()
max_n_clusters = target.size(1)
input = input.contiguous().view(bs, n_features, height * width)
target = target.contiguous().view(bs, max_n_clusters, height * width)
c_means = self._cluster_means(input, target, n_clusters)
l_var = self._variance_term(input, target, c_means, n_clusters)
l_dist = self._distance_term(c_means, n_clusters)
l_reg = self._regularization_term(c_means, n_clusters)
loss = self.alpha * l_var + self.beta * l_dist + self.gamma * l_reg
return loss
def _cluster_means(self, input, target, n_clusters):
bs, n_features, n_loc = input.size()
max_n_clusters = target.size(1)
# bs, n_features, max_n_clusters, n_loc
input = input.unsqueeze(2).expand(bs, n_features, max_n_clusters, n_loc)
# bs, 1, max_n_clusters, n_loc
target = target.unsqueeze(1)
# bs, n_features, max_n_clusters, n_loc
input = input * target
means = []
for i in range(bs):
# n_features, n_clusters, n_loc
input_sample = input[i, :, :n_clusters[i]]
# 1, n_clusters, n_loc,
target_sample = target[i, :, :n_clusters[i]]
# n_features, n_cluster
mean_sample = input_sample.sum(2) / (target_sample.sum(2) + 0.00001)
# padding
n_pad_clusters = max_n_clusters - n_clusters[i]
assert n_pad_clusters >= 0
if n_pad_clusters > 0:
pad_sample = torch.zeros(n_features, n_pad_clusters)
pad_sample = Variable(pad_sample)
if self.usegpu:
pad_sample = pad_sample.cuda()
mean_sample = torch.cat((mean_sample, pad_sample), dim=1)
means.append(mean_sample)
# bs, n_features, max_n_clusters
means = torch.stack(means)
return means
def _variance_term(self, input, target, c_means, n_clusters):
bs, n_features, n_loc = input.size()
max_n_clusters = target.size(1)
# bs, n_features, max_n_clusters, n_loc
c_means = c_means.unsqueeze(3).expand(bs, n_features, max_n_clusters, n_loc)
# bs, n_features, max_n_clusters, n_loc
input = input.unsqueeze(2).expand(bs, n_features, max_n_clusters, n_loc)
# bs, max_n_clusters, n_loc
var = (torch.clamp(torch.norm((input - c_means), self.norm, 1) -
self.delta_var, min=0) ** 2) * target
var_term = 0
for i in range(bs):
# n_clusters, n_loc
var_sample = var[i, :n_clusters[i]]
# n_clusters, n_loc
target_sample = target[i, :n_clusters[i]]
# n_clusters
c_var = var_sample.sum(1) / (target_sample.sum(1) + 0.00001)
var_term += c_var.sum() / int(n_clusters[i])
var_term /= bs
return var_term
def _distance_term(self, c_means, n_clusters):
bs, n_features, max_n_clusters = c_means.size()
dist_term = 0
for i in range(bs):
if n_clusters[i] <= 1:
continue
# n_features, n_clusters
mean_sample = c_means[i, :, :n_clusters[i]]
# n_features, n_clusters, n_clusters
means_a = mean_sample.unsqueeze(2).expand(n_features, n_clusters[i], n_clusters[i])
means_b = means_a.permute(0, 2, 1)
diff = means_a - means_b
margin = 2 * self.delta_dist * (1.0 - torch.eye(n_clusters[i]))
margin = Variable(margin)
if self.usegpu:
margin = margin.cuda()
c_dist = torch.sum(torch.clamp(margin - torch.norm(diff, self.norm, 0), min=0) ** 2)
dist_term += c_dist / (2 * n_clusters[i] * (n_clusters[i] - 1))
dist_term /= bs
return dist_term
def _regularization_term(self, c_means, n_clusters):
bs, n_features, max_n_clusters = c_means.size()
reg_term = 0
for i in range(bs):
# n_features, n_clusters
mean_sample = c_means[i, :, :n_clusters[i]]
reg_term += torch.mean(torch.norm(mean_sample, self.norm, 0))
reg_term /= bs
return reg_term