Skip to content

Commit

Permalink
add maml and fomaml
Browse files Browse the repository at this point in the history
  • Loading branch information
RoyalSkye committed Sep 8, 2022
1 parent 755887c commit 253156b
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 80 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ __pycache__/
.idea/

# data & pretrain-model
AM/
result/
backup/
data/
pretrained/

# private files
utils_plot*
Expand Down
22 changes: 9 additions & 13 deletions POMO/TSP/TSPModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,19 @@ def forward(self, state):
probs = self.decoder(encoded_last_node, ninf_mask=state.ninf_mask)
# shape: (batch, pomo, problem)

if self.training or self.model_params['eval_type'] == 'softmax':
while True:
selected = probs.reshape(batch_size * pomo_size, -1).multinomial(1) \
.squeeze(dim=1).reshape(batch_size, pomo_size)
while True:
if self.training or self.model_params['eval_type'] == 'softmax':
selected = probs.reshape(batch_size * pomo_size, -1).multinomial(1).squeeze(dim=1).reshape(batch_size, pomo_size)
# shape: (batch, pomo)

prob = probs[state.BATCH_IDX, state.POMO_IDX, selected] \
.reshape(batch_size, pomo_size)
else:
selected = probs.argmax(dim=2)
# shape: (batch, pomo)

if (prob != 0).all():
break

else:
selected = probs.argmax(dim=2)
prob = probs[state.BATCH_IDX, state.POMO_IDX, selected].reshape(batch_size, pomo_size)
# shape: (batch, pomo)
prob = None

if (prob != 0).all():
break

return selected, prob

Expand Down
2 changes: 1 addition & 1 deletion POMO/TSP/TSPTester.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,6 @@ def _fine_tune_one_batch(self, fine_tune_data):
score_mean = -max_pomo_reward.float().mean() # negative sign to make positive value

# Step & Return
self.model.zero_grad()
self.optimizer.zero_grad()
loss_mean.backward()
self.optimizer.step()
30 changes: 16 additions & 14 deletions POMO/TSP/TSPTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def run(self):

# Val
if epoch % self.trainer_params['val_interval'] == 0:
self._fast_val()
val_episodes = 1000
no_aug_score = self._fast_val(self.model, val_episodes=val_episodes)
print(">> validation results: {} over {} instances".format(no_aug_score, val_episodes))

# Logs & Checkpoint
elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.trainer_params['epochs'])
Expand Down Expand Up @@ -148,8 +150,8 @@ def _train_one_epoch(self, epoch):
raise NotImplementedError
env_params = {'problem_size': data.size(1), 'pomo_size': data.size(1)}
avg_score, avg_loss = self._train_one_batch(data, Env(**env_params))
score_AM.update(avg_score, batch_size)
loss_AM.update(avg_loss, batch_size)
score_AM.update(avg_score.item(), batch_size)
loss_AM.update(avg_loss.item(), batch_size)

episode += batch_size

Expand Down Expand Up @@ -200,37 +202,37 @@ def _train_one_batch(self, data, env):
score_mean = -max_pomo_reward.float().mean() # negative sign to make positive value

# Step & Return
self.model.zero_grad()
self.optimizer.zero_grad()
loss_mean.backward()
self.optimizer.step()
return score_mean.item(), loss_mean.item()

def _fast_val(self):
val_path = "../../data/TSP/tsp100_tsplib.pkl"
val_episodes = 5000
return score_mean, loss_mean

def _fast_val(self, model, data=None, val_episodes=1000):
aug_factor = 1
data = torch.Tensor(load_dataset(val_path)[: val_episodes])
if data is None:
val_path = "../../data/TSP/tsp100_tsplib.pkl"
data = torch.Tensor(load_dataset(val_path)[: val_episodes])
env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)})

self.model.eval()
model.eval()
batch_size = data.size(0)
with torch.no_grad():
env.load_problems(batch_size, problems=data, aug_factor=aug_factor)
reset_state, _, _ = env.reset()
self.model.pre_forward(reset_state)
model.pre_forward(reset_state)

state, reward, done = env.pre_step()
while not done:
selected, _ = self.model(state)
selected, _ = model(state)
# shape: (batch, pomo)
state, reward, done = env.step(selected)

# Return
aug_reward = reward.reshape(aug_factor, batch_size, env.pomo_size)
# shape: (augmentation, batch, pomo)

max_pomo_reward, _ = aug_reward.max(dim=2) # get best results from pomo
# shape: (augmentation, batch)
no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive value

print(">> validation results: {}".format(no_aug_score.item()))
return no_aug_score.item()
128 changes: 82 additions & 46 deletions POMO/TSP/TSPTrainer_Meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random
import torch
from logging import getLogger
from collections import OrderedDict

from TSPEnv import TSPEnv as Env
from TSPModel import TSPModel as Model
Expand All @@ -17,7 +18,9 @@

class TSPTrainer:
"""
Implementation of POMO with Reptile.
Implementation of POMO with MAML / FOMAML / Reptile.
For MAML & FOMAML, ref to "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks";
For Reptile, ref to "On First-Order Meta-Learning Algorithms".
"""
def __init__(self,
env_params,
Expand Down Expand Up @@ -50,7 +53,8 @@ def __init__(self,

# Main Components
self.meta_model = Model(**self.model_params)
self.alpha = self.meta_params['alpha']
self.meta_optimizer = Optimizer(self.meta_model.parameters(), **self.optimizer_params['optimizer'])
self.alpha = self.meta_params['alpha'] # for reptile
self.task_set = generate_task_set(self.meta_params)
assert self.trainer_params['meta_params']['epochs'] == math.ceil((1000 * 100000) / (
self.trainer_params['meta_params']['B'] * self.trainer_params['meta_params']['k'] *
Expand Down Expand Up @@ -79,20 +83,15 @@ def run(self):
self.logger.info('=================================================================')

# Train
if self.meta_params['meta_method'] == 'maml':
pass
elif self.meta_params['meta_method'] == 'fomaml':
pass
elif self.meta_params['meta_method'] == 'reptile':
train_score, train_loss = self._train_one_epoch(epoch)
else:
raise NotImplementedError
train_score, train_loss = self._train_one_epoch(epoch)
self.result_log.append('train_score', epoch, train_score)
self.result_log.append('train_loss', epoch, train_loss)

# Val
if epoch % self.trainer_params['val_interval'] == 0:
self._fast_val()
val_episodes = 1000
no_aug_score = self._fast_val(self.meta_model, val_episodes=val_episodes)
print(">> validation results: {} over {} instances".format(no_aug_score, val_episodes))

# Logs & Checkpoint
elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.meta_params['epochs'])
Expand Down Expand Up @@ -137,23 +136,23 @@ def run(self):
def _train_one_epoch(self, epoch):
"""
1. Sample B training tasks from task distribution P(T)
2. for a batch of tasks T_i, do reptile -> \theta_i
3. update meta-model -> \theta_0
2. inner-loop: for a batch of tasks T_i, do reptile -> \theta_i
3. outer-loop: update meta-model -> \theta_0
"""
score_AM = AverageMeter()
loss_AM = AverageMeter()

batch_size = self.meta_params['meta_batch_size']

self._alpha_scheduler(epoch)
slow_weights = copy.deepcopy(self.meta_model.state_dict())
fast_weights = []
fast_weights, val_loss, fomaml_grad = [], 0, []

# sample a batch of tasks
for i in range(self.meta_params['B']):
task_params = random.sample(self.task_set, 1)[0] # uniformly sample a task
task_params = random.sample(self.task_set, 1)[0]
task_model = copy.deepcopy(self.meta_model)
optimizer = Optimizer(task_model.parameters(), **self.optimizer_params['optimizer'])

for step in range(self.meta_params['k']):
for step in range(self.meta_params['k'] + 1):
# generate task-specific data
if self.meta_params['data_type'] == 'distribution':
assert len(task_params) == 2
Expand All @@ -167,23 +166,49 @@ def _train_one_epoch(self, epoch):
else:
raise NotImplementedError

if step == self.meta_params['k']: continue
env_params = {'problem_size': data.size(1), 'pomo_size': data.size(1)}
avg_score, avg_loss = self._train_one_batch(task_model, data, optimizer, Env(**env_params))
score_AM.update(avg_score, batch_size)
loss_AM.update(avg_loss, batch_size)
avg_score, avg_loss = self._train_one_batch(task_model, data, Env(**env_params))
score_AM.update(avg_score.item(), batch_size)
loss_AM.update(avg_loss.item(), batch_size)

fast_weights.append(task_model.state_dict())

state_dict = {params_key: (slow_weights[params_key] + self.alpha * torch.mean(torch.stack([fast_weight[params_key] - slow_weights[params_key] for fast_weight in fast_weights], dim=0), dim=0))
for params_key in slow_weights}
self.meta_model.load_state_dict(state_dict)
if self.meta_params['meta_method'] == 'maml':
# cal loss on query(val) set - data
# val_loss += self._fast_val(task_model, val_episodes=64)
val_loss += self._fast_val(task_model, data=data)
elif self.meta_params['meta_method'] == 'fomaml':
loss, _ = self._fast_val(task_model, data=data)
val_loss += loss
task_model.train()
grad = torch.autograd.grad(val_loss, task_model.parameters())
fomaml_grad.append(grad)
elif self.meta_params['meta_method'] == 'reptile':
fast_weights.append(task_model.state_dict())

# update meta-model
if self.meta_params['meta_method'] == 'maml':
val_loss = val_loss / self.meta_params['B']
self.meta_optimizer.zero_grad()
val_loss.backward()
self.meta_optimizer.step()
elif self.meta_params['meta_method'] == 'fomaml':
updated_weights = self.meta_model.state_dict()
for gradients in fomaml_grad:
updated_weights = OrderedDict(
(name, param - self.optimizer_params['optimizer']['lr'] * grad)
for ((name, param), grad) in zip(updated_weights.items(), gradients)
)
self.meta_model.load_state_dict(updated_weights)
elif self.meta_params['meta_method'] == 'reptile':
state_dict = {params_key: (slow_weights[params_key] + self.alpha * torch.mean(torch.stack([fast_weight[params_key] - slow_weights[params_key] for fast_weight in fast_weights], dim=0), dim=0)) for params_key in slow_weights}
self.meta_model.load_state_dict(state_dict)

# Log Once, for each epoch
self.logger.info('Meta Iteration {:3d}: alpha: {:6f}, Score: {:.4f}, Loss: {:.4f}'.format(epoch, self.alpha, score_AM.avg, loss_AM.avg))

return score_AM.avg, loss_AM.avg

def _train_one_batch(self, task_model, data, optimizer, env):
def _train_one_batch(self, task_model, data, env):

# Prep
task_model.train()
Expand Down Expand Up @@ -215,42 +240,53 @@ def _train_one_batch(self, task_model, data, optimizer, env):
max_pomo_reward, _ = reward.max(dim=1) # get best results from pomo
score_mean = -max_pomo_reward.float().mean() # negative sign to make positive value

# Step & Return
task_model.zero_grad()
loss_mean.backward()
optimizer.step()
# update model
create_graph = True if self.meta_params['meta_method'] == 'maml' else False
gradients = torch.autograd.grad(loss_mean, task_model.parameters(), create_graph=create_graph)
fast_weights = OrderedDict(
(name, param - self.optimizer_params['optimizer']['lr'] * grad)
for ((name, param), grad) in zip(task_model.state_dict().items(), gradients)
)
task_model.load_state_dict(fast_weights)

return score_mean.item(), loss_mean.item()
return score_mean, loss_mean

def _fast_val(self):
val_path = "../../data/TSP/tsp100_tsplib.pkl"
val_episodes = 5000
def _fast_val(self, model, data=None, val_episodes=1000):
aug_factor = 1
data = torch.Tensor(load_dataset(val_path)[: val_episodes])
if data is None:
val_path = "../../data/TSP/tsp100_tsplib.pkl"
data = torch.Tensor(load_dataset(val_path)[: val_episodes])
env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)})

self.meta_model.eval()
model.eval()
batch_size = data.size(0)
with torch.no_grad():
with torch.enable_grad():
env.load_problems(batch_size, problems=data, aug_factor=aug_factor)
reset_state, _, _ = env.reset()
self.meta_model.pre_forward(reset_state)
model.pre_forward(reset_state)
prob_list = torch.zeros(size=(batch_size, env.pomo_size, 0))

state, reward, done = env.pre_step()
while not done:
selected, _ = self.meta_model(state)
# shape: (batch, pomo)
state, reward, done = env.step(selected)
state, reward, done = env.pre_step()
while not done:
selected, prob = model(state)
# shape: (batch, pomo)
state, reward, done = env.step(selected)
prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2)

# Loss
advantage = reward - reward.float().mean(dim=1, keepdims=True)
log_prob = prob_list.log().sum(dim=2) # for the first/last node, p=1 -> log_p=0
loss = -advantage * log_prob # Minus Sign: To Increase REWARD
loss_mean = loss.mean()

# Return
aug_reward = reward.reshape(aug_factor, batch_size, env.pomo_size)
# shape: (augmentation, batch, pomo)

max_pomo_reward, _ = aug_reward.max(dim=2) # get best results from pomo
# shape: (augmentation, batch)
no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive value

print(">> validation results: {}".format(no_aug_score.item()))
return loss_mean, no_aug_score.item()

def _alpha_scheduler(self, iter):
self.alpha *= self.meta_params['alpha_decay']
8 changes: 4 additions & 4 deletions POMO/TSP/train_n100.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@

},
'meta_params': {
'enable': False, # whether use meta-learning or not
'meta_method': 'reptile', # choose from ['maml', 'reptile', 'ours']
'enable': True, # whether use meta-learning or not
'meta_method': 'maml', # choose from ['maml', 'fomaml', 'reptile', 'ours']
'data_type': 'distribution', # choose from ["size", "distribution", "size_distribution"]
'epochs': 10417, # the number of meta-model updates: (1000*100000) / (3*50*64)
'epochs': 104167, # the number of meta-model updates: (1000*100000) / (3*50*64)
'B': 3, # the number of tasks in a mini-batch
'k': 50, # gradient decent steps in the inner-loop optimization of meta-learning method
'k': 5, # gradient decent steps in the inner-loop optimization of meta-learning method
'meta_batch_size': 64, # the batch size of the inner-loop optimization
'num_task': 50, # the number of tasks in the training task set
'alpha': 0.99, # params for the outer-loop optimization of reptile
Expand Down

0 comments on commit 253156b

Please sign in to comment.