-
Notifications
You must be signed in to change notification settings - Fork 447
/
Copy pathmain_train.py
146 lines (117 loc) · 6.53 KB
/
main_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# --------------------------------------------------------------
# SNIPER: Efficient Multi-Scale Training
# Licensed under The Apache-2.0 License [see LICENSE for details]
# Training Module
# by Mahyar Najibi and Bharat Singh
# --------------------------------------------------------------
import init
import os
from iterators.MNIteratorE2E import MNIteratorE2E
from symbols.faster import *
from configs.faster.default_configs import config, update_config, update_config_from_list
import mxnet as mx
from train_utils import metric
from train_utils.utils import get_optim_params, get_fixed_param_names, create_logger, load_param
from iterators.PrefetchingIter import PrefetchingIter
from data_utils.load_data import load_proposal_roidb, merge_roidb, filter_roidb
from bbox.bbox_regression import add_bbox_regression_targets
import argparse
def parser():
arg_parser = argparse.ArgumentParser('SNIPER training module')
arg_parser.add_argument('--cfg', dest='cfg', help='Path to the config file',
default='configs/faster/sniper_res101_e2e.yml',type=str)
arg_parser.add_argument('--display', dest='display', help='Number of epochs between displaying loss info',
default=100, type=int)
arg_parser.add_argument('--momentum', dest='momentum', help='BN momentum', default=0.995, type=float)
arg_parser.add_argument('--save_prefix', dest='save_prefix', help='Prefix used for snapshotting the network',
default='SNIPER', type=str)
arg_parser.add_argument('--set', dest='set_cfg_list', help='Set the configuration fields from command line',
default=None, nargs=argparse.REMAINDER)
return arg_parser.parse_args()
if __name__ == '__main__':
args = parser()
update_config(args.cfg)
if args.set_cfg_list:
update_config_from_list(args.set_cfg_list)
context = [mx.gpu(int(gpu)) for gpu in config.gpus.split(',')]
nGPUs = len(context)
batch_size = nGPUs * config.TRAIN.BATCH_IMAGES
if not os.path.isdir(config.output_path):
os.mkdir(config.output_path)
# The following is just to make sure the code reproduces
# results in the paper after default scale settings are changed to resolution-based
# However, new scale settings should lead to similar results
if config.dataset.dataset=='coco' and config.dataset.NUM_CLASSES==81:
# Change the scales to what we used in the paper for reproducibility
config.TRAIN.SCALES = (3.0, 1.667, 512.0)
# Create roidb
image_sets = [iset for iset in config.dataset.image_set.split('+')]
roidbs = [load_proposal_roidb(config.dataset.dataset, image_set, config.dataset.root_path,
config.dataset.dataset_path,
proposal=config.dataset.proposal, append_gt=True, flip=config.TRAIN.FLIP,
result_path=config.output_path,
proposal_path=config.proposal_path, load_mask=config.TRAIN.WITH_MASK, only_gt=not config.TRAIN.USE_NEG_CHIPS)
for image_set in image_sets]
roidb = merge_roidb(roidbs)
roidb = filter_roidb(roidb, config)
bbox_means, bbox_stds = add_bbox_regression_targets(roidb, config)
print('Creating Iterator with {} Images'.format(len(roidb)))
train_iter = MNIteratorE2E(roidb=roidb, config=config, batch_size=batch_size, nGPUs=nGPUs,
threads=config.TRAIN.NUM_THREAD, pad_rois_to=400)
print('The Iterator has {} samples!'.format(len(train_iter)))
# Creating the Logger
logger, output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set)
# get list of fixed parameters
print('Initializing the model...')
sym_inst = eval('{}.{}'.format(config.symbol, config.symbol))(n_proposals=400, momentum=args.momentum)
sym = sym_inst.get_symbol_rpn(config) if config.TRAIN.ONLY_PROPOSAL else sym_inst.get_symbol_rcnn(config)
fixed_param_names = get_fixed_param_names(config.network.FIXED_PARAMS, sym)
# Creating the module
mod = mx.mod.Module(symbol=sym,
context=context,
data_names=[k[0] for k in train_iter.provide_data_single],
label_names=[k[0] for k in train_iter.provide_label_single],
fixed_param_names=fixed_param_names,
logger=logger)
shape_dict = dict(train_iter.provide_data_single + train_iter.provide_label_single)
sym_inst.infer_shape(shape_dict)
arg_params, aux_params = load_param(config.network.pretrained, config.network.pretrained_epoch, convert=True)
if config.TRAIN.ONLY_PROPOSAL:
sym_inst.init_weight_rpn(config, arg_params, aux_params)
else:
sym_inst.init_weight_rcnn(config, arg_params, aux_params)
# Creating the metrics
eval_metric = metric.RPNAccMetric()
cls_metric = metric.RPNLogLossMetric()
bbox_metric = metric.RPNL1LossMetric()
rceval_metric = metric.RCNNAccMetric(config)
rccls_metric = metric.RCNNLogLossMetric(config)
rcbbox_metric = metric.RCNNL1LossCRCNNMetric(config)
eval_metrics = mx.metric.CompositeEvalMetric()
eval_metrics.add(eval_metric)
eval_metrics.add(cls_metric)
eval_metrics.add(bbox_metric)
if not config.TRAIN.ONLY_PROPOSAL:
eval_metrics.add(rceval_metric)
eval_metrics.add(rccls_metric)
eval_metrics.add(rcbbox_metric)
if config.TRAIN.AUTO_FOCUS:
auto_focus_eval_metric = metric.AutoFocusLogLossMetric()
auto_focus_acc_metric = metric.AutoFocusAccMetric()
eval_metrics.add(auto_focus_acc_metric)
eval_metrics.add(auto_focus_eval_metric)
if config.TRAIN.WITH_MASK:
mask_metric = metric.MaskLogLossMetric(config)
eval_metrics.add(mask_metric)
optimizer_params = get_optim_params(config, len(train_iter), batch_size)
print ('Optimizer params: {}'.format(optimizer_params))
# Checkpointing
prefix = os.path.join(output_path, args.save_prefix)
batch_end_callback = mx.callback.Speedometer(batch_size, args.display)
epoch_end_callback = [mx.callback.module_checkpoint(mod, prefix, period=1, save_optimizer_states=True),
eval('{}.checkpoint_callback'.format(config.symbol))(sym_inst.get_bbox_param_names(), prefix, bbox_means, bbox_stds)]
train_iter = PrefetchingIter(train_iter)
mod.fit(train_iter, optimizer='sgd', optimizer_params=optimizer_params,
eval_metric=eval_metrics, num_epoch=config.TRAIN.end_epoch, kvstore=config.default.kvstore,
batch_end_callback=batch_end_callback,
epoch_end_callback=epoch_end_callback, arg_params=arg_params, aux_params=aux_params)