Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
RoyalSkye committed Oct 7, 2022
1 parent b136d7b commit 13c04ec
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 96 deletions.
30 changes: 20 additions & 10 deletions POMO/TSP/TSPModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,24 +307,34 @@ class Add_And_Normalization_Module(nn.Module):
def __init__(self, **model_params):
super().__init__()
embedding_dim = model_params['embedding_dim']
# self.norm = nn.BatchNorm1d(embedding_dim, affine=True, track_running_stats=True)
self.norm = nn.InstanceNorm1d(embedding_dim, affine=True, track_running_stats=False)

def forward(self, input1, input2, weights=None):
if weights is None:
# input.shape: (batch, problem, embedding)
added = input1 + input2
# shape: (batch, problem, embedding)
transposed = added.transpose(1, 2)
# shape: (batch, embedding, problem)
normalized = self.norm(transposed)
# shape: (batch, embedding, problem)
back_trans = normalized.transpose(1, 2)
# shape: (batch, problem, embedding)
if isinstance(self.norm, nn.InstanceNorm1d):
transposed = added.transpose(1, 2)
# shape: (batch, embedding, problem)
normalized = self.norm(transposed)
# shape: (batch, embedding, problem)
back_trans = normalized.transpose(1, 2)
# shape: (batch, problem, embedding)
elif isinstance(self.norm, nn.BatchNorm1d):
batch, problem, embedding = added.size()
normalized = self.norm(added.reshape(-1, embedding))
back_trans = normalized.reshape(batch, problem, embedding)
else:
added = input1 + input2
transposed = added.transpose(1, 2)
normalized = F.instance_norm(transposed, weight=weights['weight'], bias=weights['bias'])
back_trans = normalized.transpose(1, 2)
if isinstance(self.norm, nn.InstanceNorm1d):
transposed = added.transpose(1, 2)
normalized = F.instance_norm(transposed, weight=weights['weight'], bias=weights['bias'])
back_trans = normalized.transpose(1, 2)
elif isinstance(self.norm, nn.BatchNorm1d):
batch, problem, embedding = added.size()
normalized = F.batch_norm(added.reshape(-1, embedding), running_mean=self.norm.running_mean, running_var=self.norm.running_var, weight=weights['weight'], bias=weights['bias'], training=True)
back_trans = normalized.reshape(batch, problem, embedding)

return back_trans

Expand Down
160 changes: 116 additions & 44 deletions POMO/TSP/TSPTrainer_Meta.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import copy
import math
import time
Expand All @@ -11,16 +12,16 @@

from torch.optim import Adam as Optimizer
# from torch.optim import SGD as Optimizer
from torch.optim.lr_scheduler import MultiStepLR as Scheduler
# from torch.optim.lr_scheduler import MultiStepLR as Scheduler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts as Scheduler
from TSProblemDef import get_random_problems, generate_task_set

from utils.utils import *
from utils.functions import load_dataset
from utils.functions import *


class TSPTrainer:
"""
TODO: 1. val data? and training data, for k steps of inner-loop, should we use the same batch of data?
Implementation of POMO with MAML / FOMAML / Reptile.
For MAML & FOMAML, ref to "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks";
For Reptile, ref to "On First-Order Meta-Learning Algorithms".
Expand Down Expand Up @@ -59,8 +60,10 @@ def __init__(self,
# Main Components
self.meta_model = Model(**self.model_params)
self.meta_optimizer = Optimizer(self.meta_model.parameters(), **self.optimizer_params['optimizer'])
# self.scheduler = Scheduler(self.meta_optimizer, **self.optimizer_params['scheduler'])
self.alpha = self.meta_params['alpha'] # for reptile
self.task_set = generate_task_set(self.meta_params)
self.ema_est = {i[0]: 1 for i in self.task_set}

# Restore
self.start_epoch = 1
Expand All @@ -71,7 +74,7 @@ def __init__(self,
self.meta_model.load_state_dict(checkpoint['model_state_dict'])
self.start_epoch = 1 + model_load['epoch']
self.result_log.set_raw_data(checkpoint['result_log'])
# self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.meta_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# self.scheduler.last_epoch = model_load['epoch']-1
self.logger.info('Saved Model Loaded !!')

Expand All @@ -87,19 +90,24 @@ def run(self):

# Train
train_score, train_loss = self._train_one_epoch(epoch)
# self.scheduler.step()
self.result_log.append('train_score', epoch, train_score)
self.result_log.append('train_loss', epoch, train_loss)
# Val
if self.meta_params['meta_method'] in ['fomaml', 'reptile']:
no_aug_score = self._fast_val(copy.deepcopy(self.meta_model), val_episodes=32, mode="eval")
else:
no_aug_score = self._fast_val(self.meta_model, val_episodes=32, mode="eval")
self.result_log.append('val_score', epoch, no_aug_score)
dir, no_aug_score_list = "../../data/TSP/", []
# for val_path in ["tsp50_uniform.pkl", "tsp100_uniform.pkl", "tsp100_diagonal.pkl", "tsp150_uniform.pkl", "tsp200_uniform.pkl"]:
for val_path in ["tsp50_uniform.pkl", "tsp100_uniform.pkl", "tsp150_uniform.pkl", "tsp200_uniform.pkl"]:
if self.meta_params['meta_method'] in ['fomaml', 'reptile']:
no_aug_score = self._fast_val(copy.deepcopy(self.meta_model), path=os.path.join(dir, val_path), val_episodes=64, mode="eval")
else:
no_aug_score = self._fast_val(self.meta_model, path=os.path.join(dir, val_path), val_episodes=64, mode="eval")
no_aug_score_list.append(round(no_aug_score, 4))
self.result_log.append('val_score', epoch, no_aug_score_list)

# Logs & Checkpoint
elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.meta_params['epochs'])
self.logger.info("Epoch {:3d}/{:3d}({:.2f}%): Time Est.: Elapsed[{}], Remain[{}], Val Score: {:.4f}".format(
epoch, self.meta_params['epochs'], epoch/self.meta_params['epochs']*100, elapsed_time_str, remain_time_str, no_aug_score))
self.logger.info("Epoch {:3d}/{:3d}({:.2f}%): Time Est.: Elapsed[{}], Remain[{}], Val Score: {}".format(
epoch, self.meta_params['epochs'], epoch/self.meta_params['epochs']*100, elapsed_time_str, remain_time_str, no_aug_score_list))

if self.trainer_params['stop_criterion'] == "epochs":
all_done = (epoch == self.meta_params['epochs'])
Expand All @@ -121,7 +129,7 @@ def run(self):
checkpoint_dict = {
'epoch': epoch,
'model_state_dict': self.meta_model.state_dict(),
# 'optimizer_state_dict': self.optimizer.state_dict(),
'optimizer_state_dict': self.meta_optimizer.state_dict(),
# 'scheduler_state_dict': self.scheduler.state_dict(),
'result_log': self.result_log.get_raw_data()
}
Expand All @@ -135,16 +143,16 @@ def run(self):

if all_done:
self.logger.info(" *** Training Done *** ")
self.logger.info("Now, printing log array...")
util_print_log_array(self.logger, self.result_log)
# self.logger.info("Now, printing log array...")
# util_print_log_array(self.logger, self.result_log)

def _train_one_epoch(self, epoch):
"""
1. Sample B training tasks from task distribution P(T)
2. inner-loop: for a batch of tasks T_i, do reptile -> \theta_i
3. outer-loop: update meta-model -> \theta_0
"""
self.meta_model.train()
self.meta_optimizer.zero_grad()
score_AM = AverageMeter()
loss_AM = AverageMeter()
batch_size = self.meta_params['meta_batch_size']
Expand All @@ -165,36 +173,44 @@ def _train_one_epoch(self, epoch):
fast_weight = OrderedDict(self.meta_model.decoder.named_parameters())
for k in list(fast_weight.keys()):
fast_weight["decoder."+k] = fast_weight.pop(k)
optimizer = Optimizer(fast_weight.values(), **self.optimizer_params['optimizer'])
optimizer.load_state_dict(self.meta_optimizer.state_dict())

for step in range(self.meta_params['k'] + 1):
# generate task-specific data
# inner-loop optimization
for step in range(self.meta_params['k']):
data = self._get_data(batch_size, task_params)
if step == self.meta_params['k']: continue
data = self._generate_x_adv(data, eps=random.randint(10, 100)) if self.trainer_params['adv_train'] else data
env_params = {'problem_size': data.size(1), 'pomo_size': data.size(1)}

self.meta_model.train()
if self.meta_params['meta_method'] in ['reptile', 'fomaml']:
avg_score, avg_loss = self._train_one_batch(task_model, data, Env(**env_params), optimizer)
elif self.meta_params['meta_method'] == 'maml':
avg_score, avg_loss, fast_weight = self._train_one_batch_maml(fast_weight, data, Env(**env_params))

avg_score, avg_loss, fast_weight = self._train_one_batch_maml(fast_weight, data, Env(**env_params), optimizer)
score_AM.update(avg_score.item(), batch_size)
loss_AM.update(avg_loss.item(), batch_size)

val_data = self._get_val_data(self.meta_params['val_batch_size'], task_params)
if self.meta_params['meta_method'] == 'maml':
# cal loss on query(val) set - data
val_loss += self._fast_val(fast_weight, data=data, mode="maml")
val_loss = self._fast_val(fast_weight, data=val_data, mode="maml")
val_loss /= self.meta_params['B']
val_loss.backward()
elif self.meta_params['meta_method'] == 'fomaml':
val_loss = self._fast_val(task_model, data=data, mode="fomaml")
val_loss = self._fast_val(task_model, data=val_data, mode="fomaml")
grad = torch.autograd.grad(val_loss, task_model.parameters())
fomaml_grad.append(grad)
self.meta_optimizer.load_state_dict(optimizer.state_dict())
elif self.meta_params['meta_method'] == 'reptile':
fast_weights.append(task_model.state_dict())

# update meta-model
# outer-loop optimization (update meta-model)
if self.meta_params['meta_method'] == 'maml':
val_loss /= self.meta_params['B']
self.meta_optimizer.zero_grad()
val_loss.backward()
# val_loss /= self.meta_params['B']
# self.meta_optimizer.zero_grad()
# val_loss.backward()
# print(self.meta_model.encoder.embedding.weight.grad.norm(p=2).cpu().item())
# print(self.meta_model.decoder.multi_head_combine.weight.grad.norm(p=2).cpu().item())
# grad_norms = clip_grad_norms(self.meta_optimizer.param_groups, max_norm=1.0)
# print(grad_norms[0])
self.meta_optimizer.step()
elif self.meta_params['meta_method'] == 'fomaml':
updated_weights = self.meta_model.state_dict()
Expand All @@ -213,7 +229,7 @@ def _train_one_epoch(self, epoch):

return score_AM.avg, loss_AM.avg

def _train_one_batch(self, task_model, data, env, optimizer):
def _train_one_batch(self, task_model, data, env, optimizer=None):

task_model.train()
batch_size = data.size(0)
Expand All @@ -231,7 +247,9 @@ def _train_one_batch(self, task_model, data, env, optimizer):
state, reward, done = env.step(selected)
prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2)

# Loss
# Loss & adjust reward
# self.ema_est[data.size(1)] = 0.99 * self.ema_est[data.size(1)] + (1 - 0.99) * (-reward.float().mean().item()) if self.ema_est[data.size(1)] != 1 else -reward.float().mean().item()
# reward = reward / self.ema_est[data.size(1)]
advantage = reward - reward.float().mean(dim=1, keepdims=True)
# shape: (batch, pomo)
log_prob = prob_list.log().sum(dim=2) # for the first/last node, p=1 -> log_p=0
Expand All @@ -252,7 +270,7 @@ def _train_one_batch(self, task_model, data, env, optimizer):

return score_mean, loss_mean

def _train_one_batch_maml(self, fast_weight, data, env):
def _train_one_batch_maml(self, fast_weight, data, env, optimizer=None):

batch_size = data.size(0)
env.load_problems(batch_size, problems=data, aug_factor=1)
Expand All @@ -269,21 +287,25 @@ def _train_one_batch_maml(self, fast_weight, data, env):
state, reward, done = env.step(selected)
prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2)

# Loss
# Loss & adjust reward
# self.ema_est[data.size(1)] = 0.99 * self.ema_est[data.size(1)] + (1 - 0.99) * (-reward.float().mean().item()) if self.ema_est[data.size(1)] != 1 else -reward.float().mean().item()
# print(self.ema_est)
# reward = reward / self.ema_est[data.size(1)]
advantage = reward - reward.float().mean(dim=1, keepdims=True)
# shape: (batch, pomo)
log_prob = prob_list.log().sum(dim=2) # for the first/last node, p=1 -> log_p=0
# size = (batch, pomo)
loss = -advantage * log_prob # Minus Sign: To Increase REWARD
# shape: (batch, pomo)
loss_mean = loss.mean()

# update model
gradients = torch.autograd.grad(loss_mean, fast_weight.values(), create_graph=True)
fast_weight = OrderedDict(
(name, param - self.optimizer_params['optimizer']['lr'] * grad)
for ((name, param), grad) in zip(fast_weight.items(), gradients)
)
# gradients = torch.autograd.grad(loss_mean, fast_weight.values(), create_graph=True) # allow_unused=True
# fast_weight = OrderedDict(
# (name, param - self.optimizer_params['optimizer']['lr'] * grad)
# for ((name, param), grad) in zip(fast_weight.items(), gradients)
# )
optimizer.zero_grad()
loss_mean.backward(retain_graph=True, create_graph=True)
optimizer.step()

# Score
max_pomo_reward, _ = reward.max(dim=1) # get best results from pomo
Expand All @@ -292,12 +314,9 @@ def _train_one_batch_maml(self, fast_weight, data, env):

return score_mean, loss_mean, fast_weight

def _fast_val(self, model, data=None, val_episodes=32, mode="eval"):

def _fast_val(self, model, data=None, path=None, val_episodes=32, mode="eval"):
aug_factor = 1
if data is None:
val_path = "../../data/TSP/tsp150_uniform.pkl"
data = torch.Tensor(load_dataset(val_path)[: val_episodes])
data = torch.Tensor(load_dataset(path)[: val_episodes]) if data is None else data
env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)})

batch_size = data.size(0)
Expand Down Expand Up @@ -335,6 +354,8 @@ def _fast_val(self, model, data=None, val_episodes=32, mode="eval"):
bootstrap_reward = self._bootstrap(fast_weight, data)
advantage = reward - bootstrap_reward
else:
# self.ema_est[data.size(1)] = 0.99 * self.ema_est[data.size(1)] + (1 - 0.99) * (-reward.float().mean().item()) if self.ema_est[data.size(1)] != 1 else -reward.float().mean().item()
# reward = reward / self.ema_est[data.size(1)]
advantage = reward - reward.float().mean(dim=1, keepdims=True)
log_prob = prob_list.log().sum(dim=2) # for the first/last node, p=1 -> log_p=0
loss = -advantage * log_prob # Minus Sign: To Increase REWARD
Expand Down Expand Up @@ -410,8 +431,59 @@ def _get_data(self, batch_size, task_params):

return data

def _get_val_data(self, batch_size, task_params):
val_data = self._get_data(batch_size, task_params)
# val_path = "../../data/TSP/tsp150_uniform.pkl"
# val_data = torch.Tensor(load_dataset(val_path)[: batch_size])

return val_data

def _alpha_scheduler(self, iter):
"""
Update param for Reptile.
"""
self.alpha = max(self.alpha * self.meta_params['alpha_decay'], 0.0001)

def _generate_x_adv(self, data, eps=10.0):
"""
Generate adversarial data based on the current model, also need to generate optimal sol for x_adv.
"""
from torch.autograd import Variable
def minmax(xy_):
# min_max normalization: [b,n,2]
xy_ = (xy_ - xy_.min(dim=1, keepdims=True)[0]) / (xy_.max(dim=1, keepdims=True)[0] - xy_.min(dim=1, keepdims=True)[0])
return xy_

# generate x_adv
print(">> Warning! Generating x_adv!")
self.meta_model.eval()
aug_factor, batch_size = 1, data.size(0)
env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)})
with torch.enable_grad():
data.requires_grad_()
env.load_problems(batch_size, problems=data, aug_factor=aug_factor)
reset_state, _, _ = env.reset()
self.meta_model.pre_forward(reset_state)
prob_list = torch.zeros(size=(aug_factor * batch_size, env.pomo_size, 0))
state, reward, done = env.pre_step()
while not done:
selected, prob = self.meta_model(state)
state, reward, done = env.step(selected)
prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2)

aug_reward = reward.reshape(aug_factor, batch_size, env.pomo_size).permute(1, 0, 2).view(batch_size, -1)
baseline_reward = aug_reward.float().mean(dim=1, keepdims=True)
advantage = aug_reward - baseline_reward
log_prob = prob_list.log().sum(dim=2).reshape(aug_factor, batch_size, env.pomo_size).permute(1, 0, 2).view(batch_size, -1)

# delta = torch.autograd.grad(eps * ((advantage / baseline_reward) * log_prob).mean(), data)[0]
delta = torch.autograd.grad(eps * ((-advantage) * log_prob).mean(), data)[0]
data = data.detach() + delta
data = minmax(data)
data = Variable(data, requires_grad=False)

# generate opt sol
# opt_sol = solve_all_gurobi(data)
# return data, opt_sol

return data
Loading

0 comments on commit 13c04ec

Please sign in to comment.