Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
RoyalSkye committed Sep 7, 2022
1 parent 95de9e5 commit 755887c
Show file tree
Hide file tree
Showing 8 changed files with 385 additions and 38 deletions.
15 changes: 10 additions & 5 deletions POMO/TSP/TSPTester.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,12 @@ def __init__(self,
self.optimizer = Optimizer(self.model.parameters(), **self.tester_params['fine_tune_params']['optimizer'])

# load dataset
self.test_data = load_dataset(tester_params['test_set_path'])
self.fine_tune_data = load_dataset(self.fine_tune_params['fine_tune_set_path']) if self.fine_tune_set_path['enable'] else None
self.test_data = load_dataset(tester_params['test_set_path'])[: self.tester_params['test_episodes']]
if self.fine_tune_params['enable']:
start = tester_params['test_episodes'] if self.tester_params['test_set_path'] == self.fine_tune_params['fine_tune_set_path'] else 0
self.fine_tune_data = load_dataset(self.fine_tune_params['fine_tune_set_path'])[start: start+self.fine_tune_params['fine_tune_episodes']]
else:
self.fine_tune_data = None

# Restore
model_load = tester_params['model_load']
Expand Down Expand Up @@ -146,6 +150,7 @@ def _fine_tune_and_test(self):
aug_score_list.append(aug_score)

for k in range(self.fine_tune_params['k']):
self.logger.info("Start fine-tune step {}".format(k+1))
episode = 0
while episode < fine_tune_episode:
remaining = fine_tune_episode - episode
Expand Down Expand Up @@ -183,14 +188,14 @@ def _fine_tune_one_batch(self, fine_tune_data):
prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2)

# Loss
aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size).transpose(1, 0, 2).view(batch_size, -1)
aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size).permute(1, 0, 2).view(batch_size, -1)
# shape: (batch, augmentation * pomo)
advantage = aug_reward - aug_reward.float().mean(dim=1, keepdims=True)
# shape: (batch, augmentation * pomo)
log_prob = prob_list.log().sum(dim=2).reshape(aug_factor, batch_size, self.env.pomo_size).transpose(1, 0, 2).view(batch_size, -1)
log_prob = prob_list.log().sum(dim=2).reshape(aug_factor, batch_size, self.env.pomo_size).permute(1, 0, 2).view(batch_size, -1)
# size = (batch, augmentation * pomo)
loss = -advantage * log_prob # Minus Sign: To Increase REWARD
# shape: (batch, augmentation * pomo)
# shape: (batch, augmentation * pomo)pretra
loss_mean = loss.mean()

# Score
Expand Down
39 changes: 37 additions & 2 deletions POMO/TSP/TSPTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from TSProblemDef import get_random_problems, generate_task_set

from utils.utils import *
from utils.functions import load_dataset


class TSPTrainer:
Expand Down Expand Up @@ -76,10 +77,14 @@ def run(self):
self.result_log.append('train_score', epoch, train_score)
self.result_log.append('train_loss', epoch, train_loss)

# Val
if epoch % self.trainer_params['val_interval'] == 0:
self._fast_val()

# Logs & Checkpoint
elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.trainer_params['epochs'])
self.logger.info("Epoch {:3d}/{:3d}: Time Est.: Elapsed[{}], Remain[{}]".format(
epoch, self.trainer_params['epochs'], elapsed_time_str, remain_time_str))
self.logger.info("Epoch {:3d}/{:3d}({:.2f}%): Time Est.: Elapsed[{}], Remain[{}]".format(
epoch, self.trainer_params['epochs'], epoch/self.trainer_params['epochs']*100, elapsed_time_str, remain_time_str))

all_done = (epoch == self.trainer_params['epochs'])
model_save_interval = self.trainer_params['logging']['model_save_interval']
Expand Down Expand Up @@ -199,3 +204,33 @@ def _train_one_batch(self, data, env):
loss_mean.backward()
self.optimizer.step()
return score_mean.item(), loss_mean.item()

def _fast_val(self):
val_path = "../../data/TSP/tsp100_tsplib.pkl"
val_episodes = 5000
aug_factor = 1
data = torch.Tensor(load_dataset(val_path)[: val_episodes])
env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)})

self.model.eval()
batch_size = data.size(0)
with torch.no_grad():
env.load_problems(batch_size, problems=data, aug_factor=aug_factor)
reset_state, _, _ = env.reset()
self.model.pre_forward(reset_state)

state, reward, done = env.pre_step()
while not done:
selected, _ = self.model(state)
# shape: (batch, pomo)
state, reward, done = env.step(selected)

# Return
aug_reward = reward.reshape(aug_factor, batch_size, env.pomo_size)
# shape: (augmentation, batch, pomo)

max_pomo_reward, _ = aug_reward.max(dim=2) # get best results from pomo
# shape: (augmentation, batch)
no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive value

print(">> validation results: {}".format(no_aug_score.item()))
70 changes: 54 additions & 16 deletions POMO/TSP/TSPTrainer_Meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
from TSProblemDef import get_random_problems, generate_task_set

from utils.utils import *
from utils.functions import load_dataset


class TSPTrainer:
"""
Implementation of POMO with Reptile.
"""
def __init__(self,
env_params,
model_params,
Expand Down Expand Up @@ -69,27 +73,31 @@ def __init__(self,
self.time_estimator = TimeEstimator()

def run(self):
"""
1. Sample B training tasks from task distribution P(T)
2. for each of task T_i, do reptile -> \theta_i
3. update meta-model \theta_0
"""

self.time_estimator.reset(self.start_epoch)
for epoch in range(self.start_epoch, self.meta_params['epochs']+1):
self.logger.info('=================================================================')

# Train
if self.meta_params['meta_method'] == 'reptile':
if self.meta_params['meta_method'] == 'maml':
pass
elif self.meta_params['meta_method'] == 'fomaml':
pass
elif self.meta_params['meta_method'] == 'reptile':
train_score, train_loss = self._train_one_epoch(epoch)
else:
raise NotImplementedError
self.result_log.append('train_score', epoch, train_score)
self.result_log.append('train_loss', epoch, train_loss)

# Val
if epoch % self.trainer_params['val_interval'] == 0:
self._fast_val()

# Logs & Checkpoint
elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.meta_params['epochs'])
self.logger.info("Epoch {:3d}/{:3d}: Time Est.: Elapsed[{}], Remain[{}]".format(
epoch, self.meta_params['epochs'], elapsed_time_str, remain_time_str))
self.logger.info("Epoch {:3d}/{:3d}({:.2f}%): Time Est.: Elapsed[{}], Remain[{}]".format(
epoch, self.meta_params['epochs'], epoch/self.meta_params['epochs']*100, elapsed_time_str, remain_time_str))

all_done = (epoch == self.meta_params['epochs'])
model_save_interval = self.trainer_params['logging']['model_save_interval']
Expand Down Expand Up @@ -127,7 +135,11 @@ def run(self):
util_print_log_array(self.logger, self.result_log)

def _train_one_epoch(self, epoch):

"""
1. Sample B training tasks from task distribution P(T)
2. for a batch of tasks T_i, do reptile -> \theta_i
3. update meta-model -> \theta_0
"""
score_AM = AverageMeter()
loss_AM = AverageMeter()

Expand All @@ -137,10 +149,11 @@ def _train_one_epoch(self, epoch):
fast_weights = []

for i in range(self.meta_params['B']):
task_params = random.sample(self.task_set, 1)[0] # uniform sample a task
task_params = random.sample(self.task_set, 1)[0] # uniformly sample a task
task_model = copy.deepcopy(self.meta_model)
optimizer = Optimizer(task_model.parameters(), **self.optimizer_params['optimizer'])
for batch_id in range(self.meta_params['k']):

for step in range(self.meta_params['k']):
# generate task-specific data
if self.meta_params['data_type'] == 'distribution':
assert len(task_params) == 2
Expand All @@ -153,6 +166,7 @@ def _train_one_epoch(self, epoch):
data = get_random_problems(batch_size, problem_size=task_params[0], num_modes=task_params[1], cdist=task_params[-1], distribution='gaussian_mixture')
else:
raise NotImplementedError

env_params = {'problem_size': data.size(1), 'pomo_size': data.size(1)}
avg_score, avg_loss = self._train_one_batch(task_model, data, optimizer, Env(**env_params))
score_AM.update(avg_score, batch_size)
Expand Down Expand Up @@ -208,11 +222,35 @@ def _train_one_batch(self, task_model, data, optimizer, env):

return score_mean.item(), loss_mean.item()

def _fast_val(self, task_model, data, env):
"""
TODO: a simple implementation of fast evaluation at the end of each meta training iteration.
"""
return 0, 0
def _fast_val(self):
val_path = "../../data/TSP/tsp100_tsplib.pkl"
val_episodes = 5000
aug_factor = 1
data = torch.Tensor(load_dataset(val_path)[: val_episodes])
env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)})

self.meta_model.eval()
batch_size = data.size(0)
with torch.no_grad():
env.load_problems(batch_size, problems=data, aug_factor=aug_factor)
reset_state, _, _ = env.reset()
self.meta_model.pre_forward(reset_state)

state, reward, done = env.pre_step()
while not done:
selected, _ = self.meta_model(state)
# shape: (batch, pomo)
state, reward, done = env.step(selected)

# Return
aug_reward = reward.reshape(aug_factor, batch_size, env.pomo_size)
# shape: (augmentation, batch, pomo)

max_pomo_reward, _ = aug_reward.max(dim=2) # get best results from pomo
# shape: (augmentation, batch)
no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive value

print(">> validation results: {}".format(no_aug_score.item()))

def _alpha_scheduler(self, iter):
self.alpha *= self.meta_params['alpha_decay']
Loading

0 comments on commit 755887c

Please sign in to comment.