Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed initialisation issue erasing loaded model. #20

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions base/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def load(self, sess):
print("Loading model checkpoint {} ...\n".format(latest_checkpoint))
self.saver.restore(sess, latest_checkpoint)
print("Model loaded")
else:
print("NO model loaded. Training from beginning")

# just initialize a tensorflow variable to use it as epoch counter
def init_cur_epoch(self):
Expand Down
6 changes: 5 additions & 1 deletion base/base_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


class BaseTrain:
def __init__(self, sess, model, data, config, logger):
def __init__(self, sess, model, data, config, logger, load=False):
self.model = model
self.logger = logger
self.config = config
Expand All @@ -11,6 +11,10 @@ def __init__(self, sess, model, data, config, logger):
self.init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
self.sess.run(self.init)

# Load the model after initialization to ensure that the loaded values are kept
if load:
self.model.load(self.sess)

def train(self):
for cur_epoch in range(self.model.cur_epoch_tensor.eval(self.sess), self.config.num_epochs + 1, 1):
self.train_epoch()
Expand Down
4 changes: 1 addition & 3 deletions mains/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,12 @@ def main():
sess = tf.Session()
# create an instance of the model you want
model = ExampleModel(config)
#load model if exists
model.load(sess)
# create your data generator
data = DataGenerator(config)
# create tensorboard logger
logger = Logger(sess, config)
# create trainer and pass all the previous components to it
trainer = ExampleTrainer(sess, model, data, config, logger)
trainer = ExampleTrainer(sess, model, data, config, logger, load=True)

# here you train your model
trainer.train()
Expand Down
4 changes: 2 additions & 2 deletions trainers/example_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@


class ExampleTrainer(BaseTrain):
def __init__(self, sess, model, data, config,logger):
super(ExampleTrainer, self).__init__(sess, model, data, config,logger)
def __init__(self, sess, model, data, config,logger, **kwargs):
super(ExampleTrainer, self).__init__(sess, model, data, config,logger, **kwargs)

def train_epoch(self):
loop = tqdm(range(self.config.num_iter_per_epoch))
Expand Down