Skip to content

Commit

Permalink
update tsp_distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
RoyalSkye committed Nov 2, 2022
1 parent 1f4821c commit fdb346e
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 118 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ __pycache__/

# data & pretrain-model
backup/
AM/
data/
pretrained/

Expand Down
4 changes: 2 additions & 2 deletions POMO/TSP/TSPTester.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,11 @@ def _fine_tune_one_batch(self, fine_tune_data):
prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2)

# Loss
aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size).permute(1, 0, 2).view(batch_size, -1)
aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size).permute(1, 0, 2).reshape(batch_size, -1)
# shape: (batch, augmentation * pomo)
advantage = aug_reward - aug_reward.float().mean(dim=1, keepdims=True)
# shape: (batch, augmentation * pomo)
log_prob = prob_list.log().sum(dim=2).reshape(aug_factor, batch_size, self.env.pomo_size).permute(1, 0, 2).view(batch_size, -1)
log_prob = prob_list.log().sum(dim=2).reshape(aug_factor, batch_size, self.env.pomo_size).permute(1, 0, 2).reshape(batch_size, -1)
# size = (batch, augmentation * pomo)
loss = -advantage * log_prob # Minus Sign: To Increase REWARD
# shape: (batch, augmentation * pomo)
Expand Down
Loading

0 comments on commit fdb346e

Please sign in to comment.