Skip to content

Commit

Permalink
fix pretrain bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
RoyalSkye committed May 31, 2023
1 parent 2ccc19c commit f281406
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 5 deletions.
5 changes: 3 additions & 2 deletions POMO/CVRP/CVRPTrainer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def __init__(self,
self.logger.info(">> Loading pretrained model: be careful with the type of the normalization layer!")
checkpoint_fullname = '{path}'.format(**pretrain_load)
checkpoint = torch.load(checkpoint_fullname, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.meta_model.load_state_dict(checkpoint['model_state_dict'])
self.meta_optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # otherwise, unstable meta-training (nan problem)
self.logger.info('Pretrained model loaded successfully from {}'.format(checkpoint_fullname))

# utility
Expand Down Expand Up @@ -372,7 +373,7 @@ def _train_one_batch_maml(self, fast_weight, data, env, optimizer=None, create_g
w_t, (beta1, beta2), eps = [], self.meta_optimizer.param_groups[0]['betas'], self.meta_optimizer.param_groups[0]['eps']
lr, weight_decay = self.optimizer_params['optimizer']['lr'], self.optimizer_params['optimizer']['weight_decay']
for i, ((name, param), grad) in enumerate(zip(fast_weight.items(), gradients)):
print(i, name)
# print(i, name)
if self.meta_optimizer.state_dict()['state'] != {}:
# (with batch/instance norm layer): i \in [0, 86], where encoder \in [0, 81] + decoder \in [82, 86]
# (with rezero norm layer): i \in [0, 74], where encoder \in [0, 69] + decoder \in [70, 74]
Expand Down
1 change: 1 addition & 0 deletions POMO/CVRP/CVRPTrainer_pomo.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(self,
checkpoint_fullname = '{path}'.format(**pretrain_load)
checkpoint = torch.load(checkpoint_fullname, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.logger.info('Pretrained model loaded successfully from {}'.format(checkpoint_fullname))

# utility
Expand Down
2 changes: 1 addition & 1 deletion POMO/CVRP/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def main():
if not meta_params['enable']:
print(">> Start CVRP-POMO Training.")
trainer = Trainer_pomo(env_params=env_params, model_params=model_params, optimizer_params=optimizer_params, trainer_params=trainer_params, meta_params=meta_params)
elif meta_params['meta_method'] in ['maml', 'fomaml', 'reptile']:
elif meta_params['meta_method'] in ['maml', 'fomaml', 'maml_fomaml', 'reptile']:
print(">> Start CVRP-POMO-{} Training.".format(meta_params['meta_method']))
trainer = Trainer_meta(env_params=env_params, model_params=model_params, optimizer_params=optimizer_params, trainer_params=trainer_params, meta_params=meta_params)
else:
Expand Down
3 changes: 2 additions & 1 deletion POMO/TSP/TSPTrainer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def __init__(self,
self.logger.info(">> Loading pretrained model: be careful with the type of the normalization layer!")
checkpoint_fullname = '{path}'.format(**pretrain_load)
checkpoint = torch.load(checkpoint_fullname, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.meta_model.load_state_dict(checkpoint['model_state_dict'])
self.meta_optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # otherwise, unstable meta-training (nan problem)
self.logger.info('Pretrained model loaded successfully from {}'.format(checkpoint_fullname))

# utility
Expand Down
1 change: 1 addition & 0 deletions POMO/TSP/TSPTrainer_pomo.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(self,
checkpoint_fullname = '{path}'.format(**pretrain_load)
checkpoint = torch.load(checkpoint_fullname, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.logger.info('Pretrained model loaded successfully from {}'.format(checkpoint_fullname))

# utility
Expand Down
2 changes: 1 addition & 1 deletion POMO/TSP/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def main():
if not meta_params['enable']:
print(">> Start TSP-POMO Training.")
trainer = Trainer_pomo(env_params=env_params, model_params=model_params, optimizer_params=optimizer_params, trainer_params=trainer_params, meta_params=meta_params)
elif meta_params['meta_method'] in ['maml', 'fomaml', 'reptile']:
elif meta_params['meta_method'] in ['maml', 'fomaml', 'maml_fomaml', 'reptile']:
print(">> Start TSP-POMO-{} Training.".format(meta_params['meta_method']))
trainer = Trainer_meta(env_params=env_params, model_params=model_params, optimizer_params=optimizer_params, trainer_params=trainer_params, meta_params=meta_params)
else:
Expand Down

0 comments on commit f281406

Please sign in to comment.