Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
RoyalSkye committed Sep 9, 2022
1 parent 253156b commit 43cdc01
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 42 deletions.
18 changes: 9 additions & 9 deletions POMO/TSP/TSPTrainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import random
import torch
from logging import getLogger
Expand Down Expand Up @@ -77,12 +78,6 @@ def run(self):
self.result_log.append('train_score', epoch, train_score)
self.result_log.append('train_loss', epoch, train_loss)

# Val
if epoch % self.trainer_params['val_interval'] == 0:
val_episodes = 1000
no_aug_score = self._fast_val(self.model, val_episodes=val_episodes)
print(">> validation results: {} over {} instances".format(no_aug_score, val_episodes))

# Logs & Checkpoint
elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.trainer_params['epochs'])
self.logger.info("Epoch {:3d}/{:3d}({:.2f}%): Time Est.: Elapsed[{}], Remain[{}]".format(
Expand All @@ -101,6 +96,11 @@ def run(self):
self.result_log, labels=['train_loss'])

if all_done or (epoch % model_save_interval) == 0:
# val
val_episodes = 256
no_aug_score = self._fast_val(copy.deepcopy(self.model), val_episodes=val_episodes)
print(">> validation results: {} over {} instances".format(no_aug_score, val_episodes))
# save checkpoint
self.logger.info("Saving trained_model")
checkpoint_dict = {
'epoch': epoch,
Expand Down Expand Up @@ -148,6 +148,7 @@ def _train_one_epoch(self, epoch):
data = get_random_problems(batch_size, problem_size=task_params[0], num_modes=task_params[1], cdist=task_params[-1], distribution='gaussian_mixture')
else:
raise NotImplementedError

env_params = {'problem_size': data.size(1), 'pomo_size': data.size(1)}
avg_score, avg_loss = self._train_one_batch(data, Env(**env_params))
score_AM.update(avg_score.item(), batch_size)
Expand All @@ -171,7 +172,6 @@ def _train_one_epoch(self, epoch):

def _train_one_batch(self, data, env):

# Prep
self.model.train()
batch_size = data.size(0)
env.load_problems(batch_size, problems=data, aug_factor=1)
Expand Down Expand Up @@ -208,10 +208,10 @@ def _train_one_batch(self, data, env):

return score_mean, loss_mean

def _fast_val(self, model, data=None, val_episodes=1000):
def _fast_val(self, model, data=None, val_episodes=256):
aug_factor = 1
if data is None:
val_path = "../../data/TSP/tsp100_tsplib.pkl"
val_path = "../../data/TSP/tsp50_tsplib.pkl"
data = torch.Tensor(load_dataset(val_path)[: val_episodes])
env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)})

Expand Down
53 changes: 29 additions & 24 deletions POMO/TSP/TSPTrainer_Meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ def __init__(self,

# Main Components
self.meta_model = Model(**self.model_params)
self.meta_optimizer = Optimizer(self.meta_model.parameters(), **self.optimizer_params['optimizer'])
self.alpha = self.meta_params['alpha'] # for reptile
self.task_set = generate_task_set(self.meta_params)
assert self.trainer_params['meta_params']['epochs'] == math.ceil((1000 * 100000) / (
assert self.trainer_params['meta_params']['epochs'] == math.ceil((self.trainer_params['epochs'] * self.trainer_params['train_episodes']) / (
self.trainer_params['meta_params']['B'] * self.trainer_params['meta_params']['k'] *
self.trainer_params['meta_params']['meta_batch_size'])), ">> meta-learning iteration does not match with POMO!"

Expand All @@ -79,6 +78,7 @@ def __init__(self,
def run(self):

self.time_estimator.reset(self.start_epoch)
val_res = []
for epoch in range(self.start_epoch, self.meta_params['epochs']+1):
self.logger.info('=================================================================')

Expand All @@ -87,12 +87,6 @@ def run(self):
self.result_log.append('train_score', epoch, train_score)
self.result_log.append('train_loss', epoch, train_loss)

# Val
if epoch % self.trainer_params['val_interval'] == 0:
val_episodes = 1000
no_aug_score = self._fast_val(self.meta_model, val_episodes=val_episodes)
print(">> validation results: {} over {} instances".format(no_aug_score, val_episodes))

# Logs & Checkpoint
elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.meta_params['epochs'])
self.logger.info("Epoch {:3d}/{:3d}({:.2f}%): Time Est.: Elapsed[{}], Remain[{}]".format(
Expand All @@ -111,6 +105,12 @@ def run(self):
self.result_log, labels=['train_loss'])

if all_done or (epoch % model_save_interval) == 0:
# val
val_episodes = 256
_, no_aug_score = self._fast_val(copy.deepcopy(self.meta_model), val_episodes=val_episodes)
val_res.append(no_aug_score)
print(">> validation results: {} over {} instances".format(no_aug_score, val_episodes))
# save checkpoint
self.logger.info("Saving trained_model")
checkpoint_dict = {
'epoch': epoch,
Expand All @@ -132,6 +132,7 @@ def run(self):
self.logger.info(" *** Training Done *** ")
self.logger.info("Now, printing log array...")
util_print_log_array(self.logger, self.result_log)
print(val_res)

def _train_one_epoch(self, epoch):
"""
Expand Down Expand Up @@ -168,17 +169,15 @@ def _train_one_epoch(self, epoch):

if step == self.meta_params['k']: continue
env_params = {'problem_size': data.size(1), 'pomo_size': data.size(1)}
avg_score, avg_loss = self._train_one_batch(task_model, data, Env(**env_params))
avg_score, avg_loss = self._train_one_batch(step, task_model, data, Env(**env_params))
score_AM.update(avg_score.item(), batch_size)
loss_AM.update(avg_loss.item(), batch_size)

if self.meta_params['meta_method'] == 'maml':
# cal loss on query(val) set - data
# val_loss += self._fast_val(task_model, val_episodes=64)
val_loss += self._fast_val(task_model, data=data)
val_loss += self._fast_val(task_model, data=data)[0]
elif self.meta_params['meta_method'] == 'fomaml':
loss, _ = self._fast_val(task_model, data=data)
val_loss += loss
val_loss = self._fast_val(task_model, data=data)[0]
task_model.train()
grad = torch.autograd.grad(val_loss, task_model.parameters())
fomaml_grad.append(grad)
Expand All @@ -188,14 +187,17 @@ def _train_one_epoch(self, epoch):
# update meta-model
if self.meta_params['meta_method'] == 'maml':
val_loss = val_loss / self.meta_params['B']
self.meta_optimizer.zero_grad()
val_loss.backward()
self.meta_optimizer.step()
gradients = torch.autograd.grad(val_loss, self.maml)
updated_weights = OrderedDict(
(name, param - self.optimizer_params['optimizer']['lr'] * grad)
for ((name, param), grad) in zip(self.meta_model.state_dict().items(), gradients)
)
self.meta_model.load_state_dict(updated_weights)
elif self.meta_params['meta_method'] == 'fomaml':
updated_weights = self.meta_model.state_dict()
for gradients in fomaml_grad:
updated_weights = OrderedDict(
(name, param - self.optimizer_params['optimizer']['lr'] * grad)
(name, param - self.optimizer_params['optimizer']['lr'] / self.meta_params['B'] * grad)
for ((name, param), grad) in zip(updated_weights.items(), gradients)
)
self.meta_model.load_state_dict(updated_weights)
Expand All @@ -208,9 +210,8 @@ def _train_one_epoch(self, epoch):

return score_AM.avg, loss_AM.avg

def _train_one_batch(self, task_model, data, env):
def _train_one_batch(self, i, task_model, data, env):

# Prep
task_model.train()
batch_size = data.size(0)
env.load_problems(batch_size, problems=data, aug_factor=1)
Expand Down Expand Up @@ -242,7 +243,11 @@ def _train_one_batch(self, task_model, data, env):

# update model
create_graph = True if self.meta_params['meta_method'] == 'maml' else False
gradients = torch.autograd.grad(loss_mean, task_model.parameters(), create_graph=create_graph)
if i == 0:
self.maml = list(task_model.parameters())
gradients = torch.autograd.grad(loss_mean, self.maml, create_graph=create_graph)
else:
gradients = torch.autograd.grad(loss_mean, task_model.parameters(), create_graph=create_graph)
fast_weights = OrderedDict(
(name, param - self.optimizer_params['optimizer']['lr'] * grad)
for ((name, param), grad) in zip(task_model.state_dict().items(), gradients)
Expand All @@ -251,10 +256,10 @@ def _train_one_batch(self, task_model, data, env):

return score_mean, loss_mean

def _fast_val(self, model, data=None, val_episodes=1000):
def _fast_val(self, model, data=None, val_episodes=256):
aug_factor = 1
if data is None:
val_path = "../../data/TSP/tsp100_tsplib.pkl"
val_path = "../../data/TSP/tsp50_tsplib.pkl"
data = torch.Tensor(load_dataset(val_path)[: val_episodes])
env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)})

Expand Down Expand Up @@ -286,7 +291,7 @@ def _fast_val(self, model, data=None, val_episodes=1000):
# shape: (augmentation, batch)
no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive value

return loss_mean, no_aug_score.item()
return loss_mean, no_aug_score.detach().item()

def _alpha_scheduler(self, iter):
self.alpha *= self.meta_params['alpha_decay']
self.alpha = max(self.alpha * self.meta_params['alpha_decay'], 0.0001)
2 changes: 2 additions & 0 deletions POMO/TSP/TSProblemDef.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def generate_tsp_dist(n_samples, n_nodes, distribution):
test seed: 2023
"""
path = "../../data/TSP"
if not os.path.exists(path):
os.makedirs(path)
seed_everything(seed=2023)

for dist in ["uniform", "uniform_rectangle", "gaussian", "cluster", "diagonal", "tsplib"]:
Expand Down
17 changes: 8 additions & 9 deletions POMO/TSP/train_n100.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
# parameters

env_params = {
'problem_size': 100,
'pomo_size': 100,
'problem_size': 50,
'pomo_size': 50,
}

model_params = {
Expand Down Expand Up @@ -47,13 +47,12 @@
'use_cuda': USE_CUDA,
'cuda_device_num': CUDA_DEVICE_NUM,
'seed': 1234,
'epochs': 1000, # will be overridden if meta_params['enable'] is True
'epochs': 500, # will be overridden if meta_params['enable'] is True
'train_episodes': 100000, # number of instances per epoch
'train_batch_size': 64,
'val_interval': 10,
'logging': {
'model_save_interval': 100,
'img_save_interval': 100,
'model_save_interval': 520,
'img_save_interval': 520,
'log_image_params_1': {
'json_foldername': 'log_image_style',
'filename': 'general.json'
Expand All @@ -71,9 +70,9 @@
},
'meta_params': {
'enable': True, # whether use meta-learning or not
'meta_method': 'maml', # choose from ['maml', 'fomaml', 'reptile', 'ours']
'meta_method': 'fomaml', # choose from ['maml', 'fomaml', 'reptile', 'ours']
'data_type': 'distribution', # choose from ["size", "distribution", "size_distribution"]
'epochs': 104167, # the number of meta-model updates: (1000*100000) / (3*50*64)
'epochs': 52084, # the number of meta-model updates: (500*100000) / (3*50*64)
'B': 3, # the number of tasks in a mini-batch
'k': 5, # gradient decent steps in the inner-loop optimization of meta-learning method
'meta_batch_size': 64, # the batch size of the inner-loop optimization
Expand All @@ -85,7 +84,7 @@

logger_params = {
'log_file': {
'desc': 'train_tsp_n100',
'desc': 'train_tsp_n50',
'filename': 'log.txt'
}
}
Expand Down

0 comments on commit 43cdc01

Please sign in to comment.