-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathsoft_n_cut_loss.py
62 lines (49 loc) · 2.64 KB
/
soft_n_cut_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
r"""
loss.py
--------
Implementations of loss functions used for training W-Net CNN models.
"""
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.ndimage import grey_opening
from utils.filter import gaussian_kernel
class NCutLoss2D(nn.Module):
r"""Implementation of the continuous N-Cut loss, as in:
'W-Net: A Deep Model for Fully Unsupervised Image Segmentation', by Xia, Kulis (2017)"""
def __init__(self, radius: int = 5, sigma_1: float = 4, sigma_2: float = 10):
r"""
:param radius: Radius of the spatial interaction term
:param sigma_1: Standard deviation of the spatial Gaussian interaction
:param sigma_2: Standard deviation of the pixel value Gaussian interaction
"""
super(NCutLoss2D, self).__init__()
self.radius = radius
self.sigma_1 = sigma_1 # Spatial standard deviation
self.sigma_2 = sigma_2 # Pixel value standard deviation
def forward(self, labels: Tensor, inputs: Tensor) -> Tensor:
r"""Computes the continuous N-Cut loss, given a set of class probabilities (labels) and raw images (inputs).
Small modifications have been made here for efficiency -- specifically, we compute the pixel-wise weights
relative to the class-wide average, rather than for every individual pixel.
:param labels: Predicted class probabilities
:param inputs: Raw images
:return: Continuous N-Cut loss
"""
num_classes = labels.shape[1]
kernel = gaussian_kernel(radius=self.radius, sigma=self.sigma_1, device=labels.device.type)
loss = 0
for k in range(num_classes):
# Compute the average pixel value for this class, and the difference from each pixel
class_probs = labels[:, k].unsqueeze(1)
class_mean = torch.mean(inputs * class_probs, dim=(2, 3), keepdim=True) / \
torch.add(torch.mean(class_probs, dim=(2, 3), keepdim=True), 1e-5)
diff = (inputs - class_mean).pow(2).sum(dim=1).unsqueeze(1)
# Weight the loss by the difference from the class average.
weights = torch.exp(diff.pow(2).mul(-1 / self.sigma_2 ** 2))
# Compute N-cut loss, using the computed weights matrix, and a Gaussian spatial filter
numerator = torch.sum(class_probs * F.conv2d(class_probs * weights, kernel, padding=self.radius))
denominator = torch.sum(class_probs * F.conv2d(weights, kernel, padding=self.radius))
loss += nn.L1Loss()(numerator / torch.add(denominator, 1e-6), torch.zeros_like(numerator))
return num_classes - loss