-
Notifications
You must be signed in to change notification settings - Fork 129
/
Copy pathtrain.py
123 lines (108 loc) · 4.61 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# -*- coding: utf-8 -*-
# Implementation of Wang et al 2017: Automatic Brain Tumor Segmentation using Cascaded Anisotropic Convolutional Neural Networks. https://arxiv.org/abs/1709.00382
# Author: Guotai Wang
# Copyright (c) 2017-2018 University College London, United Kingdom. All rights reserved.
# http://cmictig.cs.ucl.ac.uk
#
# Distributed under the BSD-3 licence. Please see the file licence.txt
# This software is not certified for clinical use.
#
from __future__ import absolute_import, print_function
import numpy as np
import random
from scipy import ndimage
import time
import os
import sys
import tensorflow as tf
from tensorflow.contrib.data import Iterator
from tensorflow.contrib.layers.python.layers import regularizers
from niftynet.layer.loss_segmentation import LossFunction
from util.data_loader import *
from util.train_test_func import *
from util.parse_config import parse_config
from util.MSNet import MSNet
class NetFactory(object):
@staticmethod
def create(name):
if name == 'MSNet':
return MSNet
# add your own networks here
print('unsupported network:', name)
exit()
def train(config_file):
# 1, load configuration parameters
config = parse_config(config_file)
config_data = config['data']
config_net = config['network']
config_train = config['training']
random.seed(config_train.get('random_seed', 1))
assert(config_data['with_ground_truth'])
net_type = config_net['net_type']
net_name = config_net['net_name']
class_num = config_net['class_num']
batch_size = config_data.get('batch_size', 5)
# 2, construct graph
full_data_shape = [batch_size] + config_data['data_shape']
full_label_shape = [batch_size] + config_data['label_shape']
x = tf.placeholder(tf.float32, shape = full_data_shape)
w = tf.placeholder(tf.float32, shape = full_label_shape)
y = tf.placeholder(tf.int64, shape = full_label_shape)
w_regularizer = regularizers.l2_regularizer(config_train.get('decay', 1e-7))
b_regularizer = regularizers.l2_regularizer(config_train.get('decay', 1e-7))
net_class = NetFactory.create(net_type)
net = net_class(num_classes = class_num,
w_regularizer = w_regularizer,
b_regularizer = b_regularizer,
name = net_name)
net.set_params(config_net)
predicty = net(x, is_training = True)
proby = tf.nn.softmax(predicty)
loss_func = LossFunction(n_class=class_num)
loss = loss_func(predicty, y, weight_map = w)
print('size of predicty:',predicty)
# 3, initialize session and saver
lr = config_train.get('learning_rate', 1e-3)
opt_step = tf.train.AdamOptimizer(lr).minimize(loss)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
dataloader = DataLoader(config_data)
dataloader.load_data()
# 4, start to train
loss_file = config_train['model_save_prefix'] + "_loss.txt"
start_it = config_train.get('start_iteration', 0)
if( start_it > 0):
saver.restore(sess, config_train['model_pre_trained'])
loss_list, temp_loss_list = [], []
for n in range(start_it, config_train['maximal_iteration']):
train_pair = dataloader.get_subimage_batch()
tempx = train_pair['images']
tempw = train_pair['weights']
tempy = train_pair['labels']
opt_step.run(session = sess, feed_dict={x:tempx, w: tempw, y:tempy})
if(n%config_train['test_iteration'] == 0):
batch_dice_list = []
for step in range(config_train['test_step']):
train_pair = dataloader.get_subimage_batch()
tempx = train_pair['images']
tempw = train_pair['weights']
tempy = train_pair['labels']
dice = loss.eval(feed_dict ={x:tempx, w:tempw, y:tempy})
batch_dice_list.append(dice)
batch_dice = np.asarray(batch_dice_list, np.float32).mean()
t = time.strftime('%X %x %Z')
print(t, 'n', n,'loss', batch_dice)
loss_list.append(batch_dice)
np.savetxt(loss_file, np.asarray(loss_list))
if((n+1)%config_train['snapshot_iteration'] == 0):
saver.save(sess, config_train['model_save_prefix']+"_{0:}.ckpt".format(n+1))
sess.close()
if __name__ == '__main__':
if(len(sys.argv) != 2):
print('Number of arguments should be 2. e.g.')
print(' python train.py config17/train_wt_ax.txt')
exit()
config_file = str(sys.argv[1])
assert(os.path.isfile(config_file))
train(config_file)