From 253156b26a57340a0b3d29ccc4ac1f57849c2d4a Mon Sep 17 00:00:00 2001 From: RoyalSkye Date: Thu, 8 Sep 2022 15:55:45 +0800 Subject: [PATCH] add maml and fomaml --- .gitignore | 3 +- POMO/TSP/TSPModel.py | 22 +++---- POMO/TSP/TSPTester.py | 2 +- POMO/TSP/TSPTrainer.py | 30 +++++---- POMO/TSP/TSPTrainer_Meta.py | 128 +++++++++++++++++++++++------------- POMO/TSP/train_n100.py | 8 +-- 6 files changed, 113 insertions(+), 80 deletions(-) diff --git a/.gitignore b/.gitignore index 31a28a7..a161774 100644 --- a/.gitignore +++ b/.gitignore @@ -12,10 +12,9 @@ __pycache__/ .idea/ # data & pretrain-model -AM/ -result/ backup/ data/ +pretrained/ # private files utils_plot* diff --git a/POMO/TSP/TSPModel.py b/POMO/TSP/TSPModel.py index 5a0e70f..98c269b 100644 --- a/POMO/TSP/TSPModel.py +++ b/POMO/TSP/TSPModel.py @@ -38,23 +38,19 @@ def forward(self, state): probs = self.decoder(encoded_last_node, ninf_mask=state.ninf_mask) # shape: (batch, pomo, problem) - if self.training or self.model_params['eval_type'] == 'softmax': - while True: - selected = probs.reshape(batch_size * pomo_size, -1).multinomial(1) \ - .squeeze(dim=1).reshape(batch_size, pomo_size) + while True: + if self.training or self.model_params['eval_type'] == 'softmax': + selected = probs.reshape(batch_size * pomo_size, -1).multinomial(1).squeeze(dim=1).reshape(batch_size, pomo_size) # shape: (batch, pomo) - - prob = probs[state.BATCH_IDX, state.POMO_IDX, selected] \ - .reshape(batch_size, pomo_size) + else: + selected = probs.argmax(dim=2) # shape: (batch, pomo) - if (prob != 0).all(): - break - - else: - selected = probs.argmax(dim=2) + prob = probs[state.BATCH_IDX, state.POMO_IDX, selected].reshape(batch_size, pomo_size) # shape: (batch, pomo) - prob = None + + if (prob != 0).all(): + break return selected, prob diff --git a/POMO/TSP/TSPTester.py b/POMO/TSP/TSPTester.py index d0e6f8d..6db87d9 100644 --- a/POMO/TSP/TSPTester.py +++ b/POMO/TSP/TSPTester.py @@ -203,6 +203,6 @@ def _fine_tune_one_batch(self, fine_tune_data): score_mean = -max_pomo_reward.float().mean() # negative sign to make positive value # Step & Return - self.model.zero_grad() + self.optimizer.zero_grad() loss_mean.backward() self.optimizer.step() diff --git a/POMO/TSP/TSPTrainer.py b/POMO/TSP/TSPTrainer.py index f4e9773..c2d999a 100644 --- a/POMO/TSP/TSPTrainer.py +++ b/POMO/TSP/TSPTrainer.py @@ -79,7 +79,9 @@ def run(self): # Val if epoch % self.trainer_params['val_interval'] == 0: - self._fast_val() + val_episodes = 1000 + no_aug_score = self._fast_val(self.model, val_episodes=val_episodes) + print(">> validation results: {} over {} instances".format(no_aug_score, val_episodes)) # Logs & Checkpoint elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.trainer_params['epochs']) @@ -148,8 +150,8 @@ def _train_one_epoch(self, epoch): raise NotImplementedError env_params = {'problem_size': data.size(1), 'pomo_size': data.size(1)} avg_score, avg_loss = self._train_one_batch(data, Env(**env_params)) - score_AM.update(avg_score, batch_size) - loss_AM.update(avg_loss, batch_size) + score_AM.update(avg_score.item(), batch_size) + loss_AM.update(avg_loss.item(), batch_size) episode += batch_size @@ -200,37 +202,37 @@ def _train_one_batch(self, data, env): score_mean = -max_pomo_reward.float().mean() # negative sign to make positive value # Step & Return - self.model.zero_grad() + self.optimizer.zero_grad() loss_mean.backward() self.optimizer.step() - return score_mean.item(), loss_mean.item() - def _fast_val(self): - val_path = "../../data/TSP/tsp100_tsplib.pkl" - val_episodes = 5000 + return score_mean, loss_mean + + def _fast_val(self, model, data=None, val_episodes=1000): aug_factor = 1 - data = torch.Tensor(load_dataset(val_path)[: val_episodes]) + if data is None: + val_path = "../../data/TSP/tsp100_tsplib.pkl" + data = torch.Tensor(load_dataset(val_path)[: val_episodes]) env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)}) - self.model.eval() + model.eval() batch_size = data.size(0) with torch.no_grad(): env.load_problems(batch_size, problems=data, aug_factor=aug_factor) reset_state, _, _ = env.reset() - self.model.pre_forward(reset_state) + model.pre_forward(reset_state) state, reward, done = env.pre_step() while not done: - selected, _ = self.model(state) + selected, _ = model(state) # shape: (batch, pomo) state, reward, done = env.step(selected) # Return aug_reward = reward.reshape(aug_factor, batch_size, env.pomo_size) # shape: (augmentation, batch, pomo) - max_pomo_reward, _ = aug_reward.max(dim=2) # get best results from pomo # shape: (augmentation, batch) no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive value - print(">> validation results: {}".format(no_aug_score.item())) + return no_aug_score.item() diff --git a/POMO/TSP/TSPTrainer_Meta.py b/POMO/TSP/TSPTrainer_Meta.py index 185c849..0adfae3 100644 --- a/POMO/TSP/TSPTrainer_Meta.py +++ b/POMO/TSP/TSPTrainer_Meta.py @@ -3,6 +3,7 @@ import random import torch from logging import getLogger +from collections import OrderedDict from TSPEnv import TSPEnv as Env from TSPModel import TSPModel as Model @@ -17,7 +18,9 @@ class TSPTrainer: """ - Implementation of POMO with Reptile. + Implementation of POMO with MAML / FOMAML / Reptile. + For MAML & FOMAML, ref to "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks"; + For Reptile, ref to "On First-Order Meta-Learning Algorithms". """ def __init__(self, env_params, @@ -50,7 +53,8 @@ def __init__(self, # Main Components self.meta_model = Model(**self.model_params) - self.alpha = self.meta_params['alpha'] + self.meta_optimizer = Optimizer(self.meta_model.parameters(), **self.optimizer_params['optimizer']) + self.alpha = self.meta_params['alpha'] # for reptile self.task_set = generate_task_set(self.meta_params) assert self.trainer_params['meta_params']['epochs'] == math.ceil((1000 * 100000) / ( self.trainer_params['meta_params']['B'] * self.trainer_params['meta_params']['k'] * @@ -79,20 +83,15 @@ def run(self): self.logger.info('=================================================================') # Train - if self.meta_params['meta_method'] == 'maml': - pass - elif self.meta_params['meta_method'] == 'fomaml': - pass - elif self.meta_params['meta_method'] == 'reptile': - train_score, train_loss = self._train_one_epoch(epoch) - else: - raise NotImplementedError + train_score, train_loss = self._train_one_epoch(epoch) self.result_log.append('train_score', epoch, train_score) self.result_log.append('train_loss', epoch, train_loss) # Val if epoch % self.trainer_params['val_interval'] == 0: - self._fast_val() + val_episodes = 1000 + no_aug_score = self._fast_val(self.meta_model, val_episodes=val_episodes) + print(">> validation results: {} over {} instances".format(no_aug_score, val_episodes)) # Logs & Checkpoint elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.meta_params['epochs']) @@ -137,23 +136,23 @@ def run(self): def _train_one_epoch(self, epoch): """ 1. Sample B training tasks from task distribution P(T) - 2. for a batch of tasks T_i, do reptile -> \theta_i - 3. update meta-model -> \theta_0 + 2. inner-loop: for a batch of tasks T_i, do reptile -> \theta_i + 3. outer-loop: update meta-model -> \theta_0 """ score_AM = AverageMeter() loss_AM = AverageMeter() - batch_size = self.meta_params['meta_batch_size'] + self._alpha_scheduler(epoch) slow_weights = copy.deepcopy(self.meta_model.state_dict()) - fast_weights = [] + fast_weights, val_loss, fomaml_grad = [], 0, [] + # sample a batch of tasks for i in range(self.meta_params['B']): - task_params = random.sample(self.task_set, 1)[0] # uniformly sample a task + task_params = random.sample(self.task_set, 1)[0] task_model = copy.deepcopy(self.meta_model) - optimizer = Optimizer(task_model.parameters(), **self.optimizer_params['optimizer']) - for step in range(self.meta_params['k']): + for step in range(self.meta_params['k'] + 1): # generate task-specific data if self.meta_params['data_type'] == 'distribution': assert len(task_params) == 2 @@ -167,23 +166,49 @@ def _train_one_epoch(self, epoch): else: raise NotImplementedError + if step == self.meta_params['k']: continue env_params = {'problem_size': data.size(1), 'pomo_size': data.size(1)} - avg_score, avg_loss = self._train_one_batch(task_model, data, optimizer, Env(**env_params)) - score_AM.update(avg_score, batch_size) - loss_AM.update(avg_loss, batch_size) + avg_score, avg_loss = self._train_one_batch(task_model, data, Env(**env_params)) + score_AM.update(avg_score.item(), batch_size) + loss_AM.update(avg_loss.item(), batch_size) - fast_weights.append(task_model.state_dict()) - - state_dict = {params_key: (slow_weights[params_key] + self.alpha * torch.mean(torch.stack([fast_weight[params_key] - slow_weights[params_key] for fast_weight in fast_weights], dim=0), dim=0)) - for params_key in slow_weights} - self.meta_model.load_state_dict(state_dict) + if self.meta_params['meta_method'] == 'maml': + # cal loss on query(val) set - data + # val_loss += self._fast_val(task_model, val_episodes=64) + val_loss += self._fast_val(task_model, data=data) + elif self.meta_params['meta_method'] == 'fomaml': + loss, _ = self._fast_val(task_model, data=data) + val_loss += loss + task_model.train() + grad = torch.autograd.grad(val_loss, task_model.parameters()) + fomaml_grad.append(grad) + elif self.meta_params['meta_method'] == 'reptile': + fast_weights.append(task_model.state_dict()) + + # update meta-model + if self.meta_params['meta_method'] == 'maml': + val_loss = val_loss / self.meta_params['B'] + self.meta_optimizer.zero_grad() + val_loss.backward() + self.meta_optimizer.step() + elif self.meta_params['meta_method'] == 'fomaml': + updated_weights = self.meta_model.state_dict() + for gradients in fomaml_grad: + updated_weights = OrderedDict( + (name, param - self.optimizer_params['optimizer']['lr'] * grad) + for ((name, param), grad) in zip(updated_weights.items(), gradients) + ) + self.meta_model.load_state_dict(updated_weights) + elif self.meta_params['meta_method'] == 'reptile': + state_dict = {params_key: (slow_weights[params_key] + self.alpha * torch.mean(torch.stack([fast_weight[params_key] - slow_weights[params_key] for fast_weight in fast_weights], dim=0), dim=0)) for params_key in slow_weights} + self.meta_model.load_state_dict(state_dict) # Log Once, for each epoch self.logger.info('Meta Iteration {:3d}: alpha: {:6f}, Score: {:.4f}, Loss: {:.4f}'.format(epoch, self.alpha, score_AM.avg, loss_AM.avg)) return score_AM.avg, loss_AM.avg - def _train_one_batch(self, task_model, data, optimizer, env): + def _train_one_batch(self, task_model, data, env): # Prep task_model.train() @@ -215,42 +240,53 @@ def _train_one_batch(self, task_model, data, optimizer, env): max_pomo_reward, _ = reward.max(dim=1) # get best results from pomo score_mean = -max_pomo_reward.float().mean() # negative sign to make positive value - # Step & Return - task_model.zero_grad() - loss_mean.backward() - optimizer.step() + # update model + create_graph = True if self.meta_params['meta_method'] == 'maml' else False + gradients = torch.autograd.grad(loss_mean, task_model.parameters(), create_graph=create_graph) + fast_weights = OrderedDict( + (name, param - self.optimizer_params['optimizer']['lr'] * grad) + for ((name, param), grad) in zip(task_model.state_dict().items(), gradients) + ) + task_model.load_state_dict(fast_weights) - return score_mean.item(), loss_mean.item() + return score_mean, loss_mean - def _fast_val(self): - val_path = "../../data/TSP/tsp100_tsplib.pkl" - val_episodes = 5000 + def _fast_val(self, model, data=None, val_episodes=1000): aug_factor = 1 - data = torch.Tensor(load_dataset(val_path)[: val_episodes]) + if data is None: + val_path = "../../data/TSP/tsp100_tsplib.pkl" + data = torch.Tensor(load_dataset(val_path)[: val_episodes]) env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)}) - self.meta_model.eval() + model.eval() batch_size = data.size(0) - with torch.no_grad(): + with torch.enable_grad(): env.load_problems(batch_size, problems=data, aug_factor=aug_factor) reset_state, _, _ = env.reset() - self.meta_model.pre_forward(reset_state) + model.pre_forward(reset_state) + prob_list = torch.zeros(size=(batch_size, env.pomo_size, 0)) - state, reward, done = env.pre_step() - while not done: - selected, _ = self.meta_model(state) - # shape: (batch, pomo) - state, reward, done = env.step(selected) + state, reward, done = env.pre_step() + while not done: + selected, prob = model(state) + # shape: (batch, pomo) + state, reward, done = env.step(selected) + prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2) + + # Loss + advantage = reward - reward.float().mean(dim=1, keepdims=True) + log_prob = prob_list.log().sum(dim=2) # for the first/last node, p=1 -> log_p=0 + loss = -advantage * log_prob # Minus Sign: To Increase REWARD + loss_mean = loss.mean() # Return aug_reward = reward.reshape(aug_factor, batch_size, env.pomo_size) # shape: (augmentation, batch, pomo) - max_pomo_reward, _ = aug_reward.max(dim=2) # get best results from pomo # shape: (augmentation, batch) no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive value - print(">> validation results: {}".format(no_aug_score.item())) + return loss_mean, no_aug_score.item() def _alpha_scheduler(self, iter): self.alpha *= self.meta_params['alpha_decay'] diff --git a/POMO/TSP/train_n100.py b/POMO/TSP/train_n100.py index 5444b4f..01f5a52 100644 --- a/POMO/TSP/train_n100.py +++ b/POMO/TSP/train_n100.py @@ -70,12 +70,12 @@ }, 'meta_params': { - 'enable': False, # whether use meta-learning or not - 'meta_method': 'reptile', # choose from ['maml', 'reptile', 'ours'] + 'enable': True, # whether use meta-learning or not + 'meta_method': 'maml', # choose from ['maml', 'fomaml', 'reptile', 'ours'] 'data_type': 'distribution', # choose from ["size", "distribution", "size_distribution"] - 'epochs': 10417, # the number of meta-model updates: (1000*100000) / (3*50*64) + 'epochs': 104167, # the number of meta-model updates: (1000*100000) / (3*50*64) 'B': 3, # the number of tasks in a mini-batch - 'k': 50, # gradient decent steps in the inner-loop optimization of meta-learning method + 'k': 5, # gradient decent steps in the inner-loop optimization of meta-learning method 'meta_batch_size': 64, # the batch size of the inner-loop optimization 'num_task': 50, # the number of tasks in the training task set 'alpha': 0.99, # params for the outer-loop optimization of reptile