Skip to content

Commit

Permalink
update code structure
Browse files Browse the repository at this point in the history
  • Loading branch information
RoyalSkye committed Sep 16, 2022
1 parent 43cdc01 commit 9f3f6c9
Show file tree
Hide file tree
Showing 7 changed files with 296 additions and 52 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ pretrained/
# private files
utils_plot*
imgs/
1.md
44 changes: 28 additions & 16 deletions POMO/TSP/TSPTester.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def __init__(self,

# load dataset
self.test_data = load_dataset(tester_params['test_set_path'])[: self.tester_params['test_episodes']]
opt_sol = load_dataset(tester_params['test_set_opt_sol_path'])[: self.tester_params['test_episodes']] # [(obj, route), ...]
self.opt_sol = [i[0] for i in opt_sol]
if self.fine_tune_params['enable']:
start = tester_params['test_episodes'] if self.tester_params['test_set_path'] == self.fine_tune_params['fine_tune_set_path'] else 0
self.fine_tune_data = load_dataset(self.fine_tune_params['fine_tune_set_path'])[start: start+self.fine_tune_params['fine_tune_episodes']]
Expand All @@ -72,19 +74,24 @@ def run(self):
def _test(self):

self.time_estimator.reset()
score_AM = AverageMeter()
aug_score_AM = AverageMeter()
score_AM, gap_AM = AverageMeter(), AverageMeter()
aug_score_AM, aug_gap_AM = AverageMeter(), AverageMeter()

test_num_episode = self.tester_params['test_episodes']
assert len(self.test_data) == test_num_episode, "the number of test instances does not match!"
episode = 0
while episode < test_num_episode:
remaining = test_num_episode - episode
batch_size = min(self.tester_params['test_batch_size'], remaining)
score, aug_score = self._test_one_batch(torch.Tensor(self.test_data[episode:episode + batch_size]))
score, aug_score, all_score, all_aug_score = self._test_one_batch(torch.Tensor(self.test_data[episode: episode + batch_size]))
opt_sol = self.opt_sol[episode: episode + batch_size]
score_AM.update(score, batch_size)
aug_score_AM.update(aug_score, batch_size)
episode += batch_size
gap = [max(all_score[i].item() - opt_sol[i], 0) / opt_sol[i] * 100 for i in range(batch_size)]
aug_gap = [max(all_aug_score[i].item() - opt_sol[i], 0) / opt_sol[i] * 100 for i in range(batch_size)]
gap_AM.update(sum(gap)/batch_size, batch_size)
aug_gap_AM.update(sum(aug_gap)/batch_size, batch_size)

############################
# Logs
Expand All @@ -97,10 +104,10 @@ def _test(self):

if all_done:
self.logger.info(" *** Test Done *** ")
self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg))
self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))
self.logger.info(" NO-AUG SCORE: {:.4f}, Gap: {:.4f} ".format(score_AM.avg, gap_AM.avg))
self.logger.info(" AUGMENTATION SCORE: {:.4f}, Gap: {:.4f} ".format(aug_score_AM.avg, aug_gap_AM.avg))

return score_AM.avg, aug_score_AM.avg
return score_AM.avg, aug_score_AM.avg, gap_AM.avg, aug_gap_AM.avg

def _test_one_batch(self, test_data):
# Augmentation
Expand Down Expand Up @@ -130,24 +137,26 @@ def _test_one_batch(self, test_data):

max_pomo_reward, _ = aug_reward.max(dim=2) # get best results from pomo
# shape: (augmentation, batch)
no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive value
no_aug_score = -max_pomo_reward[0, :].float() # negative sign to make positive value
no_aug_score_mean = no_aug_score.mean()

max_aug_pomo_reward, _ = max_pomo_reward.max(dim=0) # get best results from augmentation
# shape: (batch,)
aug_score = -max_aug_pomo_reward.float().mean() # negative sign to make positive value
aug_score = -max_aug_pomo_reward.float() # negative sign to make positive value
aug_score_mean = aug_score.mean()

return no_aug_score.item(), aug_score.item()
return no_aug_score_mean.item(), aug_score_mean.item(), no_aug_score, aug_score

def _fine_tune_and_test(self):
"""
evaluate few-shot generalization: fine-tune k steps on a small fine-tune dataset, test on test dataset after every step
"""
fine_tune_episode = self.fine_tune_params['fine_tune_episodes']
assert len(self.fine_tune_data) == fine_tune_episode, "the number of fine-tune instances does not match!"
score_list, aug_score_list = [], []
score, aug_score = self._test()
score_list.append(score)
aug_score_list.append(aug_score)
score_list, aug_score_list, gap_list, aug_gap_list = [], [], [], []
score, aug_score, gap, aug_gap = self._test()
score_list.append(score); aug_score_list.append(aug_score)
gap_list.append(gap); aug_gap_list.append(aug_gap)

for k in range(self.fine_tune_params['k']):
self.logger.info("Start fine-tune step {}".format(k+1))
Expand All @@ -157,12 +166,15 @@ def _fine_tune_and_test(self):
batch_size = min(self.fine_tune_params['fine_tune_batch_size'], remaining)
self._fine_tune_one_batch(torch.Tensor(self.fine_tune_data[episode:episode+batch_size]))
episode += batch_size
score, aug_score = self._test()
score_list.append(score)
aug_score_list.append(aug_score)
score, aug_score, gap, aug_gap = self._test()
score_list.append(score); aug_score_list.append(aug_score)
gap_list.append(gap); aug_gap_list.append(aug_gap)

print(self.tester_params['test_set_path'])
print("Final score_list: {}".format(score_list))
print("Final aug_score_list {}".format(aug_score_list))
print("Final gap_list: {}".format(gap_list))
print("Final aug_gap_list: {}".format(aug_gap_list))

def _fine_tune_one_batch(self, fine_tune_data):
# Augmentation
Expand Down
6 changes: 3 additions & 3 deletions POMO/TSP/TSPTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def run(self):

if all_done or (epoch % model_save_interval) == 0:
# val
val_episodes = 256
no_aug_score = self._fast_val(copy.deepcopy(self.model), val_episodes=val_episodes)
val_episodes = 64
no_aug_score = self._fast_val(self.model, val_episodes=val_episodes)
print(">> validation results: {} over {} instances".format(no_aug_score, val_episodes))
# save checkpoint
self.logger.info("Saving trained_model")
Expand Down Expand Up @@ -208,7 +208,7 @@ def _train_one_batch(self, data, env):

return score_mean, loss_mean

def _fast_val(self, model, data=None, val_episodes=256):
def _fast_val(self, model, data=None, val_episodes=64):
aug_factor = 1
if data is None:
val_path = "../../data/TSP/tsp50_tsplib.pkl"
Expand Down
37 changes: 16 additions & 21 deletions POMO/TSP/TSPTrainer_Meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,21 @@ def __init__(self,
def run(self):

self.time_estimator.reset(self.start_epoch)
val_res = []
for epoch in range(self.start_epoch, self.meta_params['epochs']+1):
self.logger.info('=================================================================')

# Train
train_score, train_loss = self._train_one_epoch(epoch)
self.result_log.append('train_score', epoch, train_score)
self.result_log.append('train_loss', epoch, train_loss)
# Val
_, no_aug_score = self._fast_val(copy.deepcopy(self.meta_model), val_episodes=64)
self.result_log.append('val_score', epoch, no_aug_score)

# Logs & Checkpoint
elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.meta_params['epochs'])
self.logger.info("Epoch {:3d}/{:3d}({:.2f}%): Time Est.: Elapsed[{}], Remain[{}]".format(
epoch, self.meta_params['epochs'], epoch/self.meta_params['epochs']*100, elapsed_time_str, remain_time_str))
self.logger.info("Epoch {:3d}/{:3d}({:.2f}%): Time Est.: Elapsed[{}], Remain[{}], Val Score: {:.4f}".format(
epoch, self.meta_params['epochs'], epoch/self.meta_params['epochs']*100, elapsed_time_str, remain_time_str, no_aug_score))

all_done = (epoch == self.meta_params['epochs'])
model_save_interval = self.trainer_params['logging']['model_save_interval']
Expand All @@ -99,17 +101,11 @@ def run(self):
if epoch > 1: # save latest images, every epoch
self.logger.info("Saving log_image")
image_prefix = '{}/latest'.format(self.result_folder)
util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_1'],
self.result_log, labels=['train_score'])
util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_2'],
self.result_log, labels=['train_loss'])
util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_1'], self.result_log, labels=['train_score'])
util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_1'], self.result_log, labels=['val_score'])
util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_2'], self.result_log, labels=['train_loss'])

if all_done or (epoch % model_save_interval) == 0:
# val
val_episodes = 256
_, no_aug_score = self._fast_val(copy.deepcopy(self.meta_model), val_episodes=val_episodes)
val_res.append(no_aug_score)
print(">> validation results: {} over {} instances".format(no_aug_score, val_episodes))
# save checkpoint
self.logger.info("Saving trained_model")
checkpoint_dict = {
Expand All @@ -123,10 +119,9 @@ def run(self):

if all_done or (epoch % img_save_interval) == 0:
image_prefix = '{}/img/checkpoint-{}'.format(self.result_folder, epoch)
util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_1'],
self.result_log, labels=['train_score'])
util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_2'],
self.result_log, labels=['train_loss'])
util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_1'], self.result_log, labels=['train_score'])
util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_1'], self.result_log, labels=['val_score'])
util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_2'], self.result_log, labels=['train_loss'])

if all_done:
self.logger.info(" *** Training Done *** ")
Expand Down Expand Up @@ -237,10 +232,6 @@ def _train_one_batch(self, i, task_model, data, env):
# shape: (batch, pomo)
loss_mean = loss.mean()

# Score
max_pomo_reward, _ = reward.max(dim=1) # get best results from pomo
score_mean = -max_pomo_reward.float().mean() # negative sign to make positive value

# update model
create_graph = True if self.meta_params['meta_method'] == 'maml' else False
if i == 0:
Expand All @@ -254,9 +245,13 @@ def _train_one_batch(self, i, task_model, data, env):
)
task_model.load_state_dict(fast_weights)

# Score
max_pomo_reward, _ = reward.max(dim=1) # get best results from pomo
score_mean = -max_pomo_reward.float().mean() # negative sign to make positive value

return score_mean, loss_mean

def _fast_val(self, model, data=None, val_episodes=256):
def _fast_val(self, model, data=None, val_episodes=64):
aug_factor = 1
if data is None:
val_path = "../../data/TSP/tsp50_tsplib.pkl"
Expand Down
Loading

0 comments on commit 9f3f6c9

Please sign in to comment.