-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresnet18_tusimple.py
116 lines (102 loc) · 2.25 KB
/
resnet18_tusimple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
net = dict(
type='Detector',
)
backbone = dict(
type='ResNetWrapper',
resnet='resnet18',
pretrained=True,
replace_stride_with_dilation=[False, False, False],
out_conv=False,
)
featuremap_out_channel = 512
featuremap_out_stride = 32
num_points = 72
max_lanes = 5
sample_y=range(710, 150, -10)
heads = dict(type='LaneATT',
anchors_freq_path='.cache/tusimple_anchors_freq.pt',
topk_anchors=1000)
train_parameters = dict(
conf_threshold=None,
nms_thres=15.,
nms_topk=3000
)
test_parameters = dict(
conf_threshold=0.2,
nms_thres=45,
nms_topk=max_lanes
)
optimizer = dict(
type = 'Adam',
lr = 0.0003,
)
epochs = 100
batch_size = 8
total_iter = (3616 // batch_size) * epochs
scheduler = dict(
type = 'CosineAnnealingLR',
T_max = total_iter
)
eval_ep = 1
save_ep = epochs
ori_img_w=1280
ori_img_h=720
img_w=640
img_h=360
cut_height=0
train_process = [
dict(type='GenerateLaneLine',
transforms = (
dict(
name = 'Affine',
parameters = dict(
translate_px = dict(
x = (-25, 25),
y = (-10, 10)
),
rotate=(-6, 6),
scale=(0.85, 1.15)
)
),
dict(
name = 'HorizontalFlip',
parameters = dict(
p=0.5
),
)
),
wh = (img_w, img_h),
),
dict(type='ToTensor', keys=['img', 'lane_line']),
]
val_process = [
dict(type='GenerateLaneLine', wh=(img_w, img_h)),
dict(type='ToTensor', keys=['img']),
]
dataset_path = './data/tusimple'
test_json_file = 'data/tusimple/test_label.json'
dataset_type = 'TuSimple'
dataset = dict(
train=dict(
type=dataset_type,
data_root=dataset_path,
split='trainval',
processes=train_process,
),
val=dict(
type=dataset_type,
data_root=dataset_path,
split='test',
processes=val_process,
),
test=dict(
type=dataset_type,
data_root=dataset_path,
split='test',
processes=val_process,
)
)
workers = 12
log_interval = 100
seed=0
lr_update_by_epoch = False