Skip to content

Commit

Permalink
update curri strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
RoyalSkye committed Oct 20, 2022
1 parent 5c68ec3 commit 1f4821c
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 75 deletions.
4 changes: 3 additions & 1 deletion POMO/TSP/TSPTester.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from TSPEnv import TSPEnv as Env
from TSPModel import TSPModel as Model

from baselines import solve_all_gurobi
from TSP_gurobi import solve_all_gurobi
from utils.utils import *
from utils.functions import load_dataset, save_dataset

Expand Down Expand Up @@ -60,6 +60,8 @@ def __init__(self,
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
checkpoint = torch.load(checkpoint_fullname, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # TODO: which performance is good? load or not load?
self.logger.info(">> Model loaded from {}".format(checkpoint_fullname))

# utility
self.time_estimator = TimeEstimator()
Expand Down
85 changes: 64 additions & 21 deletions POMO/TSP/TSPTrainer_Meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,33 @@

from torch.optim import Adam as Optimizer
# from torch.optim import SGD as Optimizer
# from torch.optim.lr_scheduler import MultiStepLR as Scheduler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts as Scheduler
from TSProblemDef import get_random_problems, generate_task_set

from utils.utils import *
from utils.functions import *
from TSP_baseline import *


class TSPTrainer:
"""
Implementation of POMO with MAML / FOMAML / Reptile.
Implementation of POMO with MAML / FOMAML / Reptile on TSP.
For MAML & FOMAML, ref to "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks";
For Reptile, ref to "On First-Order Meta-Learning Algorithms".
Refer to "https://lilianweng.github.io/posts/2018-11-30-meta-learning"
MAML's time and space complexity (i.e., GPU memory) is high, so we only update decoder in inner-loop (similar performance).
"""
def __init__(self,
env_params,
model_params,
optimizer_params,
trainer_params):
trainer_params,
meta_params):

# save arguments
self.env_params = env_params
self.model_params = model_params
self.optimizer_params = optimizer_params
self.trainer_params = trainer_params
self.meta_params = trainer_params['meta_params']
self.meta_params = meta_params

# result folder, logger
self.logger = getLogger(name='trainer')
Expand All @@ -60,9 +59,11 @@ def __init__(self,
# Main Components
self.meta_model = Model(**self.model_params)
self.meta_optimizer = Optimizer(self.meta_model.parameters(), **self.optimizer_params['optimizer'])
# self.scheduler = Scheduler(self.meta_optimizer, **self.optimizer_params['scheduler'])
self.alpha = self.meta_params['alpha'] # for reptile
self.task_set = generate_task_set(self.meta_params)
self.min_n, self.max_n, self.task_interval = self.task_set[0][0], self.task_set[-1][0], 5 # [20, 150] / [0, 100]
# self.task_w = {start: 1/(len(self.task_set)//5) for start in range(self.min_n, self.max_n, self.task_interval)}
self.task_w = torch.full((len(self.task_set)//self.task_interval,), 1/(len(self.task_set)//self.task_interval))
self.ema_est = {i[0]: 1 for i in self.task_set}

# Restore
Expand All @@ -75,8 +76,7 @@ def __init__(self,
self.start_epoch = 1 + model_load['epoch']
self.result_log.set_raw_data(checkpoint['result_log'])
self.meta_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# self.scheduler.last_epoch = model_load['epoch']-1
self.logger.info('Saved Model Loaded !!')
self.logger.info(">> Model loaded from {}".format(checkpoint_fullname))

# utility
self.time_estimator = TimeEstimator()
Expand All @@ -90,15 +90,14 @@ def run(self):

# Train
train_score, train_loss = self._train_one_epoch(epoch)
# self.scheduler.step()
self.result_log.append('train_score', epoch, train_score)
self.result_log.append('train_loss', epoch, train_loss)
# Val
dir, no_aug_score_list = "../../data/TSP/", []
if self.meta_params["data_type"] == "size":
paths = ["tsp50_uniform.pkl", "tsp100_uniform.pkl", "tsp200_uniform.pkl"]
elif self.meta_params["data_type"] == "distribution":
paths = ["tsp100_uniform.pkl", "tsp100_gaussian.pkl", "tsp100_cluster.pkl", "tsp100_diagonal.pkl", "tsp100_tsplib.pkl"]
paths = ["tsp100_uniform.pkl", "tsp100_cluster.pkl", "tsp100_diagonal.pkl"]
elif self.meta_params["data_type"] == "size_distribution":
pass
for val_path in paths:
Expand Down Expand Up @@ -135,7 +134,6 @@ def run(self):
'epoch': epoch,
'model_state_dict': self.meta_model.state_dict(),
'optimizer_state_dict': self.meta_optimizer.state_dict(),
# 'scheduler_state_dict': self.scheduler.state_dict(),
'result_log': self.result_log.get_raw_data()
}
torch.save(checkpoint_dict, '{}/checkpoint-{}.pt'.format(self.result_folder, epoch))
Expand All @@ -162,16 +160,23 @@ def _train_one_epoch(self, epoch):
loss_AM = AverageMeter()

"""
Curriculum learning:
Curriculum learning / Adaptive task scheduler:
for size: gradually increase the problem size
for distribution: gradually increase adversarial budgets (i.e., \epsilon)
for distribution: adversarial budgets (i.e., \epsilon) may not be correlated with the hardness of constructed
data distribution. Instead, we evaluate the relative gaps (w.r.t. LKH3) of dist/eps sampled
from each interval every X iters. Hopefully, it can indicate the hardness of its neighbor.
"""
if self.meta_params["data_type"] in ["size", "distribution"]:
self.min_n, self.max_n = self.task_set[0][0], self.task_set[-1][0] # [20, 150] / [0, 130]
if self.meta_params["data_type"] == "size":
# start = self.min_n + int(epoch/self.meta_params['epochs'] * (self.max_n - self.min_n)) # linear
start = self.min_n + int(1/2 * (1-math.cos(math.pi * min(epoch/self.meta_params['epochs'], 1))) * (self.max_n - self.min_n)) # cosine
end = min(start + 10, self.max_n) # 10 is the size of the sliding window
if self.meta_params["curriculum"]: print(">> training task {}".format((start, end)))
elif self.meta_params["data_type"] == "distribution":
# Every X iters, evaluating 50 instances for each interval (e.g., [1, 6) / [6, 11) / ...) using LKH3
if epoch != 0 and epoch % self.meta_params['update_weight'] == 0:
self._update_task_weight()
start = torch.multinomial(self.task_w, 1).item() * self.task_interval
end = min(start + self.task_interval, self.max_n)
elif self.meta_params["data_type"] == "size_distribution":
pass

Expand Down Expand Up @@ -215,6 +220,7 @@ def _train_one_epoch(self, epoch):
loss_AM.update(avg_loss.item(), batch_size)

val_data = self._get_val_data(batch_size, task_params)
self.meta_model.train()
if self.meta_params['meta_method'] == 'maml':
val_loss = self._fast_val(fast_weight, data=val_data, mode="maml")
val_loss /= self.meta_params['B']
Expand Down Expand Up @@ -340,7 +346,7 @@ def _train_one_batch_maml(self, fast_weight, data, env, optimizer=None):

return score_mean, loss_mean, fast_weight

def _fast_val(self, model, data=None, path=None, val_episodes=32, mode="eval"):
def _fast_val(self, model, data=None, path=None, val_episodes=32, mode="eval", return_all=False):
aug_factor = 1
data = torch.Tensor(load_dataset(path)[: val_episodes]) if data is None else data
env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)})
Expand Down Expand Up @@ -395,9 +401,13 @@ def _fast_val(self, model, data=None, path=None, val_episodes=32, mode="eval"):
max_pomo_reward, _ = aug_reward.max(dim=2) # get best results from pomo
# shape: (augmentation, batch)
no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive value
print(no_aug_score)

if mode == "eval":
return no_aug_score.detach().item()
if return_all:
return -max_pomo_reward[0, :].float()
else:
return no_aug_score.detach().item()
else:
return loss_mean

Expand Down Expand Up @@ -460,11 +470,14 @@ def _get_data(self, batch_size, task_params):
return data

def _get_val_data(self, batch_size, task_params):
if self.meta_params["data_type"] in ["size", "distribution"]:
if self.meta_params["data_type"] == "size":
start1, end1 = min(task_params[0] + 10, self.max_n), min(task_params[0] + 20, self.max_n)
val_size = random.sample(range(start1, end1 + 1), 1)[0]
elif self.meta_params["data_type"] == "distribution":
val_size = task_params[0]
elif self.meta_params["data_type"] == "size_distribution":
pass
val_size = random.sample(range(start1, end1 + 1), 1)[0]

val_data = self._get_data(batch_size, (val_size,))

return val_data
Expand All @@ -487,7 +500,6 @@ def minmax(xy_):

if eps == 0: return data
# generate x_adv
print(">> Warning! Generating x_adv!")
self.meta_model.eval()
aug_factor, batch_size = 1, data.size(0)
env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)})
Expand Down Expand Up @@ -519,3 +531,34 @@ def minmax(xy_):
# return data, opt_sol

return data

def _update_task_weight(self):
"""
Update the weights of tasks.
"""
gap = torch.zeros(len(self.task_set)//self.task_interval)
for i in range(gap.size(0)):
start = i * self.task_interval
end = min(start + self.task_interval, self.max_n)
selected = random.sample([j for j in range(start, end+1)], 1)[0]
data = self._get_data(batch_size=50, task_params=(selected, ))
model_score = self._fast_val(self.meta_model, data=data, mode="eval", return_all=True)
model_score = model_score.tolist()

# get results from LKH3 (~14s)
# start_t = time.time()
opts = argparse.ArgumentParser()
opts.cpus, opts.n, opts.progress_bar_mininterval = None, None, 0.1
dataset = [(instance.cpu().numpy(),) for instance in data]
executable = get_lkh_executable()
global run_func
def run_func(args):
return solve_lkh_log(executable, *args, runs=1, disable_cache=True) # otherwise it directly loads data from dir
results, _ = run_all_in_pool(run_func, "./LKH3_result", dataset, opts, use_multiprocessing=False)
gap_list = [(model_score[j]-results[j][0])/results[j][0]*100 for j in range(len(results))]
gap[i] = sum(gap_list)/len(gap_list)
# print(">> LKH3 finished within {}s".format(time.time()-start_t))
print(gap)
print(">> Old task weights: {}".format(self.task_w))
self.task_w = torch.softmax(gap, dim=0)
print(">> New task weights: {}".format(self.task_w))
13 changes: 7 additions & 6 deletions POMO/TSP/TSPTrainer_pomo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ def __init__(self,
env_params,
model_params,
optimizer_params,
trainer_params):
trainer_params,
meta_params):

# save arguments
self.env_params = env_params
self.model_params = model_params
self.optimizer_params = optimizer_params
self.trainer_params = trainer_params
self.meta_params = trainer_params['meta_params']
self.meta_params = meta_params

# result folder, logger
self.logger = getLogger(name='trainer')
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(self,
self.meta_model.load_state_dict(checkpoint['model_state_dict'])
self.start_epoch = 1 + model_load['epoch']
self.result_log.set_raw_data(checkpoint['result_log'])
# self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# self.scheduler.last_epoch = model_load['epoch']-1
self.logger.info('Saved Model Loaded !!')

Expand All @@ -89,13 +90,13 @@ def run(self):
if self.meta_params["data_type"] == "size":
paths = ["tsp50_uniform.pkl", "tsp100_uniform.pkl", "tsp200_uniform.pkl"]
elif self.meta_params["data_type"] == "distribution":
paths = ["tsp100_uniform.pkl", "tsp100_gaussian.pkl", "tsp100_cluster.pkl", "tsp100_diagonal.pkl", "tsp100_tsplib.pkl"]
paths = ["tsp100_uniform.pkl", "tsp100_cluster.pkl", "tsp100_diagonal.pkl"]
elif self.meta_params["data_type"] == "size_distribution":
pass
for val_path in paths:
no_aug_score = self._fast_val(self.meta_model, path=os.path.join(dir, val_path), val_episodes=64)
no_aug_score_list.append(round(no_aug_score, 4))
self.result_log.append('val_score', epoch, no_aug_score_list[-1])
self.result_log.append('val_score', epoch, no_aug_score_list)

# Logs & Checkpoint
elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.meta_params['epochs'])
Expand All @@ -122,7 +123,7 @@ def run(self):
checkpoint_dict = {
'epoch': epoch,
'model_state_dict': self.meta_model.state_dict(),
# 'optimizer_state_dict': self.optimizer.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
# 'scheduler_state_dict': self.scheduler.state_dict(),
'result_log': self.result_log.get_raw_data()
}
Expand Down
15 changes: 8 additions & 7 deletions POMO/TSP/TSP_baseline.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import argparse
import numpy as np
import os, re
import os, re, sys
import time
from datetime import timedelta
from scipy.spatial import distance_matrix
from subprocess import check_call, check_output, CalledProcessError
import torch
from torch.utils.data import Dataset
from urllib.parse import urlparse
from tqdm import tqdm
os.chdir(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, "..") # for utils
Expand Down Expand Up @@ -99,7 +100,7 @@ def solve_concorde_log(executable, directory, name, loc, disable_cache=False):

def get_lkh_executable(url="http://www.akira.ruc.dk/~keld/research/LKH-3/LKH-3.0.7.tgz"):

cwd = os.path.abspath(os.path.join("problems", "vrp", "lkh"))
cwd = os.path.abspath(os.path.join("lkh"))
os.makedirs(cwd, exist_ok=True)

file = os.path.join(cwd, os.path.split(urlparse(url).path)[-1])
Expand Down Expand Up @@ -132,7 +133,7 @@ def solve_lkh_log(executable, directory, name, loc, runs=1, disable_cache=False)
try:
# May have already been run
if os.path.isfile(output_filename) and not disable_cache:
tour, duration = load_dataset(output_filename)
tour, duration = load_dataset(output_filename, disable_print=True)
else:
write_tsplib(problem_filename, loc, name=name)

Expand Down Expand Up @@ -371,17 +372,17 @@ def solve_all_nn(dataset_path, eval_batch_size=1024, no_cuda=False, dataset_n=No
if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("method", type=str, default='concorde', choices=['nn', "gurobi", "gurobigap", "gurobit", "concorde", "lkh", "random_insertion", "nearest_insertion", "farthest_insertion"])
parser.add_argument("datasets", nargs='+', help="Filename of the dataset(s) to evaluate")
parser.add_argument("--method", type=str, default='lkh', choices=['nn', "gurobi", "gurobigap", "gurobit", "concorde", "lkh", "random_insertion", "nearest_insertion", "farthest_insertion"])
parser.add_argument("--datasets", nargs='+', default=["../../data/TSP/tsp50_uniform.pkl", ], help="Filename of the dataset(s) to evaluate")
parser.add_argument("-f", action='store_true', help="Set true to overwrite")
parser.add_argument("-o", default=None, help="Name of the results file to write")
parser.add_argument("--cpus", type=int, help="Number of CPUs to use, defaults to all cores")
parser.add_argument('--no_cuda', action='store_true', help='Disable CUDA (only for Tsiligirides)')
parser.add_argument('--disable_cache', action='store_true', help='Disable caching')
parser.add_argument('--max_calc_batch_size', type=int, default=1000, help='Size for subbatches')
parser.add_argument('--progress_bar_mininterval', type=float, default=0.1, help='Minimum interval')
parser.add_argument('-n', type=int, help="Number of instances to process")
parser.add_argument('--offset', type=int, help="Offset where to start processing")
parser.add_argument('-n', type=int, default=1000, help="Number of instances to process")
parser.add_argument('--offset', type=int, default=0, help="Offset where to start processing")
parser.add_argument('--results_dir', default='results', help="Name of results directory")

opts = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion POMO/TSP/TSP_gurobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def solve_all_gurobi(dataset):
for i, instance in enumerate(dataset):
print("Solving instance {}".format(i))
# some hard instances may take prohibitively long time, and ultimately kill the solver, so we set tl=1800s for TSP100 to avoid that.
result = solve_euclidian_tsp(instance, timeout=1800)
result = solve_euclidian_tsp(instance)
results.append(result)
return results

Expand Down
5 changes: 1 addition & 4 deletions POMO/TSP/TSProblemDef.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@

def generate_task_set(meta_params):
if meta_params['data_type'] == "distribution": # focus on the TSP100 with different distributions
# task_set = [(m, l) for l in [1, 10, 20, 30, 50] for m in range(1, 1 + meta_params['num_task'] // 5)] + [(0, 0)]
task_set = [(eps,) for eps in range(0, 0 + meta_params['num_task'] + 1)]
elif meta_params['data_type'] == "size": # focus on uniform distribution with different sizes
task_set = [(n,) for n in range(20, 20 + meta_params['num_task'] + 1)]
elif meta_params['data_type'] == "size_distribution":
# task_set = [(m, l) for l in [1, 10, 20, 30, 50] for m in range(1, 11)] + [(0, 0)]
# task_set = [(n, m, l) for n in [25, 50, 75, 100, 125, 150] for (m, l) in task_set]
task_set = [(n, eps) for n in range(20, 20 + meta_params['num_task'] + 1) for eps in range(0, 0 + meta_params['num_task'] + 1, 10)]
else:
raise NotImplementedError
Expand Down Expand Up @@ -115,7 +112,7 @@ def generate_tsp_dist(n_samples, n_nodes, distribution):
for i in range(n_samples):
print(n_nodes, i)
loc = []
n_cluster = np.random.randint(low=3, high=9)
n_cluster = np.random.randint(low=2, high=9)
loc.append(np.random.randint(1000, size=[1, n_cluster, 2]))
prob = np.zeros((1000, 1000))
coord = np.concatenate([np.tile(np.arange(1000).reshape(-1, 1, 1), [1, 1000, 1]),
Expand Down
3 changes: 1 addition & 2 deletions POMO/TSP/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
'test_set_path': '../../data/TSP/tsp100_uniform.pkl',
'test_set_opt_sol_path': '../../data/TSP/gurobi/tsp100_uniform.pkl',
'fine_tune_params': {
'enable': True, # evaluate few-shot generalization
'enable': False, # evaluate few-shot generalization
'fine_tune_episodes': 3000, # how many data used to fine-tune the pretrained model
'k': 20, # gradient decent steps in the inner-loop optimization of meta-learning method
'fine_tune_batch_size': 64, # the batch size of the inner-loop optimization
Expand Down Expand Up @@ -101,5 +101,4 @@ def _print_config():


if __name__ == "__main__":
# TODO: 1. why not use test dataset to fine-tune the model?
main()
Loading

0 comments on commit 1f4821c

Please sign in to comment.