Skip to content

Commit

Permalink
update POMO code
Browse files Browse the repository at this point in the history
  • Loading branch information
RoyalSkye committed May 25, 2023
1 parent d4a6534 commit 8db4b45
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 202 deletions.
33 changes: 13 additions & 20 deletions POMO/CVRP/CVRPTrainer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

class CVRPTrainer:
"""
Implementation of POMO with MAML / FOMAML / Reptile on CVRP.
Implementation of POMO with MAML / FOMAML / Reptile / Bootstrap Meta-learning on CVRP.
For MAML & FOMAML, ref to "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks";
For Reptile, ref to "On First-Order Meta-Learning Algorithms" and "On the generalization of neural combinatorial optimization heuristics".
"""
Expand All @@ -35,6 +35,7 @@ def __init__(self,
self.optimizer_params = optimizer_params
self.trainer_params = trainer_params
self.meta_params = meta_params
assert self.meta_params['data_type'] == "size_distribution", "Not supported, need to modify the code!"

# result folder, logger
self.logger = getLogger(name='trainer')
Expand Down Expand Up @@ -62,7 +63,6 @@ def __init__(self,
self.alpha = self.meta_params['alpha'] # for reptile
self.task_set = generate_task_set(self.meta_params)
self.val_data, self.val_opt = {}, {} # for lkh3_offline
assert not (self.meta_params['curriculum'] and self.meta_params["data_type"] in ["size", "distribution"]), "Not Implemented!"
if self.meta_params["data_type"] == "size_distribution":
# hardcoded - task_set: range(self.min_n, self.max_n, self.task_interval) * self.num_dist
self.min_n, self.max_n, self.task_interval, self.num_dist = 50, 200, 5, 11
Expand All @@ -71,14 +71,22 @@ def __init__(self,
# Restore
self.start_epoch = 1
model_load = trainer_params['model_load']
pretrain_load = trainer_params['pretrain_load']
if model_load['enable']:
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
checkpoint = torch.load(checkpoint_fullname, map_location=self.device)
self.meta_model.load_state_dict(checkpoint['model_state_dict'])
self.start_epoch = 1 + model_load['epoch']
self.result_log.set_raw_data(checkpoint['result_log'])
self.meta_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.logger.info(">> Model loaded from {}".format(checkpoint_fullname))
self.logger.info('Checkpoint loaded successfully from {}'.format(checkpoint_fullname))

elif pretrain_load['enable']: # meta-training on a pretrain model
self.logger.info(">> Loading pretrained model: be careful with the type of the normalization layer!")
checkpoint_fullname = '{path}'.format(**pretrain_load)
checkpoint = torch.load(checkpoint_fullname, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.logger.info('Pretrained model loaded successfully from {}'.format(checkpoint_fullname))

# utility
self.time_estimator = TimeEstimator()
Expand All @@ -105,13 +113,7 @@ def run(self):
img_save_interval = self.trainer_params['logging']['img_save_interval']
# Val
no_aug_score_list = []
if self.meta_params["data_type"] == "size":
dir = "../../data/CVRP/Size/"
paths = ["cvrp100_uniform.pkl", "cvrp200_uniform.pkl", "cvrp300_uniform.pkl"]
elif self.meta_params["data_type"] == "distribution":
dir = "../../data/CVRP/Distribution/"
paths = ["cvrp100_uniform.pkl", "cvrp100_gaussian.pkl", "cvrp100_cluster.pkl", "cvrp100_diagonal.pkl", "cvrp100_cvrplib.pkl"]
elif self.meta_params["data_type"] == "size_distribution":
if self.meta_params["data_type"] == "size_distribution":
dir = "../../data/CVRP/Size_Distribution/"
paths = ["cvrp200_uniform.pkl", "cvrp300_rotation.pkl"]
if epoch <= 1 or (epoch % img_save_interval) == 0:
Expand Down Expand Up @@ -196,13 +198,7 @@ def _train_one_epoch(self, epoch):
# sample a batch of tasks
w, selected_tasks = [1.0] * self.meta_params['B'], []
for b in range(self.meta_params['B']):
if self.meta_params["data_type"] == "size":
task_params = random.sample(self.task_set, 1)[0]
batch_size = meta_batch_size if task_params[0] <= 150 else meta_batch_size // 2
elif self.meta_params["data_type"] == "distribution":
task_params = random.sample(self.task_set, 1)[0]
batch_size = meta_batch_size
elif self.meta_params["data_type"] == "size_distribution":
if self.meta_params["data_type"] == "size_distribution":
selected = torch.multinomial(self.task_w[idx], 1).item()
task_params = tasks[selected] if self.meta_params['curriculum'] else random.sample(self.task_set, 1)[0]
batch_size = meta_batch_size if task_params[0] <= 150 else meta_batch_size // 2
Expand Down Expand Up @@ -525,9 +521,6 @@ def _get_data(self, batch_size, task_params, return_capacity=False):

def _get_val_data(self, batch_size, task_params):
if self.meta_params["data_type"] == "size":
# start1, end1 = min(task_params[0] + 10, self.max_n), min(task_params[0] + 20, self.max_n)
# val_size = random.sample(range(start1, end1 + 1), 1)[0]
# val_data = self._get_data(batch_size, (val_size,))
val_data = self._get_data(batch_size, task_params)
elif self.meta_params["data_type"] == "distribution":
val_data = self._get_data(batch_size, task_params)
Expand Down
39 changes: 18 additions & 21 deletions POMO/CVRP/CVRPTrainer_pomo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self,
self.optimizer_params = optimizer_params
self.trainer_params = trainer_params
self.meta_params = meta_params
assert self.meta_params['data_type'] == "size_distribution", "Not supported, need to modify the code!"

# result folder, logger
self.logger = getLogger(name='trainer')
Expand All @@ -56,7 +57,6 @@ def __init__(self,
self.optimizer = Optimizer(self.model.parameters(), **self.optimizer_params['optimizer'])
self.task_set = generate_task_set(self.meta_params)
self.val_data, self.val_opt = {}, {} # for lkh3_offline
assert not (self.meta_params['curriculum'] and self.meta_params["data_type"] in ["size", "distribution"]), "Not Implemented!"
if self.meta_params["data_type"] == "size_distribution":
# hardcoded - task_set: range(self.min_n, self.max_n, self.task_interval) * self.num_dist
self.min_n, self.max_n, self.task_interval, self.num_dist = 50, 200, 5, 11
Expand All @@ -65,14 +65,22 @@ def __init__(self,
# Restore
self.start_epoch = 1
model_load = trainer_params['model_load']
pretrain_load = trainer_params['pretrain_load']
if model_load['enable']:
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
checkpoint = torch.load(checkpoint_fullname, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.start_epoch = 1 + model_load['epoch']
self.result_log.set_raw_data(checkpoint['result_log'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.logger.info('Saved Model Loaded !!')
self.logger.info('Checkpoint loaded successfully from {}'.format(checkpoint_fullname))

elif pretrain_load['enable']: # meta-training on a pretrain model
self.logger.info(">> Loading pretrained model: be careful with the type of the normalization layer!")
checkpoint_fullname = '{path}'.format(**pretrain_load)
checkpoint = torch.load(checkpoint_fullname, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.logger.info('Pretrained model loaded successfully from {}'.format(checkpoint_fullname))

# utility
self.time_estimator = TimeEstimator()
Expand All @@ -83,12 +91,12 @@ def run(self):
for epoch in range(self.start_epoch, self.meta_params['epochs']+1):
self.logger.info('=================================================================')

# lr decay (by 10) to speed up convergence at 90th and 95th iterations
# if epoch in [int(self.meta_params['epochs'] * 0.9)]:
# self.optimizer_params['optimizer']['lr'] /= 10
# for group in self.optimizer.param_groups:
# group["lr"] /= 10
# print(">> LR decay to {}".format(group["lr"]))
# lr decay (by 10) to speed up convergence at 90th iteration
if epoch in [int(self.meta_params['epochs'] * 0.9)]:
self.optimizer_params['optimizer']['lr'] /= 10
for group in self.optimizer.param_groups:
group["lr"] /= 10
print(">> LR decay to {}".format(group["lr"]))

# Train
train_score, train_loss = self._train_one_epoch(epoch)
Expand All @@ -98,13 +106,7 @@ def run(self):
img_save_interval = self.trainer_params['logging']['img_save_interval']
# Val
no_aug_score_list = []
if self.meta_params["data_type"] == "size":
dir = "../../data/CVRP/Size/"
paths = ["cvrp100_uniform.pkl", "cvrp200_uniform.pkl", "cvrp300_uniform.pkl"]
elif self.meta_params["data_type"] == "distribution":
dir = "../../data/CVRP/Distribution/"
paths = ["cvrp100_uniform.pkl", "cvrp100_gaussian.pkl", "cvrp100_cluster.pkl", "cvrp100_diagonal.pkl", "cvrp100_cvrplib.pkl"]
elif self.meta_params["data_type"] == "size_distribution":
if self.meta_params["data_type"] == "size_distribution":
dir = "../../data/CVRP/Size_Distribution/"
paths = ["cvrp200_uniform.pkl", "cvrp300_rotation.pkl"]
if epoch <= 1 or (epoch % img_save_interval) == 0:
Expand Down Expand Up @@ -165,12 +167,7 @@ def _train_one_epoch(self, epoch):
# sample a batch of tasks
for b in range(self.meta_params['B']):
for step in range(self.meta_params['k']):
if self.meta_params["data_type"] == "size":
task_params = random.sample(self.task_set, 1)[0]
batch_size = self.meta_params['meta_batch_size'] if task_params[0] <= 150 else self.meta_params['meta_batch_size'] // 2
elif self.meta_params["data_type"] == "distribution":
task_params = random.sample(self.task_set, 1)[0]
elif self.meta_params["data_type"] == "size_distribution":
if self.meta_params["data_type"] == "size_distribution":
task_params = tasks[torch.multinomial(self.task_w[idx], 1).item()] if self.meta_params['curriculum'] else random.sample(self.task_set, 1)[0]
batch_size = self.meta_params['meta_batch_size'] if task_params[0] <= 150 else self.meta_params['meta_batch_size'] // 2

Expand Down
8 changes: 6 additions & 2 deletions POMO/CVRP/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
'ff_hidden_dim': 512,
'eval_type': 'argmax',
'meta_update_encoder': True,
# 'norm': 'batch_no_track'
}

optimizer_params = {
Expand All @@ -57,11 +56,16 @@
'filename': 'style_loss_1.json'
},
},
# load previous checkpoint for meta-training
'model_load': {
'enable': False, # enable loading pre-trained model
'path': './result/saved_CVRP20_model', # directory path of pre-trained model and log files saved.
'epoch': 2000, # epoch version of pre-trained model to laod.

},
# load pretrain model for meta-training instead of meta-training from scratch
'pretrain_load': {
'enable': False,
'path': '../../pretrained/POMO-CVRP/checkpoint-30500-cvrp100-instance-norm.pt', # be careful with the type of the normalization layer
}
}

Expand Down
Loading

0 comments on commit 8db4b45

Please sign in to comment.