forked from DingXiaoH/DiverseBranchBlock
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathacb.py
82 lines (73 loc) · 4.36 KB
/
acb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch.nn as nn
class CropLayer(nn.Module):
# E.g., (-1, 0) means this layer should crop the first and last rows of the feature map. And (0, -1) crops the first and last columns
def __init__(self, crop_set):
super(CropLayer, self).__init__()
self.rows_to_crop = - crop_set[0]
self.cols_to_crop = - crop_set[1]
assert self.rows_to_crop >= 0
assert self.cols_to_crop >= 0
def forward(self, input):
if self.rows_to_crop == 0 and self.cols_to_crop == 0:
return input
elif self.rows_to_crop > 0 and self.cols_to_crop == 0:
return input[:, :, self.rows_to_crop:-self.rows_to_crop, :]
elif self.rows_to_crop == 0 and self.cols_to_crop > 0:
return input[:, :, :, self.cols_to_crop:-self.cols_to_crop]
else:
return input[:, :, self.rows_to_crop:-self.rows_to_crop, self.cols_to_crop:-self.cols_to_crop]
class ACBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False,
use_affine=True, nonlinear=None):
super(ACBlock, self).__init__()
if nonlinear is None:
self.nonlinear = nn.Identity()
else:
self.nonlinear = nonlinear
self.deploy = deploy
if deploy:
self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size,kernel_size), stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
else:
self.square_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=(kernel_size, kernel_size), stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=False,
padding_mode=padding_mode)
self.square_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
center_offset_from_origin_border = padding - kernel_size // 2
ver_pad_or_crop = (padding, center_offset_from_origin_border)
hor_pad_or_crop = (center_offset_from_origin_border, padding)
if center_offset_from_origin_border >= 0:
self.ver_conv_crop_layer = nn.Identity()
ver_conv_padding = ver_pad_or_crop
self.hor_conv_crop_layer = nn.Identity()
hor_conv_padding = hor_pad_or_crop
else:
self.ver_conv_crop_layer = CropLayer(crop_set=ver_pad_or_crop)
ver_conv_padding = (0, 0)
self.hor_conv_crop_layer = CropLayer(crop_set=hor_pad_or_crop)
hor_conv_padding = (0, 0)
self.ver_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, 1),
stride=stride,
padding=ver_conv_padding, dilation=dilation, groups=groups, bias=False,
padding_mode=padding_mode)
self.hor_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, kernel_size),
stride=stride,
padding=hor_conv_padding, dilation=dilation, groups=groups, bias=False,
padding_mode=padding_mode)
self.ver_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
self.hor_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
def forward(self, input):
if self.deploy:
return self.fused_conv(input)
else:
square_outputs = self.square_conv(input)
square_outputs = self.square_bn(square_outputs)
vertical_outputs = self.ver_conv_crop_layer(input)
vertical_outputs = self.ver_conv(vertical_outputs)
vertical_outputs = self.ver_bn(vertical_outputs)
horizontal_outputs = self.hor_conv_crop_layer(input)
horizontal_outputs = self.hor_conv(horizontal_outputs)
horizontal_outputs = self.hor_bn(horizontal_outputs)
result = square_outputs + vertical_outputs + horizontal_outputs
return self.nonlinear(result)