-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresidual_block.py
25 lines (22 loc) · 1.14 KB
/
residual_block.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.layers = nn.Sequential(
# Reflection padding is used for Smooth transitions and preventing artifacts as convolutional filters are applied towards the edge.
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, kernel_size=3, padding=0),
# Instance Normalization prevents internal covariate shift due to batch coupling
# In BatchNorm,
# Too little batch size can cause estimates of the mean and variance to become noisy,
# which can lead to suboptimal performance.
# When batch size changes (during inference), the statistics might not generalize well
# That's why we're using InstanceNorm
nn.InstanceNorm2d(channels, affine=False),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, kernel_size=3, padding=0),
nn.InstanceNorm2d(channels, affine=False),
)
def forward(self, x):
return x + self.layers(x)