- Forward propagation
- Define the hybrid loss Function
- Backward propagation
PyTorch based implementation of https://arxiv.org/pdf/1406.2661.pdf
The segmentation of a image is the task of classifying each pixel to a certain class. Here this is done using an adversarial appraoch rather than the conventional CNN or FCN which were based on single loss term.
The Network model is as follows :
The two networks are :
- Segmentor
- Discriminator
The segmentor is the traditional ConvNet we have seen that takes as input an image of dimensions HxWxC, here: H:Height ; W:Width ; C:Channels of the imput image. And outputs the classpredictions for each pixel of the image. So the output is of shape HxWxS where S is the desired no of classes we want to segment the image to. The third dimension i.e. the 'S' one contains the probability of a certain pixel belonging to each of those S classes.
Segmentor's Loss Term : Multiclass Cross Entropy :
The Discriminator takes as input a label map, and the corresponding RGB image. The label map is either the ground truth corresponding to the image or that predicted by the Segmentor Network. Initially two separate branches process the label map and the input RGB image. These are converted to 64 channels and finally concatenated . This concatenated signal is then passed through another stack of convolutional and max-pooling layers, after which the binary class probability is produced by a sigmoid activation function. This tells us whether the label map was genuine (coming from ground truth) or counterfiet (produced by segmentor).
Adversary's Loss Term : Binary Cross Entropy
Overall Loss Term : Segmentor-lambda(Adversarial)
- N: number of training examples
- s(xn) : output of segmentor
- yn : ground truth label
- θs: parameters of the segmentation model
- θa: parameters of adversarial model
In GANS we train the two networks one by one at each step. Here training the Adversarial is minimizing the BCE Loss :
And training the Segmentor is equivalent to minimizing the multi-class cross-entropy loss, while at the same time reducing the performance of the adversarial network. This encourages the segmentation model to produce segmentation maps that are hard to distinguish from ground-truth ones for the adversarial model. This is the classic mini max game of Adversarial networks. ie :
In the Segmentor's training funtion, we replace the second term by
The reason behind this is described in https://arxiv.org/pdf/1406.2661.pdf .