diff --git a/POMO/TSP/TSPModel.py b/POMO/TSP/TSPModel.py index 98c269b..6efacea 100644 --- a/POMO/TSP/TSPModel.py +++ b/POMO/TSP/TSPModel.py @@ -15,12 +15,12 @@ def __init__(self, **model_params): self.encoded_nodes = None # shape: (batch, problem, EMBEDDING_DIM) - def pre_forward(self, reset_state): - self.encoded_nodes = self.encoder(reset_state.problems) + def pre_forward(self, reset_state, weights=None): + self.encoded_nodes = self.encoder(reset_state.problems, weights=weights) # shape: (batch, problem, EMBEDDING_DIM) - self.decoder.set_kv(self.encoded_nodes) + self.decoder.set_kv(self.encoded_nodes, weights=weights) - def forward(self, state): + def forward(self, state, weights=None): batch_size = state.BATCH_IDX.size(0) pomo_size = state.BATCH_IDX.size(1) @@ -30,12 +30,12 @@ def forward(self, state): encoded_first_node = _get_encoding(self.encoded_nodes, selected) # shape: (batch, pomo, embedding) - self.decoder.set_q1(encoded_first_node) # pre-compute fixed part of the context embedding + self.decoder.set_q1(encoded_first_node, weights=weights) # pre-compute fixed part of the context embedding else: encoded_last_node = _get_encoding(self.encoded_nodes, state.current_node) # shape: (batch, pomo, embedding) - probs = self.decoder(encoded_last_node, ninf_mask=state.ninf_mask) + probs = self.decoder(encoded_last_node, ninf_mask=state.ninf_mask, weights=weights) # shape: (batch, pomo, problem) while True: @@ -86,15 +86,19 @@ def __init__(self, **model_params): self.embedding = nn.Linear(2, embedding_dim) self.layers = nn.ModuleList([EncoderLayer(**model_params) for _ in range(encoder_layer_num)]) - def forward(self, data): - # data.shape: (batch, problem, 2) - - embedded_input = self.embedding(data) - # shape: (batch, problem, embedding) - - out = embedded_input - for layer in self.layers: - out = layer(out) + def forward(self, data, weights=None): + if weights is None: + # data.shape: (batch, problem, 2) + embedded_input = self.embedding(data) + # shape: (batch, problem, embedding) + out = embedded_input + for layer in self.layers: + out = layer(out) + else: + embedded_input = F.linear(data, weights['encoder.embedding.weight'], weights['encoder.embedding.bias']) + out = embedded_input + for idx, layer in enumerate(self.layers): + out = layer(out, weights=weights, index=idx) return out @@ -116,24 +120,32 @@ def __init__(self, **model_params): self.feedForward = Feed_Forward_Module(**model_params) self.addAndNormalization2 = Add_And_Normalization_Module(**model_params) - def forward(self, input1): - # input.shape: (batch, problem, EMBEDDING_DIM) - head_num = self.model_params['head_num'] - - q = reshape_by_heads(self.Wq(input1), head_num=head_num) - k = reshape_by_heads(self.Wk(input1), head_num=head_num) - v = reshape_by_heads(self.Wv(input1), head_num=head_num) - # q shape: (batch, HEAD_NUM, problem, KEY_DIM) - - out_concat = multi_head_attention(q, k, v) - # shape: (batch, problem, HEAD_NUM*KEY_DIM) - - multi_head_out = self.multi_head_combine(out_concat) - # shape: (batch, problem, EMBEDDING_DIM) - - out1 = self.addAndNormalization1(input1, multi_head_out) - out2 = self.feedForward(out1) - out3 = self.addAndNormalization2(out1, out2) + def forward(self, input1, weights=None, index=0): + if weights is None: + # input.shape: (batch, problem, EMBEDDING_DIM) + head_num = self.model_params['head_num'] + q = reshape_by_heads(self.Wq(input1), head_num=head_num) + k = reshape_by_heads(self.Wk(input1), head_num=head_num) + v = reshape_by_heads(self.Wv(input1), head_num=head_num) + # q shape: (batch, HEAD_NUM, problem, KEY_DIM) + out_concat = multi_head_attention(q, k, v) + # shape: (batch, problem, HEAD_NUM*KEY_DIM) + multi_head_out = self.multi_head_combine(out_concat) + # shape: (batch, problem, EMBEDDING_DIM) + out1 = self.addAndNormalization1(input1, multi_head_out) + out2 = self.feedForward(out1) + out3 = self.addAndNormalization2(out1, out2) + else: + head_num = self.model_params['head_num'] + q = reshape_by_heads(F.linear(input1, weights['encoder.layers.{}.Wq.weight'.format(index)], bias=None), head_num=head_num) + k = reshape_by_heads(F.linear(input1, weights['encoder.layers.{}.Wk.weight'.format(index)], bias=None), head_num=head_num) + v = reshape_by_heads(F.linear(input1, weights['encoder.layers.{}.Wv.weight'.format(index)], bias=None), head_num=head_num) + out_concat = multi_head_attention(q, k, v) + multi_head_out = F.linear(out_concat, weights['encoder.layers.{}.multi_head_combine.weight'.format(index)], weights['encoder.layers.{}.multi_head_combine.bias'.format(index)]) + out1 = self.addAndNormalization1(input1, multi_head_out, weights={'weight': weights['encoder.layers.{}.addAndNormalization1.norm.weight'.format(index)], 'bias': weights['encoder.layers.{}.addAndNormalization1.norm.bias'.format(index)]}) + out2 = self.feedForward(out1, weights={'weight1': weights['encoder.layers.{}.feedForward.W1.weight'.format(index)], 'bias1': weights['encoder.layers.{}.feedForward.W1.bias'.format(index)], + 'weight2': weights['encoder.layers.{}.feedForward.W2.weight'.format(index)], 'bias2': weights['encoder.layers.{}.feedForward.W2.bias'.format(index)]}) + out3 = self.addAndNormalization2(out1, out2, weights={'weight': weights['encoder.layers.{}.addAndNormalization2.norm.weight'.format(index)], 'bias': weights['encoder.layers.{}.addAndNormalization2.norm.bias'.format(index)]}) return out3 # shape: (batch, problem, EMBEDDING_DIM) @@ -163,60 +175,71 @@ def __init__(self, **model_params): self.single_head_key = None # saved, for single-head attention self.q_first = None # saved q1, for multi-head attention - def set_kv(self, encoded_nodes): - # encoded_nodes.shape: (batch, problem, embedding) - head_num = self.model_params['head_num'] - - self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num) - self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num) - # shape: (batch, head_num, pomo, qkv_dim) - self.single_head_key = encoded_nodes.transpose(1, 2) - # shape: (batch, embedding, problem) - - def set_q1(self, encoded_q1): - # encoded_q.shape: (batch, n, embedding) # n can be 1 or pomo - head_num = self.model_params['head_num'] - - self.q_first = reshape_by_heads(self.Wq_first(encoded_q1), head_num=head_num) - # shape: (batch, head_num, n, qkv_dim) - - def forward(self, encoded_last_node, ninf_mask): - # encoded_last_node.shape: (batch, pomo, embedding) - # ninf_mask.shape: (batch, pomo, problem) - - head_num = self.model_params['head_num'] - - # Multi-Head Attention - ####################################################### - q_last = reshape_by_heads(self.Wq_last(encoded_last_node), head_num=head_num) - # shape: (batch, head_num, pomo, qkv_dim) - - q = self.q_first + q_last - # shape: (batch, head_num, pomo, qkv_dim) - - out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask) - # shape: (batch, pomo, head_num*qkv_dim) - - mh_atten_out = self.multi_head_combine(out_concat) - # shape: (batch, pomo, embedding) - - # Single-Head Attention, for probability calculation - ####################################################### - score = torch.matmul(mh_atten_out, self.single_head_key) - # shape: (batch, pomo, problem) - - sqrt_embedding_dim = self.model_params['sqrt_embedding_dim'] - logit_clipping = self.model_params['logit_clipping'] - - score_scaled = score / sqrt_embedding_dim - # shape: (batch, pomo, problem) - - score_clipped = logit_clipping * torch.tanh(score_scaled) - - score_masked = score_clipped + ninf_mask - - probs = F.softmax(score_masked, dim=2) - # shape: (batch, pomo, problem) + def set_kv(self, encoded_nodes, weights=None): + if weights is None: + # encoded_nodes.shape: (batch, problem, embedding) + head_num = self.model_params['head_num'] + self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num) + self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num) + # shape: (batch, head_num, pomo, qkv_dim) + self.single_head_key = encoded_nodes.transpose(1, 2) + # shape: (batch, embedding, problem) + else: + head_num = self.model_params['head_num'] + self.k = reshape_by_heads(F.linear(encoded_nodes, weights['decoder.Wk.weight'], bias=None), head_num=head_num) + self.v = reshape_by_heads(F.linear(encoded_nodes, weights['decoder.Wv.weight'], bias=None), head_num=head_num) + self.single_head_key = encoded_nodes.transpose(1, 2) + + def set_q1(self, encoded_q1, weights=None): + if weights is None: + # encoded_q.shape: (batch, n, embedding) # n can be 1 or pomo + head_num = self.model_params['head_num'] + self.q_first = reshape_by_heads(self.Wq_first(encoded_q1), head_num=head_num) + # shape: (batch, head_num, n, qkv_dim) + else: + head_num = self.model_params['head_num'] + self.q_first = reshape_by_heads(F.linear(encoded_q1, weights['decoder.Wq_first.weight'], bias=None), head_num=head_num) + + def forward(self, encoded_last_node, ninf_mask, weights=None): + if weights is None: + # encoded_last_node.shape: (batch, pomo, embedding) + # ninf_mask.shape: (batch, pomo, problem) + head_num = self.model_params['head_num'] + # Multi-Head Attention + ####################################################### + q_last = reshape_by_heads(self.Wq_last(encoded_last_node), head_num=head_num) + # shape: (batch, head_num, pomo, qkv_dim) + q = self.q_first + q_last + # shape: (batch, head_num, pomo, qkv_dim) + out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask) + # shape: (batch, pomo, head_num*qkv_dim) + mh_atten_out = self.multi_head_combine(out_concat) + # shape: (batch, pomo, embedding) + # Single-Head Attention, for probability calculation + ####################################################### + score = torch.matmul(mh_atten_out, self.single_head_key) + # shape: (batch, pomo, problem) + sqrt_embedding_dim = self.model_params['sqrt_embedding_dim'] + logit_clipping = self.model_params['logit_clipping'] + score_scaled = score / sqrt_embedding_dim + # shape: (batch, pomo, problem) + score_clipped = logit_clipping * torch.tanh(score_scaled) + score_masked = score_clipped + ninf_mask + probs = F.softmax(score_masked, dim=2) + # shape: (batch, pomo, problem) + else: + head_num = self.model_params['head_num'] + q_last = reshape_by_heads(F.linear(encoded_last_node, weights['decoder.Wq_last.weight'], bias=None), head_num=head_num) + q = self.q_first + q_last + out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask) + mh_atten_out = F.linear(out_concat, weights['decoder.multi_head_combine.weight'], weights['decoder.multi_head_combine.bias']) + score = torch.matmul(mh_atten_out, self.single_head_key) + sqrt_embedding_dim = self.model_params['sqrt_embedding_dim'] + logit_clipping = self.model_params['logit_clipping'] + score_scaled = score / sqrt_embedding_dim + score_clipped = logit_clipping * torch.tanh(score_scaled) + score_masked = score_clipped + ninf_mask + probs = F.softmax(score_masked, dim=2) return probs @@ -283,20 +306,22 @@ def __init__(self, **model_params): embedding_dim = model_params['embedding_dim'] self.norm = nn.InstanceNorm1d(embedding_dim, affine=True, track_running_stats=False) - def forward(self, input1, input2): - # input.shape: (batch, problem, embedding) - - added = input1 + input2 - # shape: (batch, problem, embedding) - - transposed = added.transpose(1, 2) - # shape: (batch, embedding, problem) - - normalized = self.norm(transposed) - # shape: (batch, embedding, problem) - - back_trans = normalized.transpose(1, 2) - # shape: (batch, problem, embedding) + def forward(self, input1, input2, weights=None): + if weights is None: + # input.shape: (batch, problem, embedding) + added = input1 + input2 + # shape: (batch, problem, embedding) + transposed = added.transpose(1, 2) + # shape: (batch, embedding, problem) + normalized = self.norm(transposed) + # shape: (batch, embedding, problem) + back_trans = normalized.transpose(1, 2) + # shape: (batch, problem, embedding) + else: + added = input1 + input2 + transposed = added.transpose(1, 2) + normalized = F.instance_norm(transposed, weight=weights['weight'], bias=weights['bias']) + back_trans = normalized.transpose(1, 2) return back_trans @@ -310,7 +335,10 @@ def __init__(self, **model_params): self.W1 = nn.Linear(embedding_dim, ff_hidden_dim) self.W2 = nn.Linear(ff_hidden_dim, embedding_dim) - def forward(self, input1): - # input.shape: (batch, problem, embedding) - - return self.W2(F.relu(self.W1(input1))) + def forward(self, input1, weights=None): + if weights is None: + # input.shape: (batch, problem, embedding) + return self.W2(F.relu(self.W1(input1))) + else: + output = F.relu(F.linear(input1, weights['weight1'], bias=weights['bias1'])) + return F.linear(output, weights['weight2'], bias=weights['bias2']) diff --git a/POMO/TSP/TSPTrainer_Meta.py b/POMO/TSP/TSPTrainer_Meta.py index 70ed98d..818ff20 100644 --- a/POMO/TSP/TSPTrainer_Meta.py +++ b/POMO/TSP/TSPTrainer_Meta.py @@ -19,9 +19,12 @@ class TSPTrainer: """ + TODO: 1. val data? and training data, for k steps of inner-loop, should we use the same batch of data? + 2. only meta-update partial para of pomo? Implementation of POMO with MAML / FOMAML / Reptile. For MAML & FOMAML, ref to "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks"; For Reptile, ref to "On First-Order Meta-Learning Algorithms". + Refer to "https://lilianweng.github.io/posts/2018-11-30-meta-learning" """ def __init__(self, env_params, @@ -54,6 +57,7 @@ def __init__(self, # Main Components self.meta_model = Model(**self.model_params) + self.meta_optimizer = Optimizer(self.meta_model.parameters(), **self.optimizer_params['optimizer']) self.alpha = self.meta_params['alpha'] # for reptile self.task_set = generate_task_set(self.meta_params) # assert self.trainer_params['meta_params']['epochs'] == math.ceil((self.trainer_params['epochs'] * self.trainer_params['train_episodes']) / ( @@ -88,7 +92,10 @@ def run(self): self.result_log.append('train_score', epoch, train_score) self.result_log.append('train_loss', epoch, train_loss) # Val - _, no_aug_score = self._fast_val(copy.deepcopy(self.meta_model), val_episodes=64) + if self.meta_params['meta_method'] in ['fomaml', 'reptile']: + no_aug_score = self._fast_val(copy.deepcopy(self.meta_model), val_episodes=32, mode="eval") + else: + no_aug_score = self._fast_val(self.meta_model, val_episodes=32, mode="eval") self.result_log.append('val_score', epoch, no_aug_score) # Logs & Checkpoint @@ -140,45 +147,42 @@ def _train_one_epoch(self, epoch): 2. inner-loop: for a batch of tasks T_i, do reptile -> \theta_i 3. outer-loop: update meta-model -> \theta_0 """ + self.meta_model.train() score_AM = AverageMeter() loss_AM = AverageMeter() batch_size = self.meta_params['meta_batch_size'] self._alpha_scheduler(epoch) - slow_weights = copy.deepcopy(self.meta_model.state_dict()) fast_weights, val_loss, fomaml_grad = [], 0, [] # sample a batch of tasks for i in range(self.meta_params['B']): task_params = random.sample(self.task_set, 1)[0] - task_model = copy.deepcopy(self.meta_model) + if self.meta_params['meta_method'] in ['fomaml', 'reptile']: + task_model = copy.deepcopy(self.meta_model) + optimizer = Optimizer(task_model.parameters(), **self.optimizer_params['optimizer']) + elif self.meta_params['meta_method'] == 'maml': + fast_weight = OrderedDict(self.meta_model.named_parameters()) for step in range(self.meta_params['k'] + 1): # generate task-specific data - if self.meta_params['data_type'] == 'distribution': - assert len(task_params) == 2 - data = get_random_problems(batch_size, self.env_params['problem_size'], num_modes=task_params[0], cdist=task_params[-1], distribution='gaussian_mixture') - elif self.meta_params['data_type'] == 'size': - assert len(task_params) == 1 - data = get_random_problems(batch_size, task_params[0], num_modes=0, cdist=0, distribution='uniform') - elif self.meta_params['data_type'] == "size_distribution": - assert len(task_params) == 3 - data = get_random_problems(batch_size, problem_size=task_params[0], num_modes=task_params[1], cdist=task_params[-1], distribution='gaussian_mixture') - else: - raise NotImplementedError - + data = self._get_data(batch_size, task_params) if step == self.meta_params['k']: continue env_params = {'problem_size': data.size(1), 'pomo_size': data.size(1)} - avg_score, avg_loss = self._train_one_batch(step, task_model, data, Env(**env_params)) + + if self.meta_params['meta_method'] in ['reptile', 'fomaml']: + avg_score, avg_loss = self._train_one_batch(task_model, data, Env(**env_params), optimizer) + elif self.meta_params['meta_method'] == 'maml': + avg_score, avg_loss, fast_weight = self._train_one_batch_maml(fast_weight, data, Env(**env_params)) + score_AM.update(avg_score.item(), batch_size) loss_AM.update(avg_loss.item(), batch_size) if self.meta_params['meta_method'] == 'maml': # cal loss on query(val) set - data - val_loss += self._fast_val(task_model, data=data)[0] + val_loss += self._fast_val(fast_weight, data=data, mode="maml") elif self.meta_params['meta_method'] == 'fomaml': - val_loss = self._fast_val(task_model, data=data)[0] - task_model.train() + val_loss = self._fast_val(task_model, data=data, mode="fomaml") grad = torch.autograd.grad(val_loss, task_model.parameters()) fomaml_grad.append(grad) elif self.meta_params['meta_method'] == 'reptile': @@ -186,13 +190,10 @@ def _train_one_epoch(self, epoch): # update meta-model if self.meta_params['meta_method'] == 'maml': - val_loss = val_loss / self.meta_params['B'] - gradients = torch.autograd.grad(val_loss, self.maml) - updated_weights = OrderedDict( - (name, param - self.optimizer_params['optimizer']['lr'] * grad) - for ((name, param), grad) in zip(self.meta_model.state_dict().items(), gradients) - ) - self.meta_model.load_state_dict(updated_weights) + val_loss /= self.meta_params['B'] + self.meta_optimizer.zero_grad() + val_loss.backward() + self.meta_optimizer.step() elif self.meta_params['meta_method'] == 'fomaml': updated_weights = self.meta_model.state_dict() for gradients in fomaml_grad: @@ -202,7 +203,7 @@ def _train_one_epoch(self, epoch): ) self.meta_model.load_state_dict(updated_weights) elif self.meta_params['meta_method'] == 'reptile': - state_dict = {params_key: (slow_weights[params_key] + self.alpha * torch.mean(torch.stack([fast_weight[params_key] - slow_weights[params_key] for fast_weight in fast_weights], dim=0), dim=0)) for params_key in slow_weights} + state_dict = {params_key: (self.meta_model.state_dict()[params_key] + self.alpha * torch.mean(torch.stack([fast_weight[params_key] - self.meta_model.state_dict()[params_key] for fast_weight in fast_weights], dim=0), dim=0)) for params_key in self.meta_model.state_dict()} self.meta_model.load_state_dict(state_dict) # Log Once, for each epoch @@ -210,7 +211,7 @@ def _train_one_epoch(self, epoch): return score_AM.avg, loss_AM.avg - def _train_one_batch(self, i, task_model, data, env): + def _train_one_batch(self, task_model, data, env, optimizer): task_model.train() batch_size = data.size(0) @@ -238,51 +239,101 @@ def _train_one_batch(self, i, task_model, data, env): loss_mean = loss.mean() # update model - create_graph = True if self.meta_params['meta_method'] == 'maml' else False - if i == 0: - self.maml = list(task_model.parameters()) - gradients = torch.autograd.grad(loss_mean, self.maml, create_graph=create_graph) - else: - gradients = torch.autograd.grad(loss_mean, task_model.parameters(), create_graph=create_graph) - fast_weights = OrderedDict( + optimizer.zero_grad() + loss_mean.backward() + optimizer.step() + + # Score + max_pomo_reward, _ = reward.max(dim=1) # get best results from pomo + score_mean = -max_pomo_reward.float().mean() # negative sign to make positive value + print(score_mean) + + return score_mean, loss_mean + + def _train_one_batch_maml(self, fast_weight, data, env): + + batch_size = data.size(0) + env.load_problems(batch_size, problems=data, aug_factor=1) + reset_state, _, _ = env.reset() + self.meta_model.pre_forward(reset_state, weights=fast_weight) + prob_list = torch.zeros(size=(batch_size, env.pomo_size, 0)) + # shape: (batch, pomo, 0~problem) + + # POMO Rollout, please note that the reward is negative (i.e., -length of route). + state, reward, done = env.pre_step() + while not done: + selected, prob = self.meta_model(state, weights=fast_weight) + # shape: (batch, pomo) + state, reward, done = env.step(selected) + prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2) + + # Loss + advantage = reward - reward.float().mean(dim=1, keepdims=True) + # shape: (batch, pomo) + log_prob = prob_list.log().sum(dim=2) # for the first/last node, p=1 -> log_p=0 + # size = (batch, pomo) + loss = -advantage * log_prob # Minus Sign: To Increase REWARD + # shape: (batch, pomo) + loss_mean = loss.mean() + + # update model + gradients = torch.autograd.grad(loss_mean, fast_weight.values(), create_graph=True) + fast_weight = OrderedDict( (name, param - self.optimizer_params['optimizer']['lr'] * grad) - for ((name, param), grad) in zip(task_model.state_dict().items(), gradients) + for ((name, param), grad) in zip(fast_weight.items(), gradients) ) - task_model.load_state_dict(fast_weights) # Score max_pomo_reward, _ = reward.max(dim=1) # get best results from pomo score_mean = -max_pomo_reward.float().mean() # negative sign to make positive value + print(score_mean) - return score_mean, loss_mean + return score_mean, loss_mean, fast_weight - def _fast_val(self, model, data=None, val_episodes=64): + def _fast_val(self, model, data=None, val_episodes=32, mode="eval"): aug_factor = 1 if data is None: - val_path = "../../data/TSP/tsp50_tsplib.pkl" + val_path = "../../data/TSP/tsp150_uniform.pkl" data = torch.Tensor(load_dataset(val_path)[: val_episodes]) env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)}) - model.eval() batch_size = data.size(0) - with torch.enable_grad(): + if mode == "eval": + model.eval() + with torch.no_grad(): + env.load_problems(batch_size, problems=data, aug_factor=aug_factor) + reset_state, _, _ = env.reset() + model.pre_forward(reset_state) + state, reward, done = env.pre_step() + while not done: + selected, _ = model(state) + # shape: (batch, pomo) + state, reward, done = env.step(selected) + elif mode in ["maml", "fomaml"]: + fast_weight = model env.load_problems(batch_size, problems=data, aug_factor=aug_factor) reset_state, _, _ = env.reset() - model.pre_forward(reset_state) + if mode == "maml": + self.meta_model.pre_forward(reset_state, weights=fast_weight) + else: + model.pre_forward(reset_state) prob_list = torch.zeros(size=(batch_size, env.pomo_size, 0)) - state, reward, done = env.pre_step() while not done: - selected, prob = model(state) + if mode == "maml": + selected, prob = self.meta_model(state, weights=fast_weight) + else: + selected, prob = model(state) # shape: (batch, pomo) state, reward, done = env.step(selected) prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2) - - # Loss - advantage = reward - reward.float().mean(dim=1, keepdims=True) - log_prob = prob_list.log().sum(dim=2) # for the first/last node, p=1 -> log_p=0 - loss = -advantage * log_prob # Minus Sign: To Increase REWARD - loss_mean = loss.mean() + # Loss + advantage = reward - reward.float().mean(dim=1, keepdims=True) + log_prob = prob_list.log().sum(dim=2) # for the first/last node, p=1 -> log_p=0 + loss = -advantage * log_prob # Minus Sign: To Increase REWARD + loss_mean = loss.mean() + else: + raise NotImplementedError # Return aug_reward = reward.reshape(aug_factor, batch_size, env.pomo_size) @@ -291,7 +342,25 @@ def _fast_val(self, model, data=None, val_episodes=64): # shape: (augmentation, batch) no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive value - return loss_mean, no_aug_score.detach().item() + if mode == "eval": + return no_aug_score.detach().item() + else: + return loss_mean + + def _get_data(self, batch_size, task_params): + if self.meta_params['data_type'] == 'distribution': + assert len(task_params) == 2 + data = get_random_problems(batch_size, self.env_params['problem_size'], num_modes=task_params[0], cdist=task_params[-1], distribution='gaussian_mixture') + elif self.meta_params['data_type'] == 'size': + assert len(task_params) == 1 + data = get_random_problems(batch_size, task_params[0], num_modes=0, cdist=0, distribution='uniform') + elif self.meta_params['data_type'] == "size_distribution": + assert len(task_params) == 3 + data = get_random_problems(batch_size, problem_size=task_params[0], num_modes=task_params[1], cdist=task_params[-1], distribution='gaussian_mixture') + else: + raise NotImplementedError + + return data def _alpha_scheduler(self, iter): self.alpha = max(self.alpha * self.meta_params['alpha_decay'], 0.0001) diff --git a/POMO/TSP/TSPTrainer_pomo.py b/POMO/TSP/TSPTrainer_pomo.py index 9103fe8..ca818d3 100644 --- a/POMO/TSP/TSPTrainer_pomo.py +++ b/POMO/TSP/TSPTrainer_pomo.py @@ -86,7 +86,7 @@ def run(self): self.result_log.append('train_score', epoch, train_score) self.result_log.append('train_loss', epoch, train_loss) # Val - no_aug_score = self._fast_val(self.meta_model, val_episodes=64) + no_aug_score = self._fast_val(self.meta_model, val_episodes=32) self.result_log.append('val_score', epoch, no_aug_score) # Logs & Checkpoint @@ -138,6 +138,7 @@ def _train_one_epoch(self, epoch): 2. inner-loop: for a batch of tasks T_i, do reptile -> \theta_i 3. outer-loop: update meta-model -> \theta_0 """ + self.meta_model.train() score_AM = AverageMeter() loss_AM = AverageMeter() batch_size = self.meta_params['meta_batch_size'] @@ -172,7 +173,6 @@ def _train_one_epoch(self, epoch): def _train_one_batch(self, data, env): - self.meta_model.train() batch_size = data.size(0) env.load_problems(batch_size, problems=data, aug_factor=1) reset_state, _, _ = env.reset() @@ -208,10 +208,10 @@ def _train_one_batch(self, data, env): return score_mean, loss_mean - def _fast_val(self, model, data=None, val_episodes=64): + def _fast_val(self, model, data=None, val_episodes=32): aug_factor = 1 if data is None: - val_path = "../../data/TSP/tsp50_tsplib.pkl" + val_path = "../../data/TSP/tsp150_uniform.pkl" data = torch.Tensor(load_dataset(val_path)[: val_episodes]) env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)}) diff --git a/POMO/TSP/TSProblemDef.py b/POMO/TSP/TSProblemDef.py index 3539b47..5e00b68 100644 --- a/POMO/TSP/TSProblemDef.py +++ b/POMO/TSP/TSProblemDef.py @@ -12,7 +12,7 @@ def generate_task_set(meta_params): if meta_params['data_type'] == "distribution": # focus on the TSP100 with different distributions task_set = [(m, l) for l in [1, 10, 20, 30, 50] for m in range(1, 1 + meta_params['num_task'] // 5)] + [(0, 0)] elif meta_params['data_type'] == "size": # focus on uniform distribution with different sizes - task_set = [(n,) for n in range(5, 5 + 5 * meta_params['num_task'], 5)] + task_set = [(n,) for n in range(10, 10 + 10 * meta_params['num_task'], 10)] elif meta_params['data_type'] == "size_distribution": task_set = [(m, l) for l in [1, 10, 20, 30, 50] for m in range(1, 11)] + [(0, 0)] task_set = [(n, m, l) for n in [25, 50, 75, 100, 125, 150] for (m, l) in task_set] diff --git a/POMO/TSP/train_n100.py b/POMO/TSP/train_n100.py index 27c6749..ec7aa5a 100644 --- a/POMO/TSP/train_n100.py +++ b/POMO/TSP/train_n100.py @@ -8,11 +8,10 @@ from TSPTrainer import TSPTrainer as Trainer from TSPTrainer_pomo import TSPTrainer as Trainer_Pomo from TSPTrainer_Meta import TSPTrainer as Trainer_Meta -from TSPTrainer_Scheduler import TSPTrainer as Trainer_Scheduler DEBUG_MODE = False USE_CUDA = not DEBUG_MODE and torch.cuda.is_available() -CUDA_DEVICE_NUM = 0 # $ nohup python -u train_n100.py 2>&1 &, no need to use CUDA_VISIBLE_DEVICES=0 +CUDA_DEVICE_NUM = 1 # $ nohup python -u train_n100.py 2>&1 &, no need to use CUDA_VISIBLE_DEVICES=0 ########################################################################################## # parameters @@ -50,7 +49,7 @@ 'seed': 1234, 'epochs': 500, 'time_limit': 86400, - 'stop_criterion': 'epochs', # epochs or time + 'stop_criterion': 'time', # epochs or time 'train_episodes': 100000, # number of instances per epoch 'train_batch_size': 64, 'logging': { @@ -71,15 +70,17 @@ # 'epoch': 510, # epoch version of pre-trained model to laod. }, + # For fomaml, k needs to be small (1 or 2), but the performance is still inferior. + # For reptile, performance is quite well, however, after several iteration, the improvement in inner-loop is trivial. 'meta_params': { 'enable': True, # whether use meta-learning or not - 'meta_method': 'reptile', # choose from ['maml', 'fomaml', 'reptile', 'ours'] + 'meta_method': 'maml', # choose from ['maml', 'fomaml', 'reptile', 'ours'] 'data_type': 'size', # choose from ["size", "distribution", "size_distribution"] 'epochs': 52084, # the number of meta-model updates: (500*100000) / (3*50*64) - 'B': 3, # the number of tasks in a mini-batch - 'k': 50, # gradient decent steps in the inner-loop optimization of meta-learning method + 'B': 1, # the number of tasks in a mini-batch + 'k': 3, # gradient decent steps in the inner-loop optimization of meta-learning method 'meta_batch_size': 64, # the batch size of the inner-loop optimization - 'num_task': 20, # the number of tasks in the training task set + 'num_task': 10, # the number of tasks in the training task set 'alpha': 0.99, # params for the outer-loop optimization of reptile 'alpha_decay': 0.999, # params for the outer-loop optimization of reptile }