From 43cdc01d04fc1dd8bf2506de7e68b4942c84f3d7 Mon Sep 17 00:00:00 2001 From: RoyalSkye Date: Fri, 9 Sep 2022 18:18:20 +0800 Subject: [PATCH] fix bugs --- POMO/TSP/TSPTrainer.py | 18 ++++++------- POMO/TSP/TSPTrainer_Meta.py | 53 ++++++++++++++++++++----------------- POMO/TSP/TSProblemDef.py | 2 ++ POMO/TSP/train_n100.py | 17 ++++++------ 4 files changed, 48 insertions(+), 42 deletions(-) diff --git a/POMO/TSP/TSPTrainer.py b/POMO/TSP/TSPTrainer.py index c2d999a..9c13736 100644 --- a/POMO/TSP/TSPTrainer.py +++ b/POMO/TSP/TSPTrainer.py @@ -1,3 +1,4 @@ +import copy import random import torch from logging import getLogger @@ -77,12 +78,6 @@ def run(self): self.result_log.append('train_score', epoch, train_score) self.result_log.append('train_loss', epoch, train_loss) - # Val - if epoch % self.trainer_params['val_interval'] == 0: - val_episodes = 1000 - no_aug_score = self._fast_val(self.model, val_episodes=val_episodes) - print(">> validation results: {} over {} instances".format(no_aug_score, val_episodes)) - # Logs & Checkpoint elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.trainer_params['epochs']) self.logger.info("Epoch {:3d}/{:3d}({:.2f}%): Time Est.: Elapsed[{}], Remain[{}]".format( @@ -101,6 +96,11 @@ def run(self): self.result_log, labels=['train_loss']) if all_done or (epoch % model_save_interval) == 0: + # val + val_episodes = 256 + no_aug_score = self._fast_val(copy.deepcopy(self.model), val_episodes=val_episodes) + print(">> validation results: {} over {} instances".format(no_aug_score, val_episodes)) + # save checkpoint self.logger.info("Saving trained_model") checkpoint_dict = { 'epoch': epoch, @@ -148,6 +148,7 @@ def _train_one_epoch(self, epoch): data = get_random_problems(batch_size, problem_size=task_params[0], num_modes=task_params[1], cdist=task_params[-1], distribution='gaussian_mixture') else: raise NotImplementedError + env_params = {'problem_size': data.size(1), 'pomo_size': data.size(1)} avg_score, avg_loss = self._train_one_batch(data, Env(**env_params)) score_AM.update(avg_score.item(), batch_size) @@ -171,7 +172,6 @@ def _train_one_epoch(self, epoch): def _train_one_batch(self, data, env): - # Prep self.model.train() batch_size = data.size(0) env.load_problems(batch_size, problems=data, aug_factor=1) @@ -208,10 +208,10 @@ def _train_one_batch(self, data, env): return score_mean, loss_mean - def _fast_val(self, model, data=None, val_episodes=1000): + def _fast_val(self, model, data=None, val_episodes=256): aug_factor = 1 if data is None: - val_path = "../../data/TSP/tsp100_tsplib.pkl" + val_path = "../../data/TSP/tsp50_tsplib.pkl" data = torch.Tensor(load_dataset(val_path)[: val_episodes]) env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)}) diff --git a/POMO/TSP/TSPTrainer_Meta.py b/POMO/TSP/TSPTrainer_Meta.py index 0adfae3..f00ec0b 100644 --- a/POMO/TSP/TSPTrainer_Meta.py +++ b/POMO/TSP/TSPTrainer_Meta.py @@ -53,10 +53,9 @@ def __init__(self, # Main Components self.meta_model = Model(**self.model_params) - self.meta_optimizer = Optimizer(self.meta_model.parameters(), **self.optimizer_params['optimizer']) self.alpha = self.meta_params['alpha'] # for reptile self.task_set = generate_task_set(self.meta_params) - assert self.trainer_params['meta_params']['epochs'] == math.ceil((1000 * 100000) / ( + assert self.trainer_params['meta_params']['epochs'] == math.ceil((self.trainer_params['epochs'] * self.trainer_params['train_episodes']) / ( self.trainer_params['meta_params']['B'] * self.trainer_params['meta_params']['k'] * self.trainer_params['meta_params']['meta_batch_size'])), ">> meta-learning iteration does not match with POMO!" @@ -79,6 +78,7 @@ def __init__(self, def run(self): self.time_estimator.reset(self.start_epoch) + val_res = [] for epoch in range(self.start_epoch, self.meta_params['epochs']+1): self.logger.info('=================================================================') @@ -87,12 +87,6 @@ def run(self): self.result_log.append('train_score', epoch, train_score) self.result_log.append('train_loss', epoch, train_loss) - # Val - if epoch % self.trainer_params['val_interval'] == 0: - val_episodes = 1000 - no_aug_score = self._fast_val(self.meta_model, val_episodes=val_episodes) - print(">> validation results: {} over {} instances".format(no_aug_score, val_episodes)) - # Logs & Checkpoint elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.meta_params['epochs']) self.logger.info("Epoch {:3d}/{:3d}({:.2f}%): Time Est.: Elapsed[{}], Remain[{}]".format( @@ -111,6 +105,12 @@ def run(self): self.result_log, labels=['train_loss']) if all_done or (epoch % model_save_interval) == 0: + # val + val_episodes = 256 + _, no_aug_score = self._fast_val(copy.deepcopy(self.meta_model), val_episodes=val_episodes) + val_res.append(no_aug_score) + print(">> validation results: {} over {} instances".format(no_aug_score, val_episodes)) + # save checkpoint self.logger.info("Saving trained_model") checkpoint_dict = { 'epoch': epoch, @@ -132,6 +132,7 @@ def run(self): self.logger.info(" *** Training Done *** ") self.logger.info("Now, printing log array...") util_print_log_array(self.logger, self.result_log) + print(val_res) def _train_one_epoch(self, epoch): """ @@ -168,17 +169,15 @@ def _train_one_epoch(self, epoch): if step == self.meta_params['k']: continue env_params = {'problem_size': data.size(1), 'pomo_size': data.size(1)} - avg_score, avg_loss = self._train_one_batch(task_model, data, Env(**env_params)) + avg_score, avg_loss = self._train_one_batch(step, task_model, data, Env(**env_params)) score_AM.update(avg_score.item(), batch_size) loss_AM.update(avg_loss.item(), batch_size) if self.meta_params['meta_method'] == 'maml': # cal loss on query(val) set - data - # val_loss += self._fast_val(task_model, val_episodes=64) - val_loss += self._fast_val(task_model, data=data) + val_loss += self._fast_val(task_model, data=data)[0] elif self.meta_params['meta_method'] == 'fomaml': - loss, _ = self._fast_val(task_model, data=data) - val_loss += loss + val_loss = self._fast_val(task_model, data=data)[0] task_model.train() grad = torch.autograd.grad(val_loss, task_model.parameters()) fomaml_grad.append(grad) @@ -188,14 +187,17 @@ def _train_one_epoch(self, epoch): # update meta-model if self.meta_params['meta_method'] == 'maml': val_loss = val_loss / self.meta_params['B'] - self.meta_optimizer.zero_grad() - val_loss.backward() - self.meta_optimizer.step() + gradients = torch.autograd.grad(val_loss, self.maml) + updated_weights = OrderedDict( + (name, param - self.optimizer_params['optimizer']['lr'] * grad) + for ((name, param), grad) in zip(self.meta_model.state_dict().items(), gradients) + ) + self.meta_model.load_state_dict(updated_weights) elif self.meta_params['meta_method'] == 'fomaml': updated_weights = self.meta_model.state_dict() for gradients in fomaml_grad: updated_weights = OrderedDict( - (name, param - self.optimizer_params['optimizer']['lr'] * grad) + (name, param - self.optimizer_params['optimizer']['lr'] / self.meta_params['B'] * grad) for ((name, param), grad) in zip(updated_weights.items(), gradients) ) self.meta_model.load_state_dict(updated_weights) @@ -208,9 +210,8 @@ def _train_one_epoch(self, epoch): return score_AM.avg, loss_AM.avg - def _train_one_batch(self, task_model, data, env): + def _train_one_batch(self, i, task_model, data, env): - # Prep task_model.train() batch_size = data.size(0) env.load_problems(batch_size, problems=data, aug_factor=1) @@ -242,7 +243,11 @@ def _train_one_batch(self, task_model, data, env): # update model create_graph = True if self.meta_params['meta_method'] == 'maml' else False - gradients = torch.autograd.grad(loss_mean, task_model.parameters(), create_graph=create_graph) + if i == 0: + self.maml = list(task_model.parameters()) + gradients = torch.autograd.grad(loss_mean, self.maml, create_graph=create_graph) + else: + gradients = torch.autograd.grad(loss_mean, task_model.parameters(), create_graph=create_graph) fast_weights = OrderedDict( (name, param - self.optimizer_params['optimizer']['lr'] * grad) for ((name, param), grad) in zip(task_model.state_dict().items(), gradients) @@ -251,10 +256,10 @@ def _train_one_batch(self, task_model, data, env): return score_mean, loss_mean - def _fast_val(self, model, data=None, val_episodes=1000): + def _fast_val(self, model, data=None, val_episodes=256): aug_factor = 1 if data is None: - val_path = "../../data/TSP/tsp100_tsplib.pkl" + val_path = "../../data/TSP/tsp50_tsplib.pkl" data = torch.Tensor(load_dataset(val_path)[: val_episodes]) env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)}) @@ -286,7 +291,7 @@ def _fast_val(self, model, data=None, val_episodes=1000): # shape: (augmentation, batch) no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive value - return loss_mean, no_aug_score.item() + return loss_mean, no_aug_score.detach().item() def _alpha_scheduler(self, iter): - self.alpha *= self.meta_params['alpha_decay'] + self.alpha = max(self.alpha * self.meta_params['alpha_decay'], 0.0001) diff --git a/POMO/TSP/TSProblemDef.py b/POMO/TSP/TSProblemDef.py index 02b8769..43218c5 100644 --- a/POMO/TSP/TSProblemDef.py +++ b/POMO/TSP/TSProblemDef.py @@ -225,6 +225,8 @@ def generate_tsp_dist(n_samples, n_nodes, distribution): test seed: 2023 """ path = "../../data/TSP" + if not os.path.exists(path): + os.makedirs(path) seed_everything(seed=2023) for dist in ["uniform", "uniform_rectangle", "gaussian", "cluster", "diagonal", "tsplib"]: diff --git a/POMO/TSP/train_n100.py b/POMO/TSP/train_n100.py index 01f5a52..6a6bbc0 100644 --- a/POMO/TSP/train_n100.py +++ b/POMO/TSP/train_n100.py @@ -17,8 +17,8 @@ # parameters env_params = { - 'problem_size': 100, - 'pomo_size': 100, + 'problem_size': 50, + 'pomo_size': 50, } model_params = { @@ -47,13 +47,12 @@ 'use_cuda': USE_CUDA, 'cuda_device_num': CUDA_DEVICE_NUM, 'seed': 1234, - 'epochs': 1000, # will be overridden if meta_params['enable'] is True + 'epochs': 500, # will be overridden if meta_params['enable'] is True 'train_episodes': 100000, # number of instances per epoch 'train_batch_size': 64, - 'val_interval': 10, 'logging': { - 'model_save_interval': 100, - 'img_save_interval': 100, + 'model_save_interval': 520, + 'img_save_interval': 520, 'log_image_params_1': { 'json_foldername': 'log_image_style', 'filename': 'general.json' @@ -71,9 +70,9 @@ }, 'meta_params': { 'enable': True, # whether use meta-learning or not - 'meta_method': 'maml', # choose from ['maml', 'fomaml', 'reptile', 'ours'] + 'meta_method': 'fomaml', # choose from ['maml', 'fomaml', 'reptile', 'ours'] 'data_type': 'distribution', # choose from ["size", "distribution", "size_distribution"] - 'epochs': 104167, # the number of meta-model updates: (1000*100000) / (3*50*64) + 'epochs': 52084, # the number of meta-model updates: (500*100000) / (3*50*64) 'B': 3, # the number of tasks in a mini-batch 'k': 5, # gradient decent steps in the inner-loop optimization of meta-learning method 'meta_batch_size': 64, # the batch size of the inner-loop optimization @@ -85,7 +84,7 @@ logger_params = { 'log_file': { - 'desc': 'train_tsp_n100', + 'desc': 'train_tsp_n50', 'filename': 'log.txt' } }