-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtf_base1.py
177 lines (140 loc) · 6.91 KB
/
tf_base1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from collections import deque
from datetime import datetime
from importlib import reload
import logging
import os
import pprint as pp
import numpy as np
import tensorflow as tf
from tf_utils import shape
class TFBaseModel(object):
def __init__(
self,
reader,
batch_size=128,
num_training_steps=20000,
learning_rate=.01,
optimizer='adam',
grad_clip=5,
regularization_constant=0.0,
keep_prob=1.0,
early_stopping_steps=3000,
warm_start_init_step=0,
num_restarts=None,
enable_parameter_averaging=False,
min_steps_to_checkpoint=100,
log_interval=20,
loss_averaging_window=100,
num_validation_batches=1,
log_dir='logs',
checkpoint_dir='checkpoints',
prediction_dir='predictions'
):
self.reader = reader
self.batch_size = batch_size
self.num_training_steps = num_training_steps
self.learning_rate = learning_rate
self.optimizer = optimizer
self.grad_clip = grad_clip
self.regularization_constant = regularization_constant
self.warm_start_init_step = warm_start_init_step
self.early_stopping_steps = early_stopping_steps
self.keep_prob_scalar = keep_prob
self.enable_parameter_averaging = enable_parameter_averaging
self.num_restarts = num_restarts
self.min_steps_to_checkpoint = min_steps_to_checkpoint
self.log_interval = log_interval
self.num_validation_batches = num_validation_batches
self.loss_averaging_window = loss_averaging_window
self.log_dir = log_dir
self.prediction_dir = prediction_dir
self.checkpoint_dir = checkpoint_dir
if self.enable_parameter_averaging:
self.checkpoint_dir_averaged = checkpoint_dir + '_avg'
self.init_logging(self.log_dir)
logging.info('\nnew run with parameters:\n{}'.format(pp.pformat(self.__dict__)))
self.graph = self.build_graph()
self.session = tf.Session(graph=self.graph)
print ('built graph')
def predict(self, chunk_size=2048):
if not os.path.isdir(self.prediction_dir):
os.makedirs(self.prediction_dir)
if hasattr(self, 'prediction_tensors'):
prediction_dict = {tensor_name: [] for tensor_name in self.prediction_tensors}
test_generator = self.reader.test_batch_generator(chunk_size)
for i, test_batch_df in enumerate(test_generator):
if i % 100 == 0:
print (i*chunk_size)
test_feed_dict = {
getattr(self, placeholder_name, None): data
for placeholder_name, data in test_batch_df if hasattr(self, placeholder_name)
}
if hasattr(self, 'keep_prob'):
test_feed_dict.update({self.keep_prob: 1.0})
if hasattr(self, 'is_training'):
test_feed_dict.update({self.is_training: False})
tensor_names, tf_tensors = zip(*self.prediction_tensors.items())
np_tensors = self.session.run(
fetches=tf_tensors,
feed_dict=test_feed_dict
)
for tensor_name, tensor in zip(tensor_names, np_tensors):
prediction_dict[tensor_name].append(tensor)
for tensor_name, tensor in prediction_dict.items():
np_tensor = np.concatenate(tensor, 0)
save_file = os.path.join(self.prediction_dir, '{}.npy'.format(tensor_name))
logging.info('saving {} with shape {} to {}'.format(tensor_name, np_tensor.shape, save_file))
np.save(save_file, np_tensor)
if hasattr(self, 'parameter_tensors'):
for tensor_name, tensor in self.parameter_tensors.items():
np_tensor = tensor.eval(self.session)
save_file = os.path.join(self.prediction_dir, '{}.npy'.format(tensor_name))
logging.info('saving {} with shape {} to {}'.format(tensor_name, np_tensor.shape, save_file))
np.save(save_file, np_tensor)
def save(self, step, averaged=False):
saver = self.saver_averaged if averaged else self.saver
checkpoint_dir = self.checkpoint_dir_averaged if averaged else self.checkpoint_dir
if not os.path.isdir(checkpoint_dir):
logging.info('creating checkpoint directory {}'.format(checkpoint_dir))
os.mkdir(checkpoint_dir)
model_path = os.path.join(checkpoint_dir, 'model')
logging.info('saving model to {}'.format(model_path))
saver.save(self.session, model_path, global_step=step)
def init_logging(self, log_dir):
if not os.path.isdir(log_dir):
os.makedirs(log_dir)
date_str = datetime.now().strftime('%Y-%m-%d_%H-%M')
log_file = 'log_{}.txt'.format(date_str)
)
logging.getLogger().addHandler(logging.StreamHandler())
def update_parameters(self, loss):
self.global_step = tf.Variable(0, trainable=False)
self.learning_rate_var = tf.Variable(0.0, trainable=False)
if self.regularization_constant != 0:
l2_norm = tf.reduce_sum([tf.sqrt(tf.reduce_sum(tf.square(param))) for param in tf.trainable_variables()])
loss = loss + self.regularization_constant*l2_norm
optimizer = self.get_optimizer(self.learning_rate_var)
grads = optimizer.compute_gradients(loss)
clipped = [(tf.clip_by_value(g, -self.grad_clip, self.grad_clip), v_) for g, v_ in grads]
step = optimizer.apply_gradients(clipped, global_step=self.global_step)
if self.enable_parameter_averaging:
maintain_averages_op = self.ema.apply(tf.trainable_variables())
with tf.control_dependencies([step]):
self.step = tf.group(maintain_averages_op)
else:
self.step = step
logging.info('all parameters:')
logging.info(pp.pformat([(var.name, shape(var)) for var in tf.global_variables()]))
logging.info('trainable parameters:')
logging.info(pp.pformat([(var.name, shape(var)) for var in tf.trainable_variables()]))
logging.info('trainable parameter count:')
logging.info(str(np.sum(np.prod(shape(var)) for var in tf.trainable_variables())))
def get_optimizer(self, learning_rate):
if self.optimizer == 'adam':
return tf.train.AdamOptimizer(learning_rate)
elif self.optimizer == 'gd':
return tf.train.GradientDescentOptimizer(learning_rate)
elif self.optimizer == 'rms':
return tf.train.RMSPropOptimizer(learning_rate, decay=0.95, momentum=0.9)
else:
assert False, 'optimizer must be adam, gd, or rms'