Skip to content

Commit

Permalink
Merge pull request #226 from abcdhhhh/floyd-atsp
Browse files Browse the repository at this point in the history
implement floyd on tmat_class atsp generation
  • Loading branch information
fedebotu authored Oct 23, 2024
2 parents f979492 + c024697 commit adac2f6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
11 changes: 11 additions & 0 deletions rl4co/data/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"op": ["const", "unif", "dist"],
"mdpp": [None],
"pdp": [None],
"atsp": [None]
}


Expand Down Expand Up @@ -212,6 +213,16 @@ def generate_mdpp_data(
"action_mask": available.astype(bool),
}

def generate_atsp_data(dataset_size, atsp_size, tmat_class: bool = True):
cost_matrix = np.random.uniform(size=(dataset_size, atsp_size, atsp_size))
cost_matrix[..., np.arange(atsp_size), np.arange(atsp_size)] = 0
if tmat_class:
for i in range(atsp_size):
cost_matrix = np.minimum(cost_matrix, cost_matrix[..., :, [i]] + cost_matrix[..., [i], :])
return {
"cost_matrix": cost_matrix.astype(np.float32)
}


def generate_dataset(
filename: Union[str, List[str]] = None,
Expand Down
9 changes: 2 additions & 7 deletions rl4co/envs/routing/atsp/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,6 @@ def _generate(self, batch_size) -> TensorDict:
dms[..., torch.arange(self.num_loc), torch.arange(self.num_loc)] = 0
log.info("Using TMAT class (triangle inequality): {}".format(self.tmat_class))
if self.tmat_class:
while True:
old_dms = dms.clone()
dms, _ = (
dms[..., :, None, :] + dms[..., None, :, :].transpose(-2, -1)
).min(dim=-1)
if (dms == old_dms).all():
break
for i in range(self.num_loc):
dms = torch.minimum(dms, dms[..., :, [i]] + dms[..., [i], :])
return TensorDict({"cost_matrix": dms}, batch_size=batch_size)

0 comments on commit adac2f6

Please sign in to comment.