-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathpartial_asymmetric_loss.py
178 lines (139 loc) · 7.18 KB
/
partial_asymmetric_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import torch
from torch import nn as nn, Tensor
import os
import pandas as pd
import numpy as np
class PartialSelectiveLoss(nn.Module):
def __init__(self, args):
super(PartialSelectiveLoss, self).__init__()
self.args = args
self.clip = args.clip
self.gamma_pos = args.gamma_pos
self.gamma_neg = args.gamma_neg
self.gamma_unann = args.gamma_unann
self.alpha_pos = args.alpha_pos
self.alpha_neg = args.alpha_neg
self.alpha_unann = args.alpha_unann
self.targets_weights = None
if args.prior_path is not None:
print("Prior file was found in given path.")
df = pd.read_csv(args.prior_path)
self.prior_classes = dict(zip(df.values[:, 0], df.values[:, 1]))
print("Prior file was loaded successfully. ")
def forward(self, logits, targets):
# Positive, Negative and Un-annotated indexes
targets_pos = (targets == 1).float()
targets_neg = (targets == 0).float()
targets_unann = (targets == -1).float()
# Activation
xs_pos = torch.sigmoid(logits)
xs_neg = 1.0 - xs_pos
if self.clip is not None and self.clip > 0:
xs_neg.add_(self.clip).clamp_(max=1)
prior_classes = None
if hasattr(self, "prior_classes"):
prior_classes = torch.tensor(list(self.prior_classes.values())).cuda()
targets_weights = self.targets_weights
targets_weights, xs_neg = edit_targets_parital_labels(self.args, targets, targets_weights, xs_neg,
prior_classes=prior_classes)
# Loss calculation
BCE_pos = self.alpha_pos * targets_pos * torch.log(torch.clamp(xs_pos, min=1e-8))
BCE_neg = self.alpha_neg * targets_neg * torch.log(torch.clamp(xs_neg, min=1e-8))
BCE_unann = self.alpha_unann * targets_unann * torch.log(torch.clamp(xs_neg, min=1e-8))
BCE_loss = BCE_pos + BCE_neg + BCE_unann
# Adding asymmetric gamma weights
with torch.no_grad():
asymmetric_w = torch.pow(1 - xs_pos * targets_pos - xs_neg * (targets_neg + targets_unann),
self.gamma_pos * targets_pos + self.gamma_neg * targets_neg +
self.gamma_unann * targets_unann)
BCE_loss *= asymmetric_w
# partial labels weights
BCE_loss *= targets_weights
return -BCE_loss.sum()
def edit_targets_parital_labels(args, targets, targets_weights, xs_neg, prior_classes=None):
# targets_weights is and internal state of AsymmetricLoss class. we don't want to re-allocate it every batch
if args.partial_loss_mode is None:
targets_weights = 1.0
elif args.partial_loss_mode == 'negative':
# set all unsure targets as negative
targets_weights = 1.0
elif args.partial_loss_mode == 'ignore':
# remove all unsure targets (targets_weights=0)
targets_weights = torch.ones(targets.shape, device=torch.device('cuda'))
targets_weights[targets == -1] = 0
elif args.partial_loss_mode == 'ignore_normalize_classes':
# remove all unsure targets and normalize by Durand et al. https://arxiv.org/pdf/1902.09720.pdfs
alpha_norm, beta_norm = 1, 1
targets_weights = torch.ones(targets.shape, device=torch.device('cuda'))
n_annotated = 1 + torch.sum(targets != -1, axis=1) # Add 1 to avoid dividing by zero
g_norm = alpha_norm * (1 / n_annotated) + beta_norm
n_classes = targets_weights.shape[1]
targets_weights *= g_norm.repeat([n_classes, 1]).T
targets_weights[targets == -1] = 0
elif args.partial_loss_mode == 'selective':
if targets_weights is None or targets_weights.shape != targets.shape:
targets_weights = torch.ones(targets.shape, device=torch.device('cuda'))
else:
targets_weights[:] = 1.0
num_top_k = args.likelihood_topk * targets_weights.shape[0]
xs_neg_prob = xs_neg
if prior_classes is not None:
if args.prior_threshold:
idx_ignore = torch.where(prior_classes > args.prior_threshold)[0]
targets_weights[:, idx_ignore] = 0
targets_weights += (targets != -1).float()
targets_weights = targets_weights.bool()
negative_backprop_fun_jit(targets, xs_neg_prob, targets_weights, num_top_k)
# set all unsure targets as negative
# targets[targets == -1] = 0
return targets_weights, xs_neg
# @torch.jit.script
def negative_backprop_fun_jit(targets: Tensor, xs_neg_prob: Tensor, targets_weights: Tensor, num_top_k: int):
with torch.no_grad():
targets_flatten = targets.flatten()
cond_flatten = torch.where(targets_flatten == -1)[0]
targets_weights_flatten = targets_weights.flatten()
xs_neg_prob_flatten = xs_neg_prob.flatten()
ind_class_sort = torch.argsort(xs_neg_prob_flatten[cond_flatten])
targets_weights_flatten[
cond_flatten[ind_class_sort[:num_top_k]]] = 0
class ComputePrior:
def __init__(self, classes):
self.classes = classes
n_classes = len(self.classes)
self.sum_pred_train = torch.zeros(n_classes).cuda()
self.sum_pred_val = torch.zeros(n_classes).cuda()
self.cnt_samples_train, self.cnt_samples_val = .0, .0
self.avg_pred_train, self.avg_pred_val = None, None
self.path_dest = "./outputs"
self.path_local = "/class_prior/"
def update(self, logits, training=True):
with torch.no_grad():
preds = torch.sigmoid(logits).detach()
if training:
self.sum_pred_train += torch.sum(preds, axis=0)
self.cnt_samples_train += preds.shape[0]
self.avg_pred_train = self.sum_pred_train / self.cnt_samples_train
else:
self.sum_pred_val += torch.sum(preds, axis=0)
self.cnt_samples_val += preds.shape[0]
self.avg_pred_val = self.sum_pred_val / self.cnt_samples_val
def save_prior(self):
print('Prior (train), first 5 classes: {}'.format(self.avg_pred_train[:5]))
# Save data frames as csv files
if not os.path.exists(self.path_dest):
os.makedirs(self.path_dest)
df_train = pd.DataFrame({"Classes": list(self.classes.values()),
"avg_pred": self.avg_pred_train.cpu()})
df_train.to_csv(path_or_buf=os.path.join(self.path_dest, "train_avg_preds.csv"),
sep=',', header=True, index=False, encoding='utf-8')
if self.avg_pred_val is not None:
df_val = pd.DataFrame({"Classes": list(self.classes.values()),
"avg_pred": self.avg_pred_val.cpu()})
df_val.to_csv(path_or_buf=os.path.join(self.path_dest, "val_avg_preds.csv"),
sep=',', header=True, index=False, encoding='utf-8')
def get_top_freq_classes(self):
n_top = 10
top_idx = torch.argsort(-self.avg_pred_train.cpu())[:n_top]
top_classes = np.array(list(self.classes.values()))[top_idx]
print('Prior (train), first {} classes: {}'.format(n_top, top_classes))