Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
RoyalSkye committed Nov 9, 2022
1 parent fdb346e commit 38f55f8
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 323 deletions.
20 changes: 18 additions & 2 deletions POMO/TSP/TSPTester.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch

import os
import pickle
from logging import getLogger
from torch.optim import Adam as Optimizer

Expand Down Expand Up @@ -44,6 +45,7 @@ def __init__(self,
self.model = Model(**self.model_params)
self.env = Env(**self.env_params) # we assume instances in the test/fine-tune dataset have the same problem size.
self.optimizer = Optimizer(self.model.parameters(), **self.tester_params['fine_tune_params']['optimizer'])
self.score_list, self.aug_score_list, self.gap_list, self.aug_gap_list = [], [], [], []

# load dataset
self.test_data = load_dataset(tester_params['test_set_path'])[: self.tester_params['test_episodes']]
Expand Down Expand Up @@ -90,7 +92,13 @@ def run(self):
# test the model on test dataset
self._test()

def _test(self):
# save results to file
# with open(os.path.join(self.result_folder, 'result_lists.pkl'), 'wb') as f:
with open(os.path.split(self.tester_params['test_set_path'])[-1], 'wb') as f:
result = {"score_list": self.score_list, "aug_score_list": self.aug_score_list, "gap_list": self.gap_list, "aug_gap_list": self.aug_gap_list}
pickle.dump(result, f, pickle.HIGHEST_PROTOCOL)

def _test(self, store_res=True):

self.time_estimator.reset()
score_AM, gap_AM = AverageMeter(), AverageMeter()
Expand All @@ -112,6 +120,12 @@ def _test(self):
gap_AM.update(sum(gap)/batch_size, batch_size)
aug_gap_AM.update(sum(aug_gap)/batch_size, batch_size)

if store_res:
self.score_list += all_score.tolist()
self.aug_score_list += all_aug_score.tolist()
self.gap_list += gap
self.aug_gap_list += aug_gap

############################
# Logs
############################
Expand All @@ -125,6 +139,8 @@ def _test(self):
self.logger.info(" *** Test Done *** ")
self.logger.info(" NO-AUG SCORE: {:.4f}, Gap: {:.4f} ".format(score_AM.avg, gap_AM.avg))
self.logger.info(" AUGMENTATION SCORE: {:.4f}, Gap: {:.4f} ".format(aug_score_AM.avg, aug_gap_AM.avg))
print("{:.4f} ({:.4f}%)".format(score_AM.avg, gap_AM.avg))
print("{:.4f} ({:.4f}%)".format(aug_score_AM.avg, aug_gap_AM.avg))

return score_AM.avg, aug_score_AM.avg, gap_AM.avg, aug_gap_AM.avg

Expand Down Expand Up @@ -173,7 +189,7 @@ def _fine_tune_and_test(self):
fine_tune_episode = self.fine_tune_params['fine_tune_episodes']
assert len(self.fine_tune_data) == fine_tune_episode, "the number of fine-tune instances does not match!"
score_list, aug_score_list, gap_list, aug_gap_list = [], [], [], []
score, aug_score, gap, aug_gap = self._test()
score, aug_score, gap, aug_gap = self._test(store_res=False)
score_list.append(score); aug_score_list.append(aug_score)
gap_list.append(gap); aug_gap_list.append(aug_gap)

Expand Down
237 changes: 0 additions & 237 deletions POMO/TSP/TSPTrainer.py

This file was deleted.

Loading

0 comments on commit 38f55f8

Please sign in to comment.