diff --git a/rl4co/data/generate_data.py b/rl4co/data/generate_data.py index a2122369..5d837ba1 100644 --- a/rl4co/data/generate_data.py +++ b/rl4co/data/generate_data.py @@ -20,6 +20,7 @@ "op": ["const", "unif", "dist"], "mdpp": [None], "pdp": [None], + "atsp": [None] } @@ -212,6 +213,16 @@ def generate_mdpp_data( "action_mask": available.astype(bool), } +def generate_atsp_data(dataset_size, atsp_size, tmat_class: bool = True): + cost_matrix = np.random.uniform(size=(dataset_size, atsp_size, atsp_size)) + cost_matrix[..., np.arange(atsp_size), np.arange(atsp_size)] = 0 + if tmat_class: + for i in range(atsp_size): + cost_matrix = np.minimum(cost_matrix, cost_matrix[..., :, [i]] + cost_matrix[..., [i], :]) + return { + "cost_matrix": cost_matrix.astype(np.float32) + } + def generate_dataset( filename: Union[str, List[str]] = None, diff --git a/rl4co/envs/routing/atsp/generator.py b/rl4co/envs/routing/atsp/generator.py index 89e381ca..31208005 100644 --- a/rl4co/envs/routing/atsp/generator.py +++ b/rl4co/envs/routing/atsp/generator.py @@ -61,11 +61,6 @@ def _generate(self, batch_size) -> TensorDict: dms[..., torch.arange(self.num_loc), torch.arange(self.num_loc)] = 0 log.info("Using TMAT class (triangle inequality): {}".format(self.tmat_class)) if self.tmat_class: - while True: - old_dms = dms.clone() - dms, _ = ( - dms[..., :, None, :] + dms[..., None, :, :].transpose(-2, -1) - ).min(dim=-1) - if (dms == old_dms).all(): - break + for i in range(self.num_loc): + dms = torch.minimum(dms, dms[..., :, [i]] + dms[..., [i], :]) return TensorDict({"cost_matrix": dms}, batch_size=batch_size)