-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathtrans_param_setter.py
285 lines (254 loc) · 9.58 KB
/
trans_param_setter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
from __future__ import division, print_function, absolute_import
import os, sys
import numpy as np
import tensorflow as tf
import json
import copy
import argparse
import time
import functools
import inspect
from model import preprocessing as prep
from model import instance_model
from model.dataset_utils import dataset_func
from utils import DATA_LEN_IMAGENET_FULL, tuple_get_one
from param_setter import get_lr_from_boundary
import config
def get_config():
cfg = config.Config()
cfg.add('batch_size', type=int, default=128,
help='Training batch size')
cfg.add('test_batch_size', type=int, default=64,
help='Testing batch size')
cfg.add('init_lr', type=float, default=0.01,
help='Initial learning rate')
cfg.add('gpu', type=str, required=True,
help='Value for CUDA_VISIBLE_DEVICES')
cfg.add('weight_decay', type=float, default=1e-4,
help='Weight decay')
cfg.add('image_dir', type=str, required=True,
help='Directory containing dataset')
cfg.add('q_cap', type=int, default=102400,
help='Shuffle queue capacity of tfr data')
cfg.add('ten_crop', type=bool,
help='Whether do ten crop validation')
# Loading parameters
cfg.add('load_exp', type=str, required=True,
help='The experiment to load from, in the format '
'[dbname]/[collname]/[exp_id]')
cfg.add('load_step', type=int, default=None,
help='Step number for loading')
cfg.add('load_port', type=int,
help='Port number of mongodb for loading (defaults to saving port')
cfg.add('resume', type=bool,
help='Flag for loading from last step of this exp_id, will override'
' all other loading options.')
# Saving parameters
cfg.add('port', type=int, required=True,
help='Port number for mongodb')
cfg.add('host', type=str, default='localhost',
help='Host for mongodb')
cfg.add('save_exp', type=str, required=True,
help='The [dbname]/[collname]/[exp_id] of this experiment.')
cfg.add('cache_dir', type=str, required=True,
help='Prefix of saving directory')
cfg.add('fre_valid', type=int, default=10009,
help='Frequency of validation')
cfg.add('fre_metric', type=int, default=1000,
help='Frequency of saving metrics')
cfg.add('fre_filter', type=int, default=10009,
help='Frequency of saving filters')
cfg.add('fre_cache_filter', type=int,
help='Frequency of caching filters')
# Training parameters
cfg.add('model_type', type=str, default='resnet18',
help='Model type, resnet or alexnet')
cfg.add('get_all_layers', type=str, default=None,
help='Whether get outputs for all layers')
cfg.add('lr_boundaries', type=str, default=None,
help='Learning rate boundaries for 10x drops')
cfg.add('train_crop', type=str, default='default',
help='Train crop style')
cfg.add('num_classes', type=int, default=1000,
help='Number of classes')
return cfg
def reg_loss(loss, weight_decay):
# Add weight decay to the loss.
def exclude_batch_norm_and_other_device(name):
return 'batch_normalization' not in name \
and 'instance' in name
l2_loss = weight_decay * tf.add_n(
[tf.nn.l2_loss(tf.cast(v, tf.float32))
for v in tf.trainable_variables()
if exclude_batch_norm_and_other_device(v.name)])
loss_all = tf.add(loss, l2_loss)
return loss_all
def add_training_params(params, args):
NUM_BATCHES_PER_EPOCH = DATA_LEN_IMAGENET_FULL / args.batch_size
# model_params: a function that will build the model
model_params = {
'func': instance_model.build_transfer_targets,
'trainable_scopes': ['instance'],
'get_all_layers': args.get_all_layers,
"model_type": args.model_type,
"num_classes": args.num_classes,
}
multi_gpu = len(args.gpu.split(','))
if multi_gpu > 1:
model_params['num_gpus'] = multi_gpu
model_params['devices'] = ['/gpu:%i' % idx for idx in range(multi_gpu)]
params['model_params'] = model_params
# train_params: parameters about training data
process_img_func = prep.resnet_train
if args.train_crop == 'resnet_crop_flip':
process_img_func = prep.resnet_crop_flip
elif args.train_crop == 'alexnet_crop_flip':
process_img_func = prep.alexnet_crop_flip
elif args.train_crop == 'validate_crop':
process_img_func = prep.resnet_validate
train_data_param = {
'func': dataset_func,
'image_dir': args.image_dir,
'process_img_func': process_img_func,
'is_train': True,
'q_cap': args.q_cap,
'batch_size': args.batch_size}
def _train_target_func(
inputs,
output,
get_all_layers=None,
*args,
**kwargs):
if not get_all_layers:
return {'accuracy': output[1]}
else:
return {'accuracy': tf.reduce_mean(output[1].values())}
params['train_params'] = {
'validate_first': False,
'data_params': train_data_param,
'queue_params': None,
'thres_loss': float('Inf'),
'num_steps': int(2000 * NUM_BATCHES_PER_EPOCH),
'targets': {
'func': _train_target_func,
'get_all_layers': args.get_all_layers,
},
}
# loss_params: parameters to build the loss
def loss_func(output, *args, **kwargs):
#print('loss_output', output)
return output[0]
params['loss_params'] = {
'pred_targets': [],
# we don't want GPUs to calculate l2 loss separately
'agg_func': reg_loss,
'agg_func_kwargs': {'weight_decay': args.weight_decay},
'loss_func': loss_func,
}
def add_validation_params(params, args):
# validation_params: control the validation
val_len = 50000
valid_prep_func = prep.resnet_validate
if args.ten_crop:
valid_prep_func = prep.resnet_10crop_validate
topn_val_data_param = {
'func': dataset_func,
'image_dir': args.image_dir,
'process_img_func': valid_prep_func,
'is_train': False,
'q_cap': args.test_batch_size,
'batch_size': args.test_batch_size}
def online_agg(agg_res, res, step):
if agg_res is None:
agg_res = {k: [] for k in res}
for k, v in res.items():
agg_res[k].append(np.mean(v))
return agg_res
def valid_perf_func(inputs, output):
if not args.get_all_layers:
return {'top1': output}
else:
ret_dict = {}
for key, each_out in output.items():
ret_dict['top1_{name}'.format(name=key)] = each_out
return ret_dict
topn_val_param = {
'data_params': topn_val_data_param,
'queue_params': None,
'targets': {'func': valid_perf_func},
# TODO: slight rounding error?
'num_steps': int(val_len/args.test_batch_size),
'agg_func': lambda x: {k: np.mean(v) for k, v in x.items()},
'online_agg_func': online_agg,
}
params['validation_params'] = {
'topn': topn_val_param,
}
def add_save_and_load_params(params, args):
# save_params: defining where to save the models
db_name, col_name, exp_id = args.save_exp.split('/')
cache_dir = os.path.join(
args.cache_dir, 'models',
db_name, col_name, exp_id)
params['save_params'] = {
'host': 'localhost', # used for tfutils
'port': args.port, # used for tfutils
'dbname': db_name,
'collname': col_name,
'exp_id': exp_id,
'do_save': True,
'save_initial_filters': True,
'save_metrics_freq': args.fre_metric,
'save_valid_freq': args.fre_valid,
'save_filters_freq': args.fre_filter,
'cache_filters_freq': args.fre_cache_filter or args.fre_filter,
'cache_dir': cache_dir,
}
# load_params: defining where to load, if needed
if args.resume or args.load_exp is None:
load_exp = args.save_exp
else:
load_exp = args.load_exp
load_dbname, load_collname, load_exp_id = load_exp.split('/')
if args.resume or args.load_step is None:
load_query = None
else:
load_query = {
'exp_id': load_exp_id,
'saved_filters': True,
'step': args.load_step
}
params['load_params'] = {
'host': 'localhost', # used for tfutils
'port': args.load_port or args.port, # used for tfutils
'dbname': load_dbname,
'collname': load_collname,
'exp_id': load_exp_id,
'do_restore': True,
'query': load_query,
}
def add_optimization_params(params, args):
# learning_rate_params: build the learning rate
# For now, just stay the same
NUM_BATCHES_PER_EPOCH = DATA_LEN_IMAGENET_FULL / args.batch_size
params['learning_rate_params'] = {
'func': get_lr_from_boundary,
'init_lr': args.init_lr,
'NUM_BATCHES_PER_EPOCH': NUM_BATCHES_PER_EPOCH,
'boundaries': args.lr_boundaries,
}
# optimizer_params
params['optimizer_params'] = {
'optimizer': tf.train.MomentumOptimizer,
'momentum': .9,
}
def get_params_from_args(args):
params = {
'skip_check': True,
'log_device_placement': False
}
add_training_params(params, args)
add_save_and_load_params(params, args)
add_optimization_params(params, args)
add_validation_params(params, args)
return params