Skip to content

Commit

Permalink
add bootstrap code
Browse files Browse the repository at this point in the history
  • Loading branch information
RoyalSkye committed Dec 7, 2022
1 parent 238cbc8 commit cdeec9e
Show file tree
Hide file tree
Showing 16 changed files with 658 additions and 353 deletions.
32 changes: 19 additions & 13 deletions POMO/CVRP/CVRPModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ def pre_forward(self, reset_state, weights=None):
# shape: (batch, problem+1, embedding)
self.decoder.set_kv(self.encoded_nodes, weights=weights)

def forward(self, state, weights=None):
def forward(self, state, weights=None, selected=None, return_probs=False):
batch_size = state.BATCH_IDX.size(0)
pomo_size = state.BATCH_IDX.size(1)

if state.selected_count == 0: # First Move, depot
selected = torch.zeros(size=(batch_size, pomo_size), dtype=torch.long)
prob = torch.ones(size=(batch_size, pomo_size))
probs = torch.ones(size=(batch_size, pomo_size, self.encoded_nodes.size(1)))
# shape: (batch, pomo, problem_size+1)

# # Use Averaged encoded nodes for decoder input_1
# encoded_nodes_mean = self.encoded_nodes.mean(dim=1, keepdim=True)
Expand All @@ -53,27 +55,31 @@ def forward(self, state, weights=None):
elif state.selected_count == 1: # Second Move, POMO
selected = torch.arange(start=1, end=pomo_size+1)[None, :].expand(batch_size, pomo_size)
prob = torch.ones(size=(batch_size, pomo_size))
probs = torch.ones(size=(batch_size, pomo_size, self.encoded_nodes.size(1)))

else:
encoded_last_node = _get_encoding(self.encoded_nodes, state.current_node)
# shape: (batch, pomo, embedding)
probs = self.decoder(encoded_last_node, state.load, ninf_mask=state.ninf_mask, weights=weights)
# shape: (batch, pomo, problem+1)

while True:
if self.training or self.model_params['eval_type'] == 'softmax':
selected = probs.reshape(batch_size * pomo_size, -1).multinomial(1).squeeze(dim=1).reshape(batch_size, pomo_size)
# shape: (batch, pomo)
else:
selected = probs.argmax(dim=2)
if selected is None:
while True:
if self.training or self.model_params['eval_type'] == 'softmax':
selected = probs.reshape(batch_size * pomo_size, -1).multinomial(1).squeeze(dim=1).reshape(batch_size, pomo_size)
# shape: (batch, pomo)
else:
selected = probs.argmax(dim=2)
# shape: (batch, pomo)
prob = probs[state.BATCH_IDX, state.POMO_IDX, selected].reshape(batch_size, pomo_size)
# shape: (batch, pomo)

if (prob != 0).all():
break
else:
selected = selected
prob = probs[state.BATCH_IDX, state.POMO_IDX, selected].reshape(batch_size, pomo_size)
# shape: (batch, pomo)

if (prob != 0).all():
break

if return_probs:
return selected, prob, probs
return selected, prob


Expand Down
5 changes: 2 additions & 3 deletions POMO/CVRP/CVRPTester.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from CVRPEnv import CVRPEnv as Env
from CVRPModel import CVRPModel as Model

from TSP_gurobi import solve_all_gurobi
from utils.utils import *
from utils.functions import load_dataset, save_dataset

Expand Down Expand Up @@ -106,8 +105,8 @@ def _test(self, store_res=True):
score_AM.update(score, batch_size)
aug_score_AM.update(aug_score, batch_size)
episode += batch_size
gap = [max(all_score[i].item() - opt_sol[i], 0) / opt_sol[i] * 100 for i in range(batch_size)]
aug_gap = [max(all_aug_score[i].item() - opt_sol[i], 0) / opt_sol[i] * 100 for i in range(batch_size)]
gap = [(all_score[i].item() - opt_sol[i]) / opt_sol[i] * 100 for i in range(batch_size)]
aug_gap = [(all_aug_score[i].item() - opt_sol[i]) / opt_sol[i] * 100 for i in range(batch_size)]
gap_AM.update(sum(gap) / batch_size, batch_size)
aug_gap_AM.update(sum(aug_gap) / batch_size, batch_size)

Expand Down
229 changes: 141 additions & 88 deletions POMO/CVRP/CVRPTrainer_meta.py

Large diffs are not rendered by default.

81 changes: 71 additions & 10 deletions POMO/CVRP/CVRPTrainer_pomo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,16 @@ def __init__(self,
torch.set_default_tensor_type('torch.FloatTensor')

# Main Components
self.model_params["norm"] = "instance" # Original "POMO" Paper uses instance/batch normalization
self.model_params["norm"] = "batch" # Original "POMO" Paper uses batch normalization
self.model = Model(**self.model_params)
self.optimizer = Optimizer(self.model.parameters(), **self.optimizer_params['optimizer'])
self.task_set = generate_task_set(self.meta_params)
self.task_w = torch.full((len(self.task_set),), 1 / len(self.task_set))
self.val_data, self.val_opt = {}, {} # for lkh3_offline
assert not (self.meta_params['curriculum'] and self.meta_params["data_type"] in ["size", "distribution"]), "Not Implemented!"
if self.meta_params["data_type"] == "size_distribution":
# hardcoded - task_set: range(self.min_n, self.max_n, self.task_interval) * self.num_dist
self.min_n, self.max_n, self.task_interval, self.num_dist = 50, 200, 5, 11
self.task_w = torch.full(((self.max_n - self.min_n) // 5 + 1, self.num_dist), 1 / self.num_dist)

# Restore
self.start_epoch = 1
Expand Down Expand Up @@ -100,7 +105,8 @@ def run(self):
dir = "../../data/CVRP/Distribution/"
paths = ["cvrp100_uniform.pkl", "cvrp100_gaussian.pkl", "cvrp100_cluster.pkl", "cvrp100_diagonal.pkl", "cvrp100_cvrplib.pkl"]
elif self.meta_params["data_type"] == "size_distribution":
pass
dir = "../../data/CVRP/Size_Distribution/"
paths = ["cvrp200_uniform.pkl", "cvrp200_gaussian.pkl", "cvrp300_rotation.pkl"]
if epoch <= 1 or (epoch % img_save_interval) == 0:
for val_path in paths:
no_aug_score = self._fast_val(self.model, path=os.path.join(dir, val_path), val_episodes=64)
Expand Down Expand Up @@ -145,20 +151,28 @@ def _train_one_epoch(self, epoch):
loss_AM = AverageMeter()
batch_size = self.meta_params['meta_batch_size']

# Adaptive task scheduler
start, end = 0, 0
pass
# Adaptive task scheduler - Not implemented for "size" and "distribution"
if self.meta_params['curriculum']:
if self.meta_params["data_type"] == "size_distribution":
start = self.min_n + int(min(epoch / self.meta_params['sch_epoch'], 1) * (self.max_n - self.min_n)) # linear
# start = self.min_n + int(1 / 2 * (1 - math.cos(math.pi * min(epoch / self.meta_params['sch_epoch'], 1))) * (self.max_n - self.min_n)) # cosine
n = start // 5 * 5
idx = (n - self.min_n) // 5
tasks, weights = self.task_set[idx * 11: (idx + 1) * 11], self.task_w[idx]
if epoch % self.meta_params['update_weight'] == 0:
self.task_w[idx] = self._update_task_weight(tasks, weights, epoch)

# sample a batch of tasks
for b in range(self.meta_params['B']):
for step in range(self.meta_params['k']):
if self.meta_params["data_type"] == "size":
task_params = random.sample(range(start, end + 1), 1) if self.meta_params['curriculum'] else random.sample(self.task_set, 1)[0]
# batch_size = self.meta_params['meta_batch_size'] if task_params[0] <= 100 else self.meta_params['meta_batch_size'] // 2
task_params = random.sample(self.task_set, 1)[0]
batch_size = self.meta_params['meta_batch_size'] if task_params[0] <= 150 else self.meta_params['meta_batch_size'] // 2
elif self.meta_params["data_type"] == "distribution":
task_params = self.task_set[torch.multinomial(self.task_w, 1).item()] if self.meta_params['curriculum'] else random.sample(self.task_set, 1)[0]
task_params = random.sample(self.task_set, 1)[0]
elif self.meta_params["data_type"] == "size_distribution":
pass
task_params = tasks[torch.multinomial(self.task_w[idx], 1).item()] if self.meta_params['curriculum'] else random.sample(self.task_set, 1)[0]
batch_size = self.meta_params['meta_batch_size'] if task_params[0] <= 150 else self.meta_params['meta_batch_size'] // 2

data = self._get_data(batch_size, task_params)
env_params = {'problem_size': data[-1].size(1), 'pomo_size': data[-1].size(1)}
Expand Down Expand Up @@ -272,3 +286,50 @@ def _get_data(self, batch_size, task_params):
data = (depot_xy, node_xy, node_demand)

return data

def _update_task_weight(self, tasks, weights, epoch):
"""
Update the weights of tasks.
For LKH3, set MAX_TRIALS = 100 to reduce time.
"""
global run_func
start_t, gap = time.time(), torch.zeros(weights.size(0))
batch_size = 200 if self.meta_params["solver"] == "lkh3_offline" else 50
idx = torch.randperm(batch_size)[:50]
for i in range(gap.size(0)):
selected = tasks[i]
data = self._get_data(batch_size=batch_size, task_params=selected)

# only use lkh3 at the first iteration of updating task weights
if self.meta_params["solver"] == "lkh3_offline":
if selected not in self.val_data.keys():
self.val_data[selected] = data
opts = argparse.ArgumentParser()
opts.cpus, opts.n, opts.progress_bar_mininterval = None, None, 0.1
dataset = [(instance.cpu().numpy(),) for instance in data]
executable = get_lkh_executable()
def run_func(args):
return solve_lkh_log(executable, *args, runs=1, disable_cache=True, MAX_TRIALS=100) # otherwise it directly loads data from dir
results, _ = run_all_in_pool(run_func, "./LKH3_result", dataset, opts, use_multiprocessing=False)
self.val_opt[selected] = [j[0] for j in results]
data = self.val_data[selected][idx]

model_score = self._fast_val(self.meta_model, data=data, mode="eval", return_all=True)
model_score = model_score.tolist()

if self.meta_params["solver"] == "lkh3_offline":
lkh_score = [self.val_opt[selected][j] for j in idx.tolist()]
gap_list = [(model_score[j] - lkh_score[j]) / lkh_score[j] * 100 for j in range(len(lkh_score))]
gap[i] = sum(gap_list) / len(gap_list)
else:
raise NotImplementedError
print(">> Finish updating task weights within {}s".format(round(time.time() - start_t, 2)))

temp = 0.25
gap_temp = torch.Tensor([i / temp for i in gap.tolist()])
print(gap, temp)
print(">> Old task weights: {}".format(weights))
weights = torch.softmax(gap_temp, dim=0)
print(">> New task weights: {}".format(weights))

return weights
Loading

0 comments on commit cdeec9e

Please sign in to comment.