diff --git a/README.md b/README.md new file mode 100644 index 0000000..ae63296 --- /dev/null +++ b/README.md @@ -0,0 +1,87 @@ +# 2D-Ptr +Source code for paper "2D-Ptr: 2D Array Pointer Network for Solving the Heterogeneous Capacitated Vehicle Routing Problem" + +## Dependencies + +- Python>=3.8 +- NumPy +- SciPy +- [PyTorch](http://pytorch.org/)>=1.12.1 +- tqdm +- [tensorboard_logger](https://github.com/TeamHG-Memex/tensorboard_logger) + +## Quick start + +The implementation of the 2D-Ptr model is mainly in the file `./nets/attention_model.py` + +For testing HCVRP instances with 60 customers and 5 vehicles (V5-U60) and using pre-trained model: + +```shell +# greedy +python eval.py data/hcvrp/hcvrp_v5_60_seed24610.pkl --model outputs/hcvrp_v5_60 --obj min-max --decode_strategy greedy --eval_batch_size 1 +# sample1280 +python eval.py data/hcvrp/hcvrp_v5_60_seed24610.pkl --model outputs/hcvrp_v5_60 --obj min-max --decode_strategy sample --width 1280 --eval_batch_size 1 +# sample12800 +python eval.py data/hcvrp/hcvrp_v5_60_seed24610.pkl --model outputs/hcvrp_v5_60 --obj min-max --decode_strategy sample --width 12800 --eval_batch_size 1 +``` + +Since AAMAS limits the submission file size within 25Mb, we can only provide the pre-trained model on V5-U60 to avoid exceeding the limit. + +## Usage + +### Generating data + +We have provided all the well-generated test datasets in `./data`, and you can also generate each test set by: + +```shell +python generate_data.py --dataset_size 1280 --veh_num 3 --graph_size 40 +``` + +- The `--graph_size` and `--veh_num` represent the number of customers , vehicles and generated instances, respectively. + +- The default random seed is 24610, and you can change it in `./generate_data.py`. +- The test set will be stored in `./data/hcvrp/` + +### Training + +For training HCVRP instances with 40 customers and 3 vehicles (V3-U40): + +```shell +python run.py --graph_size 40 --veh_num 3 --baseline rollout --run_name hcvrp_v3_40_rollout --obj min-max +``` + +- `--run_name` will be automatically appended with a timestamp, as the unique subpath for logs and checkpoints. +- The log based on Tensorboard will be stored in `./log/`, and the checkpoint (or the well-trained model) will be stored in `./outputs/` +- `--obj` represents the objective function, supporting `min-max` and `min-sum` + +By default, training will happen on all available GPUs. Change the code in `./run.py` to only use specific GPUs: + +```python +if __name__ == "__main__": + warnings.filterwarnings('ignore') + # os.environ["CUDA_VISIBLE_DEVICES"] = "0" + run(get_options()) +``` + +### Evaluation + +you can test a well-trained model on HCVRP instances with any problem size: + +```shell +# greedy +python eval.py data/hcvrp/hcvrp_v3_40_seed24610.pkl --model outputs/hcvrp_v3_40 --obj min-max --decode_strategy greedy --eval_batch_size 1 +# sample1280 +python eval.py data/hcvrp/hcvrp_v3_40_seed24610.pkl --model outputs/hcvrp_v3_40 --obj min-max --decode_strategy sample --width 1280 --eval_batch_size 1 +# sample12800 +python eval.py data/hcvrp/hcvrp_v3_40_seed24610.pkl --model outputs/hcvrp_v3_40 --obj min-max --decode_strategy sample --width 12800 --eval_batch_size 1 +``` + +- The `--model` represents the directory where the used model is located. +- The `$filename$.pkl` represents the test set. +- The `--width` represents sampling number, which is only available when `--decode_strategy` is `sample`. +- The `--eval_batch_size` is set to 1 for serial evaluation. + + + + + diff --git a/data/hcvrp/hcvrp_v3_100_seed24610.pkl b/data/hcvrp/hcvrp_v3_100_seed24610.pkl new file mode 100644 index 0000000..bba3745 Binary files /dev/null and b/data/hcvrp/hcvrp_v3_100_seed24610.pkl differ diff --git a/data/hcvrp/hcvrp_v3_40_seed24610.pkl b/data/hcvrp/hcvrp_v3_40_seed24610.pkl new file mode 100644 index 0000000..ea959e5 Binary files /dev/null and b/data/hcvrp/hcvrp_v3_40_seed24610.pkl differ diff --git a/data/hcvrp/hcvrp_v3_60_seed24610.pkl b/data/hcvrp/hcvrp_v3_60_seed24610.pkl new file mode 100644 index 0000000..784bdda Binary files /dev/null and b/data/hcvrp/hcvrp_v3_60_seed24610.pkl differ diff --git a/data/hcvrp/hcvrp_v3_80_seed24610.pkl b/data/hcvrp/hcvrp_v3_80_seed24610.pkl new file mode 100644 index 0000000..a13bdbe Binary files /dev/null and b/data/hcvrp/hcvrp_v3_80_seed24610.pkl differ diff --git a/data/hcvrp/hcvrp_v5_100_seed24610.pkl b/data/hcvrp/hcvrp_v5_100_seed24610.pkl new file mode 100644 index 0000000..f7b002e Binary files /dev/null and b/data/hcvrp/hcvrp_v5_100_seed24610.pkl differ diff --git a/data/hcvrp/hcvrp_v5_40_seed24610.pkl b/data/hcvrp/hcvrp_v5_40_seed24610.pkl new file mode 100644 index 0000000..c919918 Binary files /dev/null and b/data/hcvrp/hcvrp_v5_40_seed24610.pkl differ diff --git a/data/hcvrp/hcvrp_v5_60_seed24610.pkl b/data/hcvrp/hcvrp_v5_60_seed24610.pkl new file mode 100644 index 0000000..985060b Binary files /dev/null and b/data/hcvrp/hcvrp_v5_60_seed24610.pkl differ diff --git a/data/hcvrp/hcvrp_v5_80_seed24610.pkl b/data/hcvrp/hcvrp_v5_80_seed24610.pkl new file mode 100644 index 0000000..b5f91c8 Binary files /dev/null and b/data/hcvrp/hcvrp_v5_80_seed24610.pkl differ diff --git a/data/hcvrp/hcvrp_v7_100_seed24610.pkl b/data/hcvrp/hcvrp_v7_100_seed24610.pkl new file mode 100644 index 0000000..fb9f3fd Binary files /dev/null and b/data/hcvrp/hcvrp_v7_100_seed24610.pkl differ diff --git a/data/hcvrp/hcvrp_v7_40_seed24610.pkl b/data/hcvrp/hcvrp_v7_40_seed24610.pkl new file mode 100644 index 0000000..fe9793a Binary files /dev/null and b/data/hcvrp/hcvrp_v7_40_seed24610.pkl differ diff --git a/data/hcvrp/hcvrp_v7_60_seed24610.pkl b/data/hcvrp/hcvrp_v7_60_seed24610.pkl new file mode 100644 index 0000000..9cd91c2 Binary files /dev/null and b/data/hcvrp/hcvrp_v7_60_seed24610.pkl differ diff --git a/data/hcvrp/hcvrp_v7_80_seed24610.pkl b/data/hcvrp/hcvrp_v7_80_seed24610.pkl new file mode 100644 index 0000000..546ce4c Binary files /dev/null and b/data/hcvrp/hcvrp_v7_80_seed24610.pkl differ diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..a506fb8 --- /dev/null +++ b/eval.py @@ -0,0 +1,225 @@ +# used after model is completely trained, and test for results + +import math +import torch +import os +import argparse +import numpy as np +import itertools +from tqdm import tqdm +from utils import load_model, move_to +from utils.data_utils import save_dataset +from torch.utils.data import DataLoader +import time +from datetime import timedelta +from utils.functions import parse_softmax_temperature +import warnings + +mp = torch.multiprocessing.get_context('spawn') + + +def get_best(sequences, cost, veh_lists, ids=None, batch_size=None): + """ + Ids contains [0, 0, 0, 1, 1, 2, ..., n, n, n] if 3 solutions found for 0th instance, 2 for 1st, etc + :param sequences: + :param lengths: + :param ids: + :return: list with n sequences and list with n lengths of solutions + """ + if ids is None: + idx = cost.argmin() + return sequences[idx:idx + 1, ...], cost[idx:idx + 1, ...], veh_lists[idx:idx + 1, ...] + + splits = np.hstack([0, np.where(ids[:-1] != ids[1:])[0] + 1]) + mincosts = np.minimum.reduceat(cost, splits) + + group_lengths = np.diff(np.hstack([splits, len(ids)])) + all_argmin = np.flatnonzero(np.repeat(mincosts, group_lengths) == cost) + result = np.full(len(group_lengths) if batch_size is None else batch_size, -1, dtype=int) + + result[ids[all_argmin[::-1]]] = all_argmin[::-1] + + return [sequences[i] if i >= 0 else None for i in result], [cost[i] if i >= 0 else math.inf for i in result], [ + veh_lists[i] if i >= 0 else None for i in result] + + +def eval_dataset_mp(args): + (dataset_path, width, softmax_temp, opts, i, num_processes) = args + + model, _ = load_model(opts.model, opts.obj) + val_size = opts.val_size // num_processes + dataset = model.problem.make_dataset(filename=dataset_path, num_samples=val_size, offset=opts.offset + val_size * i) + device = torch.device("cuda:{}".format(i)) + + return _eval_dataset(model, dataset, width, softmax_temp, opts, device) + + +def eval_dataset(dataset_path, width, softmax_temp, opts): + # Even with multiprocessing, we load the model here since it contains the name where to write results + model, _ = load_model(opts.model, opts.obj) + use_cuda = torch.cuda.is_available() and not opts.no_cuda + if opts.multiprocessing: + assert use_cuda, "Can only do multiprocessing with cuda" + num_processes = torch.cuda.device_count() + assert opts.val_size % num_processes == 0 + + with mp.Pool(num_processes) as pool: + results = list(itertools.chain.from_iterable(pool.map( + eval_dataset_mp, + [(dataset_path, width, softmax_temp, opts, i, num_processes) for i in range(num_processes)] + ))) + + else: + device = torch.device("cuda:0" if use_cuda else "cpu") + dataset = model.problem.make_dataset(filename=dataset_path, num_samples=opts.val_size, offset=opts.offset) + results = _eval_dataset(model, dataset, width, softmax_temp, opts, device) + + # This is parallelism, even if we use multiprocessing (we report as if we did not use multiprocessing, e.g. 1 GPU) + parallelism = opts.eval_batch_size + + costs, tours, veh_lists, durations = zip(*results) # Not really costs since they should be negative + + print("Average cost: {} +- {}".format(np.mean(costs), 2 * np.std(costs) / np.sqrt(len(costs)))) + print("Average serial duration: {} +- {}".format( + np.mean(durations), 2 * np.std(durations) / np.sqrt(len(durations)))) + print("Average parallel duration: {}".format(np.mean(durations) / parallelism)) + print("Calculated total duration: {}".format(timedelta(seconds=int(np.sum(durations) / parallelism)))) + # print('tour is', costs[0], len(tours), len(tours[0]), tours[0]) + # print('veh', veh_lists[1]) + # print('tour is', costs[1], len(tours), len(tours[1]), tours[1]) + + dataset_basename, ext = os.path.splitext(os.path.split(dataset_path)[-1]) + model_name = "_".join(os.path.normpath(os.path.splitext(opts.model)[0]).split(os.sep)[-2:]) + if opts.o is None: + results_dir = os.path.join(opts.results_dir, model.problem.NAME, dataset_basename) + os.makedirs(results_dir, exist_ok=True) + + out_file = os.path.join(results_dir, "{}-{}-{}{}-t{}-{}-{}{}".format( + dataset_basename, model_name, + opts.decode_strategy, + width if opts.decode_strategy != 'greedy' else '', + softmax_temp, opts.offset, opts.offset + len(costs), ext + )) + else: + out_file = opts.o + + assert opts.f or not os.path.isfile( + out_file), "File already exists! Try running with -f option to overwrite." + + save_dataset((results, parallelism), out_file) + + return costs, tours, durations + + +def _eval_dataset(model, dataset, width, softmax_temp, opts, device): + # print('data', dataset[0]) + model.to(device) + model.eval() + + model.set_decode_type( + "greedy" if opts.decode_strategy in ('bs', 'greedy') else "sampling", + temp=softmax_temp) + + dataloader = DataLoader(dataset, batch_size=opts.eval_batch_size) + + results = [] + for batch in tqdm(dataloader, disable=opts.no_progress_bar): + batch = move_to(batch, device) + start = time.time() + with torch.no_grad(): + if opts.decode_strategy in ('sample', 'greedy'): + if opts.decode_strategy == 'greedy': + assert width == 0, "Do not set width when using greedy" + assert opts.eval_batch_size <= opts.max_calc_batch_size, \ + "eval_batch_size should be smaller than calc batch size" + batch_rep = 1 + iter_rep = 1 + elif width * opts.eval_batch_size > opts.max_calc_batch_size: + assert opts.eval_batch_size == 1 + assert width % opts.max_calc_batch_size == 0 + batch_rep = opts.max_calc_batch_size + iter_rep = width // opts.max_calc_batch_size + else: + batch_rep = width + iter_rep = 1 + assert batch_rep > 0 + # This returns (batch_size, iter_rep shape) + sequences, costs, veh_lists = model.sample_many(batch, batch_rep=batch_rep, iter_rep=iter_rep) + print('cost', costs) + batch_size = len(costs) + ids = torch.arange(batch_size, dtype=torch.int64, device=costs.device) + else: + assert opts.decode_strategy == 'bs' + + cum_log_p, sequences, costs, ids, batch_size = model.beam_search( + batch, beam_size=width, + compress_mask=opts.compress_mask, + max_calc_batch_size=opts.max_calc_batch_size + ) + + if sequences is None: + sequences = [None] * batch_size + costs = [math.inf] * batch_size + veh_lists = [None] * batch_size + else: + sequences, costs, veh_lists = get_best( + sequences.cpu().numpy(), costs.cpu().numpy(), veh_lists.cpu().numpy(), + ids.cpu().numpy() if ids is not None else None, + batch_size + ) + + duration = time.time() - start + for seq, cost, veh_list in zip(sequences, costs, veh_lists): + if model.problem.NAME in ("hcvrp"): + seq = seq.tolist() # No need to trim as all are same length + else: + assert False, "Unkown problem: {}".format(model.problem.NAME) + # Note VRP only + results.append((cost, seq, veh_list, duration)) + + return results + + +if __name__ == "__main__": + warnings.filterwarnings('ignore') + + parser = argparse.ArgumentParser() + parser.add_argument("datasets", nargs='+', help="Filename of the dataset(s) to evaluate") + parser.add_argument("-f", action='store_true', help="Set true to overwrite") + parser.add_argument("-o", default=None, help="Name of the results file to write") + parser.add_argument('--val_size', type=int, default=10000, + help='Number of instances used for reporting validation performance') + parser.add_argument('--offset', type=int, default=0, + help='Offset where to start in dataset (default 0)') + parser.add_argument('--eval_batch_size', type=int, default=1024, + help="Batch size to use during (baseline) evaluation") + # parser.add_argument('--decode_type', type=str, default='greedy', + # help='Decode type, greedy or sampling') + parser.add_argument('--width', type=int, nargs='+', + help='Sizes of beam to use for beam search (or number of samples for sampling), ' + '0 to disable (default), -1 for infinite') + parser.add_argument('--decode_strategy', type=str, + help='Beam search (bs), Sampling (sample) or Greedy (greedy)') + parser.add_argument('--softmax_temperature', type=parse_softmax_temperature, default=1, + help="Softmax temperature (sampling or bs)") + parser.add_argument('--model', type=str) + parser.add_argument('--no_cuda', action='store_true', help='Disable CUDA') + parser.add_argument('--no_progress_bar', action='store_true', help='Disable progress bar') + parser.add_argument('--compress_mask', action='store_true', help='Compress mask into long') + parser.add_argument('--max_calc_batch_size', type=int, default=10000000, help='Size for subbatches') + parser.add_argument('--results_dir', default='results', help="Name of results directory") + parser.add_argument('--obj', default=['min-max', 'min-sum']) + parser.add_argument('--multiprocessing', action='store_true', + help='Use multiprocessing to parallelize over multiple GPUs') + + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + opts = parser.parse_args() + + assert opts.o is None or (len(opts.datasets) == 1 and len(opts.width) <= 1), \ + "Cannot specify result filename with more than one dataset or more than one width" + + widths = opts.width if opts.width is not None else [0] + + for width in widths: + for dataset_path in opts.datasets: + eval_dataset(dataset_path, width, opts.softmax_temperature, opts) diff --git a/generate_data.py b/generate_data.py new file mode 100644 index 0000000..e88e81c --- /dev/null +++ b/generate_data.py @@ -0,0 +1,53 @@ +import os +import numpy as np +from utils.data_utils import check_extension, save_dataset +import torch +import pickle +import argparse + +def generate_hcvrp_data(seed,dataset_size, hcvrp_size, veh_num): + rnd = np.random.RandomState(seed) + + loc = rnd.uniform(0, 1, size=(dataset_size, hcvrp_size + 1, 2)) + depot = loc[:, -1] + cust = loc[:, :-1] + d = rnd.randint(1, 10, [dataset_size, hcvrp_size + 1]) + d = d[:, :-1] # the demand of depot is 0, which do not need to generate here + + # vehicle feature + speed = rnd.uniform(0.5, 1, size=(dataset_size, veh_num)) + cap = rnd.randint(20, 41, size=(dataset_size, veh_num)) + + data = { + 'depot': depot.astype(np.float32), + 'loc': cust.astype(np.float32), + 'demand': d.astype(np.float32), + 'capacity': cap.astype(np.float32), + 'speed': speed.astype(np.float32) + } + return data + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # parser.add_argument("--filename", help="Filename of the dataset to create (ignores datadir)") + parser.add_argument("--dataset_size", type=int, default=1280, help="Size of the dataset") + parser.add_argument("--veh_num", type=int, default=3, help="number of the vehicles") + parser.add_argument('--graph_size', type=int, default=40, + help="Number of customers") + + opts = parser.parse_args() + data_dir = 'data' + problem = 'hcvrp' + datadir = os.path.join(data_dir, problem) + os.makedirs(datadir, exist_ok=True) + seed = 24610 # the last seed used for generating HCVRP data + # np.random.seed(seed) + print(opts.dataset_size, opts.graph_size, opts.veh_num) + filename = os.path.join(datadir, '{}_v{}_{}_seed{}.pkl'.format(problem, opts.veh_num, opts.graph_size, seed)) + + dataset = generate_hcvrp_data(seed,opts.dataset_size, opts.graph_size, opts.veh_num) + print({k:dataset[k][0] for k in dataset}) + save_dataset(dataset, filename) + + + diff --git a/main.py b/main.py new file mode 100644 index 0000000..e69de29 diff --git a/nets/__init__.py b/nets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nets/attention_model.py b/nets/attention_model.py new file mode 100644 index 0000000..d5f5b6f --- /dev/null +++ b/nets/attention_model.py @@ -0,0 +1,230 @@ +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint +import math +from typing import NamedTuple + +from problems.hcvrp.hcvrp import HcvrpEnv +from utils.tensor_functions import compute_in_batches + +from nets.graph_encoder import GraphAttentionEncoder, MultiHeadAttention, MultiHeadAttentionLayer +from torch.nn import DataParallel +from utils.beam_search import CachedLookup +from utils.functions import sample_many +import copy +import random + + +def set_decode_type(model, decode_type): + if isinstance(model, DataParallel): + model = model.module + model.set_decode_type(decode_type) + + +class AttentionModelFixed(NamedTuple): + """ + Context for AttentionModel decoder that is fixed during decoding so can be precomputed/cached + This class allows for efficient indexing of multiple Tensors at once + """ + node_embeddings: torch.Tensor + context_node_projected: torch.Tensor + glimpse_key: torch.Tensor + glimpse_val: torch.Tensor + logit_key: torch.Tensor + + def __getitem__(self, key): + # if torch.is_tensor(key) or isinstance(key, slice): + assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor: + return AttentionModelFixed( + node_embeddings=self.node_embeddings[key], + context_node_projected=self.context_node_projected[key], + glimpse_key=self.glimpse_key[:, key], # dim 0 are the heads + glimpse_val=self.glimpse_val[:, key], # dim 0 are the heads + logit_key=self.logit_key[key] + ) + # return super(AttentionModelFixed, self).__getitem__(key) + + +# 2D-Ptr +class AttentionModel(nn.Module): + + def __init__(self, + embedding_dim, + hidden_dim, + obj, + problem, + n_encode_layers=2, + tanh_clipping=10., + mask_inner=True, + mask_logits=True, + normalization='batch', + n_heads=8, + checkpoint_encoder=False, + shrink_size=None): + super(AttentionModel, self).__init__() + + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim # deprecated + self.obj = obj + self.n_encode_layers = n_encode_layers + self.decode_type = None + self.temp = 1.0 + self.is_hcvrp = problem.NAME == 'hcvrp' + self.feed_forward_hidden = 4*embedding_dim + + self.tanh_clipping = tanh_clipping + + self.mask_inner = mask_inner + self.mask_logits = mask_logits + + self.problem = problem + self.n_heads = n_heads + self.checkpoint_encoder = checkpoint_encoder + self.shrink_size = shrink_size + self.depot_token = nn.Parameter(torch.randn(embedding_dim)) # depot token + self.init_embed = nn.Linear(3, embedding_dim) # embed linear in customer encoder + self.node_encoder = GraphAttentionEncoder( + n_heads=n_heads, + embed_dim=embedding_dim, + n_layers=self.n_encode_layers, + normalization=normalization, + feed_forward_hidden=self.feed_forward_hidden + ) + + self.veh_encoder_mlp = nn.Sequential( + nn.Linear(4, embedding_dim * 4), + nn.ReLU(), + nn.Linear(embedding_dim * 4, embedding_dim) + ) + self.veh_encoder_self_attention = MultiHeadAttention(n_heads=n_heads, input_dim=embedding_dim, embed_dim=embedding_dim) + self.veh_encoder_ca_node_linear_kv = nn.Linear(embedding_dim,2*embedding_dim) # veh-node-cross-attn w_k,w_v + self.veh_encoder_ca_veh_linear_q = nn.Linear(embedding_dim, embedding_dim) # veh-node-cross-attn w_q + self.veh_encoder_ca_linear_o = nn.Linear(embedding_dim, embedding_dim) # veh-node-cross-attn w_o + self.veh_encoder_w = nn.Linear(2*embedding_dim,embedding_dim) + assert embedding_dim % n_heads == 0 + + def pre_calculate_node(self,input): + nhead = self.n_heads + env = HcvrpEnv(input, scale=(1, 40, 1)) + # embed node (depot and customer) + node_embedding = self.init_embed(env.get_all_node_state()) + # add depot token + node_embedding[:, 0] = node_embedding[:, 0] + self.depot_token + node_embedding = self.node_encoder(node_embedding)[0] + bs,N,d = node_embedding.size() + # pre-calculate the K,V of the cross-attention in vehcle encoder, avoid double calculation + kv = self.veh_encoder_ca_node_linear_kv(node_embedding).reshape(bs,N,nhead,-1).transpose(1,2) # bs,nhead,N,d_k*2 + k,v = torch.chunk(kv,2,-1) # bs,nhead,n,d_k,bs,nhead,n,d_k, + return input, node_embedding, (k, v) + def veh_encoder_cross_attention(self,veh_em,node_kv,mask=None): + ''' + :param veh_em: + :param node_kv: + :param action_mask: bs,M,N + :return: + ''' + bs,M,d = veh_em.size() + nhead = self.n_heads + k,v = node_kv + q = self.veh_encoder_ca_veh_linear_q(veh_em).reshape(bs,M,nhead,-1).transpose(1,2) # bs,nhead,M,d_k + attn = q @ k.transpose(-1,-2)/np.sqrt(q.size(-1)) # bs,nhead,M,N + if mask is not None: + attn[mask.unsqueeze(1).expand(attn.size())]=-math.inf + attn = attn.softmax(-1) #bs,nhead,M,N + out = attn @ v # bs,nhead,M,d_k + out = self.veh_encoder_ca_linear_o(out.transpose(1,2).reshape(bs,M,-1)) # bs,M,d + return out + + def set_decode_type(self, decode_type, temp=None): + self.decode_type = decode_type + if temp is not None: # Do not change temperature if not provided + self.temp = temp + + def forward(self, input, return_pi=False): + """ + :param input: (batch_size, graph_size, node_dim) input node features or dictionary with multiple tensors + :param return_pi: whether to return the output sequences, this is optional as it is not compatible with + using DataParallel as the results may be of different lengths on different GPUs + :return: + """ + input, node_embedding,node_kv = self.pre_calculate_node(input) + # input, node_embedding, veh_embedding = self.initial_em(input) + ll, pi, veh_list, cost = self._inner(input, node_embedding, node_kv) + if return_pi: + return cost, ll, pi + return cost, ll + + def _inner(self,input ,node_embeddings,node_kv): + env = HcvrpEnv(input, scale=(1, 40, 1)) + ll,pi,veh_list=[],[],[] + while not env.all_finished(): + # update vehicle embeddings + veh_embeddings = self.veh_encoder(node_embeddings,node_kv,env) + # select action + veh, node, log_p = self.decoder(veh_embeddings, node_embeddings, mask=env.get_action_mask()) + # update env + env.update(veh,node) + + veh_list.append(veh) + pi.append(node) + ll.append(log_p) + + # get the final cost + cost = env.get_cost(self.obj) + ll = torch.stack(ll, 1) # bs,step + pi = torch.stack(pi, 1) # bs,step + veh_list = torch.stack(veh_list, 1) # bs,step + return ll.sum(1), pi, veh_list,cost + + def decoder(self,q_em,k_em,mask=None): + ''' + :param q_em: Q: bs,m,d + :param k_em: K: bs,n,d + :param mask: bs,m,n + :return: selected index,log_pro + ''' + bs,m,d = q_em.size() + _,n,_ = k_em.size() + bs_index = torch.arange(bs,device=q_em.device) + logits = (q_em @ k_em.transpose(1, 2) / np.sqrt(d)) # bs,m,n + if self.tanh_clipping > 0: # 10 + logits = torch.tanh(logits) * self.tanh_clipping + if self.mask_logits: # True + if mask is not None: + logits[mask] = -math.inf + logits = logits.reshape(bs,-1) # bs,M*N + p = logits.softmax(1) # bs,M*N + if self.decode_type=='greedy': + selected = p.max(1)[1] # bs + else: + selected = p.multinomial(1).squeeze(1) + log_p = p[bs_index,selected].log() + veh,node = selected//n,selected%n + return veh,node,log_p + def veh_encoder(self,node_embeddings,node_kv,env): + veh_embeddings = self.veh_encoder_mlp(env.get_all_veh_state()) + bs, N, d = node_embeddings.size() + bs, M, d = veh_embeddings.size() + bs_index = torch.arange(bs, device=node_embeddings.device) + veh_node_em = node_embeddings[bs_index.unsqueeze(-1),env.veh_cur_node.clone()] # PE:bs,M,d + veh_embeddings = self.veh_encoder_w(torch.cat([veh_node_em,veh_embeddings],dim=-1)) + mask = env.visited.clone() + # depot will not be masked + mask[:,0] = False + mask = mask.unsqueeze(1).expand(bs,M,N) + veh_embeddings = veh_embeddings + self.veh_encoder_self_attention(veh_embeddings) + veh_embeddings = veh_embeddings + self.veh_encoder_cross_attention(veh_embeddings,node_kv,mask) + return veh_embeddings + + + def sample_many(self, input, batch_rep=1, iter_rep=1): + + return sample_many( + lambda input: self._inner(*input), # Need to unpack tuple into arguments + None, + self.pre_calculate_node(input), # Pack input with embeddings (additional input) + batch_rep, iter_rep + ) + diff --git a/nets/critic_network.py b/nets/critic_network.py new file mode 100644 index 0000000..73d8926 --- /dev/null +++ b/nets/critic_network.py @@ -0,0 +1,40 @@ +from torch import nn +from nets.graph_encoder import GraphAttentionEncoder + + +class CriticNetwork(nn.Module): + + def __init__( + self, + input_dim, + embedding_dim, + hidden_dim, + n_layers, + encoder_normalization + ): + super(CriticNetwork, self).__init__() + + self.hidden_dim = hidden_dim + + self.encoder = GraphAttentionEncoder( + node_dim=input_dim, + n_heads=8, + embed_dim=embedding_dim, + n_layers=n_layers, + normalization=encoder_normalization + ) + + self.value_head = nn.Sequential( + nn.Linear(embedding_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, 1) + ) + + def forward(self, inputs): + """ + + :param inputs: (batch_size, graph_size, input_dim) + :return: + """ + _, graph_embeddings = self.encoder(inputs) + return self.value_head(graph_embeddings) diff --git a/nets/graph_encoder.py b/nets/graph_encoder.py new file mode 100644 index 0000000..ff56cbd --- /dev/null +++ b/nets/graph_encoder.py @@ -0,0 +1,210 @@ +import torch +import torch.nn.functional as F +import numpy as np +from torch import nn +import math + + +class SkipConnection(nn.Module): + + def __init__(self, module): + super(SkipConnection, self).__init__() + self.module = module + + def forward(self, input): + return input + self.module(input) # skip connection + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + n_heads, + input_dim, + embed_dim=None, + val_dim=None, + key_dim=None + ): + super(MultiHeadAttention, self).__init__() + + if val_dim is None: + assert embed_dim is not None, "Provide either embed_dim or val_dim" + val_dim = embed_dim // n_heads + if key_dim is None: + key_dim = val_dim + + self.n_heads = n_heads + self.input_dim = input_dim + self.embed_dim = embed_dim + self.val_dim = val_dim + self.key_dim = key_dim + + self.norm_factor = 1 / math.sqrt(key_dim) # See Attention is all you need + + self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) + self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) + self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim)) + + if embed_dim is not None: + self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim)) + + self.init_parameters() + + def init_parameters(self): + + for param in self.parameters(): + stdv = 1. / math.sqrt(param.size(-1)) + param.data.uniform_(-stdv, stdv) + + def forward(self, q, h=None, mask=None): + """ + + :param q: queries (batch_size, n_query, input_dim) + :param h: data (batch_size, graph_size, input_dim) + :param mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1) + Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency) + :return: + """ + if h is None: + h = q # compute self-attention + + # h should be (batch_size, graph_size, input_dim) + batch_size, graph_size, input_dim = h.size() # input_dim=embed_dim + n_query = q.size(1) # =graph_size+1 + assert q.size(0) == batch_size + assert q.size(2) == input_dim + assert input_dim == self.input_dim, "Wrong embedding dimension of input" + + hflat = h.contiguous().view(-1, input_dim) # [batch_size * graph_size, embed_dim] + qflat = q.contiguous().view(-1, input_dim) # [batch_size * n_query, embed_dim] + + # last dimension can be different for keys and values + shp = (self.n_heads, batch_size, graph_size, -1) + shp_q = (self.n_heads, batch_size, n_query, -1) + + # Calculate queries, (n_heads, batch_size, n_query, key/val_size) + Q = torch.matmul(qflat, self.W_query).view(shp_q) + # Calculate keys and values (n_heads, batch_size, graph_size, key/val_size) + K = torch.matmul(hflat, self.W_key).view(shp) + V = torch.matmul(hflat, self.W_val).view(shp) + + # Calculate compatibility (n_heads, batch_size, n_query, graph_size) + compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3)) + + # Optionally apply mask to prevent attention + if mask is not None: + mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility) + compatibility[mask] = -np.inf + attn = F.softmax(compatibility, dim=-1) + + # If there are nodes with no neighbours then softmax returns nan so we fix them to 0 + if mask is not None: + attnc = attn.clone() + attnc[mask] = 0 + attn = attnc + + heads = torch.matmul(attn, V) # [n_heads, batrch_size, n_query, val_size] + + out = torch.mm( + heads.permute(1, 2, 0, 3).contiguous().view(-1, self.n_heads * self.val_dim), # [batch_size, n_query, n_heads*val_size] + self.W_out.view(-1, self.embed_dim) # [n_head*key_dim, embed_dim] + ).view(batch_size, n_query, self.embed_dim) + + return out + + +class Normalization(nn.Module): + + def __init__(self, embed_dim, normalization='batch'): + super(Normalization, self).__init__() + + normalizer_class = { + 'batch': nn.BatchNorm1d, + 'instance': nn.InstanceNorm1d + }.get(normalization, None) + + self.normalizer = normalizer_class(embed_dim, affine=True) + + # Normalization by default initializes affine parameters with bias 0 and weight unif(0,1) which is too large! + # self.init_parameters() + + def init_parameters(self): + + for name, param in self.named_parameters(): + stdv = 1. / math.sqrt(param.size(-1)) + param.data.uniform_(-stdv, stdv) + print('stdv', stdv) + + def forward(self, input): + + if isinstance(self.normalizer, nn.BatchNorm1d): + return self.normalizer(input.view(-1, input.size(-1))).view(*input.size()) # [batch_size, graph_size+1, embed_dim] + elif isinstance(self.normalizer, nn.InstanceNorm1d): + return self.normalizer(input.permute(0, 2, 1)).permute(0, 2, 1) + else: + assert self.normalizer is None, "Unknown normalizer type" + return input + + +# Encoder part, hi_hat and hi^l +class MultiHeadAttentionLayer(nn.Sequential): +# multihead attention -> skip connection, normalization -> feed forward -> skip connection, normalization + def __init__( + self, + n_heads, + embed_dim, + feed_forward_hidden=512, + normalization='batch', + ): + super(MultiHeadAttentionLayer, self).__init__( + SkipConnection( + MultiHeadAttention( + n_heads, + input_dim=embed_dim, + embed_dim=embed_dim + ) + ), + Normalization(embed_dim, normalization), + SkipConnection( + nn.Sequential( + nn.Linear(embed_dim, feed_forward_hidden), + nn.ReLU(), + nn.Linear(feed_forward_hidden, embed_dim) + ) if feed_forward_hidden > 0 else nn.Linear(embed_dim, embed_dim) + ), + Normalization(embed_dim, normalization) + ) + + +class GraphAttentionEncoder(nn.Module): + def __init__( + self, + n_heads, + embed_dim, + n_layers, + node_dim=None, + normalization='batch', + feed_forward_hidden=512 + ): + super(GraphAttentionEncoder, self).__init__() + + # To map input to embedding space + self.init_embed = nn.Linear(node_dim, embed_dim) if node_dim is not None else None + + self.layers = nn.Sequential(*( + MultiHeadAttentionLayer(n_heads, embed_dim, feed_forward_hidden, normalization) + for _ in range(n_layers) + )) + + def forward(self, x, mask=None): + + assert mask is None, "TODO mask not yet supported!" + + # Batch multiply to get initial embeddings of nodes h/x: [batch_size, graph_size+1, embed_dim] + h = self.init_embed(x.view(-1, x.size(-1))).view(*x.size()[:2], -1) if self.init_embed is not None else x + + h = self.layers(h) + + return ( + h, # (batch_size, graph_size, embed_dim) + h.mean(dim=1), # average to get embedding of graph, (batch_size, embed_dim) + ) diff --git a/nets/pointer_network.py b/nets/pointer_network.py new file mode 100644 index 0000000..8392f57 --- /dev/null +++ b/nets/pointer_network.py @@ -0,0 +1,357 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +import torch.nn.functional as F +import math +import numpy as np + + +class Encoder(nn.Module): + """Maps a graph represented as an input sequence + to a hidden vector""" + + def __init__(self, input_dim, hidden_dim): + super(Encoder, self).__init__() + self.hidden_dim = hidden_dim + self.lstm = nn.LSTM(input_dim, hidden_dim) + self.init_hx, self.init_cx = self.init_hidden(hidden_dim) + + def forward(self, x, hidden): + output, hidden = self.lstm(x, hidden) + return output, hidden + + def init_hidden(self, hidden_dim): + """Trainable initial hidden state""" + std = 1. / math.sqrt(hidden_dim) + enc_init_hx = nn.Parameter(torch.FloatTensor(hidden_dim)) + enc_init_hx.data.uniform_(-std, std) + + enc_init_cx = nn.Parameter(torch.FloatTensor(hidden_dim)) + enc_init_cx.data.uniform_(-std, std) + return enc_init_hx, enc_init_cx + + +class Attention(nn.Module): + """A generic attention module for a decoder in seq2seq""" + + def __init__(self, dim, use_tanh=False, C=10): + super(Attention, self).__init__() + self.use_tanh = use_tanh + self.project_query = nn.Linear(dim, dim) + self.project_ref = nn.Conv1d(dim, dim, 1, 1) + self.C = C # tanh exploration + self.tanh = nn.Tanh() + + self.v = nn.Parameter(torch.FloatTensor(dim)) + self.v.data.uniform_(-(1. / math.sqrt(dim)), 1. / math.sqrt(dim)) + + def forward(self, query, ref): + """ + Args: + query: is the hidden state of the decoder at the current + time step. batch x dim + ref: the set of hidden states from the encoder. + sourceL x batch x hidden_dim + """ + # ref is now [batch_size x hidden_dim x sourceL] + ref = ref.permute(1, 2, 0) + q = self.project_query(query).unsqueeze(2) # batch x dim x 1 + e = self.project_ref(ref) # batch_size x hidden_dim x sourceL + # expand the query by sourceL + # batch x dim x sourceL + expanded_q = q.repeat(1, 1, e.size(2)) + # batch x 1 x hidden_dim + v_view = self.v.unsqueeze(0).expand( + expanded_q.size(0), len(self.v)).unsqueeze(1) + # [batch_size x 1 x hidden_dim] * [batch_size x hidden_dim x sourceL] + u = torch.bmm(v_view, self.tanh(expanded_q + e)).squeeze(1) + if self.use_tanh: + logits = self.C * self.tanh(u) + else: + logits = u + return e, logits + + +class Decoder(nn.Module): + def __init__(self, + embedding_dim, + hidden_dim, + tanh_exploration, + use_tanh, + n_glimpses=1, + mask_glimpses=True, + mask_logits=True): + super(Decoder, self).__init__() + + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim + self.n_glimpses = n_glimpses + self.mask_glimpses = mask_glimpses + self.mask_logits = mask_logits + self.use_tanh = use_tanh + self.tanh_exploration = tanh_exploration + self.decode_type = None # Needs to be set explicitly before use + + self.lstm = nn.LSTMCell(embedding_dim, hidden_dim) + self.pointer = Attention(hidden_dim, use_tanh=use_tanh, C=tanh_exploration) + self.glimpse = Attention(hidden_dim, use_tanh=False) + self.sm = nn.Softmax(dim=1) + + def update_mask(self, mask, selected): + return mask.clone().scatter_(1, selected.unsqueeze(-1), True) + + def recurrence(self, x, h_in, prev_mask, prev_idxs, step, context): + + logit_mask = self.update_mask(prev_mask, prev_idxs) if prev_idxs is not None else prev_mask + + logits, h_out = self.calc_logits(x, h_in, logit_mask, context, self.mask_glimpses, self.mask_logits) + + # Calculate log_softmax for better numerical stability + log_p = F.log_softmax(logits, dim=1) + probs = log_p.exp() + + if not self.mask_logits: + # If self.mask_logits, this would be redundant, otherwise we must mask to make sure we don't resample + # Note that as a result the vector of probs may not sum to one (this is OK for .multinomial sampling) + # But practically by not masking the logits, a model is learned over all sequences (also infeasible) + # while only during sampling feasibility is enforced (a.k.a. by setting to 0. here) + probs[logit_mask] = 0. + # For consistency we should also mask out in log_p, but the values set to 0 will not be sampled and + # Therefore not be used by the reinforce estimator + + return h_out, log_p, probs, logit_mask + + def calc_logits(self, x, h_in, logit_mask, context, mask_glimpses=None, mask_logits=None): + + if mask_glimpses is None: + mask_glimpses = self.mask_glimpses + + if mask_logits is None: + mask_logits = self.mask_logits + + hy, cy = self.lstm(x, h_in) + g_l, h_out = hy, (hy, cy) + + for i in range(self.n_glimpses): + ref, logits = self.glimpse(g_l, context) + # For the glimpses, only mask before softmax so we have always an L1 norm 1 readout vector + if mask_glimpses: + logits[logit_mask] = -np.inf + # [batch_size x h_dim x sourceL] * [batch_size x sourceL x 1] = + # [batch_size x h_dim x 1] + g_l = torch.bmm(ref, self.sm(logits).unsqueeze(2)).squeeze(2) + _, logits = self.pointer(g_l, context) + + # Masking before softmax makes probs sum to one + if mask_logits: + logits[logit_mask] = -np.inf + + return logits, h_out + + def forward(self, decoder_input, embedded_inputs, hidden, context, eval_tours=None): + """ + Args: + decoder_input: The initial input to the decoder + size is [batch_size x embedding_dim]. Trainable parameter. + embedded_inputs: [sourceL x batch_size x embedding_dim] + hidden: the prev hidden state, size is [batch_size x hidden_dim]. + Initially this is set to (enc_h[-1], enc_c[-1]) + context: encoder outputs, [sourceL x batch_size x hidden_dim] + """ + + batch_size = context.size(1) + outputs = [] + selections = [] + steps = range(embedded_inputs.size(0)) + idxs = None + mask = Variable( + embedded_inputs.data.new().byte().new(embedded_inputs.size(1), embedded_inputs.size(0)).zero_(), + requires_grad=False + ) + + for i in steps: + hidden, log_p, probs, mask = self.recurrence(decoder_input, hidden, mask, idxs, i, context) + # select the next inputs for the decoder [batch_size x hidden_dim] + idxs = self.decode( + probs, + mask + ) if eval_tours is None else eval_tours[:, i] + + idxs = idxs.detach() # Otherwise pytorch complains it want's a reward, todo implement this more properly? + + # Gather input embedding of selected + decoder_input = torch.gather( + embedded_inputs, + 0, + idxs.contiguous().view(1, batch_size, 1).expand(1, batch_size, *embedded_inputs.size()[2:]) + ).squeeze(0) + + # use outs to point to next object + outputs.append(log_p) + selections.append(idxs) + return (torch.stack(outputs, 1), torch.stack(selections, 1)), hidden + + def decode(self, probs, mask): + if self.decode_type == "greedy": + _, idxs = probs.max(1) + assert not mask.gather(1, idxs.unsqueeze(-1)).data.any(), \ + "Decode greedy: infeasible action has maximum probability" + elif self.decode_type == "sampling": + idxs = probs.multinomial(1).squeeze(1) + # Check if sampling went OK, can go wrong due to bug on GPU + while mask.gather(1, idxs.unsqueeze(-1)).data.any(): + print(' [!] resampling due to race condition') + idxs = probs.multinomial().squeeze(1) + else: + assert False, "Unknown decode type" + + return idxs + + +class CriticNetworkLSTM(nn.Module): + """Useful as a baseline in REINFORCE updates""" + + def __init__(self, + embedding_dim, + hidden_dim, + n_process_block_iters, + tanh_exploration, + use_tanh): + super(CriticNetworkLSTM, self).__init__() + + self.hidden_dim = hidden_dim + self.n_process_block_iters = n_process_block_iters + + self.encoder = Encoder(embedding_dim, hidden_dim) + + self.process_block = Attention(hidden_dim, use_tanh=use_tanh, C=tanh_exploration) + self.sm = nn.Softmax(dim=1) + self.decoder = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, 1) + ) + + def forward(self, inputs): + """ + Args: + inputs: [embedding_dim x batch_size x sourceL] of embedded inputs + """ + inputs = inputs.transpose(0, 1).contiguous() + + encoder_hx = self.encoder.init_hx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0) + encoder_cx = self.encoder.init_cx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0) + + # encoder forward pass + enc_outputs, (enc_h_t, enc_c_t) = self.encoder(inputs, (encoder_hx, encoder_cx)) + + # grab the hidden state and process it via the process block + process_block_state = enc_h_t[-1] + for i in range(self.n_process_block_iters): + ref, logits = self.process_block(process_block_state, enc_outputs) + process_block_state = torch.bmm(ref, self.sm(logits).unsqueeze(2)).squeeze(2) + # produce the final scalar output + out = self.decoder(process_block_state) + return out + + +class PointerNetwork(nn.Module): + + def __init__(self, + embedding_dim, + hidden_dim, + problem, + n_encode_layers=None, + tanh_clipping=10., + mask_inner=True, + mask_logits=True, + normalization=None, + **kwargs): + super(PointerNetwork, self).__init__() + + self.problem = problem + assert problem.NAME == "tsp", "Pointer Network only supported for TSP" + self.input_dim = 2 + + self.encoder = Encoder( + embedding_dim, + hidden_dim) + + self.decoder = Decoder( + embedding_dim, + hidden_dim, + tanh_exploration=tanh_clipping, + use_tanh=tanh_clipping > 0, + n_glimpses=1, + mask_glimpses=mask_inner, + mask_logits=mask_logits + ) + + # Trainable initial hidden states + std = 1. / math.sqrt(embedding_dim) + self.decoder_in_0 = nn.Parameter(torch.FloatTensor(embedding_dim)) + self.decoder_in_0.data.uniform_(-std, std) + + self.embedding = nn.Parameter(torch.FloatTensor(self.input_dim, embedding_dim)) + self.embedding.data.uniform_(-std, std) + + def set_decode_type(self, decode_type): + self.decoder.decode_type = decode_type + + def forward(self, inputs, eval_tours=None, return_pi=False): + batch_size, graph_size, input_dim = inputs.size() + + embedded_inputs = torch.mm( + inputs.transpose(0, 1).contiguous().view(-1, input_dim), + self.embedding + ).view(graph_size, batch_size, -1) + + # query the actor net for the input indices + # making up the output, and the pointer attn + _log_p, pi = self._inner(embedded_inputs, eval_tours) + + cost, mask = self.problem.get_costs(inputs, pi) + # Log likelyhood is calculated within the model since returning it per action does not work well with + # DataParallel since sequences can be of different lengths + ll = self._calc_log_likelihood(_log_p, pi, mask) + if return_pi: + return cost, ll, pi + + return cost, ll + + def _calc_log_likelihood(self, _log_p, a, mask): + + # Get log_p corresponding to selected actions + log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1) + + # Optional: mask out actions irrelevant to objective so they do not get reinforced + if mask is not None: + log_p[mask] = 0 + + assert (log_p > -1000).data.all(), "Logprobs should not be -inf, check sampling procedure!" + + # Calculate log_likelihood + return log_p.sum(1) + + def _inner(self, inputs, eval_tours=None): + + encoder_hx = encoder_cx = Variable( + torch.zeros(1, inputs.size(1), self.encoder.hidden_dim, out=inputs.data.new()), + requires_grad=False + ) + + # encoder forward pass + enc_h, (enc_h_t, enc_c_t) = self.encoder(inputs, (encoder_hx, encoder_cx)) + + dec_init_state = (enc_h_t[-1], enc_c_t[-1]) + + # repeat decoder_in_0 across batch + decoder_input = self.decoder_in_0.unsqueeze(0).repeat(inputs.size(1), 1) + + (pointer_probs, input_idxs), dec_hidden_t = self.decoder(decoder_input, + inputs, + dec_init_state, + enc_h, + eval_tours) + + return pointer_probs, input_idxs \ No newline at end of file diff --git a/options.py b/options.py new file mode 100644 index 0000000..cf9180a --- /dev/null +++ b/options.py @@ -0,0 +1,95 @@ +import os +import time +import argparse +import torch + + +def get_options(args=None): + parser = argparse.ArgumentParser( + description="Attention based model for solving the Travelling Salesman Problem with Reinforcement Learning") + + # Data + parser.add_argument('--problem', default='hcvrp', help="The problem to solve, default 'tsp'") + parser.add_argument('--graph_size', type=int, default=20, help="The size of the problem graph") + parser.add_argument('--veh_num', type=int, default=3, help="The number of the problem vehicles") + parser.add_argument('--batch_size', type=int, default=512, help='Number of instances per batch during training') + parser.add_argument('--epoch_size', type=int, default=1280000, help='Number of instances per epoch during training') + parser.add_argument('--val_size', type=int, default=10000, + help='Number of instances used for reporting validation performance') + parser.add_argument('--val_dataset', type=str, default=None, help='Dataset file to use for validation') + parser.add_argument('--obj', default=['min-max', 'min-sum']) + + # Model + parser.add_argument('--model', default='attention', help="Model, 'attention' (default) or 'pointer'") + parser.add_argument('--embedding_dim', type=int, default=128, help='Dimension of input embedding') + parser.add_argument('--hidden_dim', type=int, default=128, help='Dimension of hidden layers in Enc/Dec') + parser.add_argument('--n_encode_layers', type=int, default=3, + help='Number of layers in the encoder/critic network') + parser.add_argument('--n_heads', type=int, default=8, + help='Number of heads in multi-Attention') + + parser.add_argument('--tanh_clipping', type=float, default=10., + help='Clip the parameters to within +- this value using tanh. ' + 'Set to 0 to not perform any clipping.') + parser.add_argument('--normalization', default='batch', help="Normalization type, 'batch' (default) or 'instance'") + + # Training + parser.add_argument('--lr_model', type=float, default=1e-4, help="Set the learning rate for the actor network") + parser.add_argument('--lr_critic', type=float, default=1e-4, help="Set the learning rate for the critic network") + parser.add_argument('--lr_decay', type=float, default=0.995, help='Learning rate decay per epoch') + parser.add_argument('--eval_only', action='store_true', help='Set this value to only evaluate model') + parser.add_argument('--n_epochs', type=int, default=50, help='The number of epochs to train') + parser.add_argument('--seed', type=int, default=1234, help='Random seed to use') + parser.add_argument('--max_grad_norm', type=float, default=3.0, + help='Maximum L2 norm for gradient clipping, default 1.0 (0 to disable clipping)') + parser.add_argument('--no_cuda', action='store_true', help='Disable CUDA') + parser.add_argument('--exp_beta', type=float, default=0.8, + help='Exponential moving average baseline decay (default 0.8)') + parser.add_argument('--baseline', default=None, + help="Baseline to use: 'rollout', 'critic' or 'exponential'. Defaults to no baseline.") + parser.add_argument('--bl_alpha', type=float, default=0.05, + help='Significance in the t-test for updating rollout baseline') + parser.add_argument('--bl_warmup_epochs', type=int, default=None, + help='Number of epochs to warmup the baseline, default None means 1 for rollout (exponential ' + 'used for warmup phase), 0 otherwise. Can only be used with rollout baseline.') + parser.add_argument('--eval_batch_size', type=int, default=1024, + help="Batch size to use during (baseline) evaluation") + parser.add_argument('--checkpoint_encoder', action='store_true', + help='Set to decrease memory usage by checkpointing encoder') + parser.add_argument('--shrink_size', type=int, default=None, + help='Shrink the batch size if at least this many instances in the batch are finished' + ' to save memory (default None means no shrinking)') + parser.add_argument('--data_distribution', type=str, default=None, + help='Data distribution to use during training, defaults and options depend on problem.') + + # Misc + parser.add_argument('--log_step', type=int, default=50, help='Log info every log_step steps') + parser.add_argument('--log_dir', default='logs', help='Directory to write TensorBoard information to') + parser.add_argument('--run_name', default='run', help='Name to identify the run') + parser.add_argument('--no_run_name_wrapper', action='store_true', + help='Do not add timestamp wrapper to run name') + parser.add_argument('--output_dir', default='outputs', help='Directory to write output models to') + parser.add_argument('--epoch_start', type=int, default=0, + help='Start at epoch # (relevant for learning rate decay)') + parser.add_argument('--checkpoint_epochs', type=int, default=1, + help='Save checkpoint every n epochs (default 1), 0 to save no checkpoints') + parser.add_argument('--load_path', help='Path to load model parameters and optimizer state from') + parser.add_argument('--resume', help='Resume from previous checkpoint file') + parser.add_argument('--no_tensorboard', action='store_true', help='Disable logging TensorBoard files') + parser.add_argument('--no_progress_bar', action='store_true', help='Disable progress bar') + + opts = parser.parse_args(args) + + opts.use_cuda = torch.cuda.is_available() and not opts.no_cuda + if not opts.no_run_name_wrapper: + opts.run_name = "{}_{}".format(opts.run_name, time.strftime("%Y%m%dT%H%M%S")) + opts.save_dir = os.path.join( + opts.output_dir, + "{}_v{}_{}".format(opts.problem, opts.veh_num, opts.graph_size), + opts.run_name + ) + if opts.bl_warmup_epochs is None: + opts.bl_warmup_epochs = 1 if opts.baseline == 'rollout' else 0 + assert (opts.bl_warmup_epochs == 0) or (opts.baseline == 'rollout') + assert opts.epoch_size % opts.batch_size == 0, "Epoch size must be integer multiple of batch size!" + return opts \ No newline at end of file diff --git a/outputs/hcvrp_v5_60/args.json b/outputs/hcvrp_v5_60/args.json new file mode 100644 index 0000000..a54cfd1 --- /dev/null +++ b/outputs/hcvrp_v5_60/args.json @@ -0,0 +1,46 @@ +{ + "problem": "hcvrp", + "graph_size": 60, + "veh_num": 5, + "batch_size": 512, + "epoch_size": 1280000, + "val_size": 10000, + "val_dataset": null, + "obj": "min-max", + "model": "attention", + "embedding_dim": 128, + "hidden_dim": 128, + "n_encode_layers": 3, + "n_heads": 8, + "tanh_clipping": 10.0, + "normalization": "batch", + "lr_model": 0.0001, + "lr_critic": 0.0001, + "lr_decay": 0.995, + "eval_only": false, + "n_epochs": 50, + "seed": 1234, + "max_grad_norm": 3.0, + "no_cuda": false, + "exp_beta": 0.8, + "baseline": "rollout", + "bl_alpha": 0.05, + "bl_warmup_epochs": 1, + "eval_batch_size": 1024, + "checkpoint_encoder": false, + "shrink_size": null, + "data_distribution": null, + "log_step": 50, + "log_dir": "logs", + "run_name": "hcvrpv5_60_rollout_20230728T234313", + "no_run_name_wrapper": false, + "output_dir": "outputs", + "epoch_start": 0, + "checkpoint_epochs": 50, + "load_path": null, + "resume": null, + "no_tensorboard": false, + "no_progress_bar": false, + "use_cuda": true, + "save_dir": "outputs/hcvrp_v5_60/hcvrpv5_60_rollout_20230728T234313" +} \ No newline at end of file diff --git a/outputs/hcvrp_v5_60/epoch-49.pt b/outputs/hcvrp_v5_60/epoch-49.pt new file mode 100644 index 0000000..8a6ec3f Binary files /dev/null and b/outputs/hcvrp_v5_60/epoch-49.pt differ diff --git a/problems/__init__.py b/problems/__init__.py new file mode 100644 index 0000000..c19d0fb --- /dev/null +++ b/problems/__init__.py @@ -0,0 +1 @@ +from problems.hcvrp.problem_hcvrp import HCVRP \ No newline at end of file diff --git a/problems/hcvrp/.gitignore b/problems/hcvrp/.gitignore new file mode 100644 index 0000000..2fbcab9 --- /dev/null +++ b/problems/hcvrp/.gitignore @@ -0,0 +1 @@ +lkh/ \ No newline at end of file diff --git a/problems/hcvrp/__init__.py b/problems/hcvrp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/problems/hcvrp/hcvrp.py b/problems/hcvrp/hcvrp.py new file mode 100644 index 0000000..c311623 --- /dev/null +++ b/problems/hcvrp/hcvrp.py @@ -0,0 +1,193 @@ +import torch + +class HcvrpEnv: + def __init__(self,input,scale=(1,40,1)): + ''' + :param input: + input:{ + 'loc': batch_size, graph_size, 2 + 'demand': batch_size, graph_size + 'depot': batch_size, 2 + 'capacity': batch_size, vehicle_num + 'speed': batch_size, vehicle_num + } + :param scale: used to output normalized state (coords,demand,speed) + ''' + self.device = input['loc'].device + self.batch_size = input['loc'].shape[0] + self.bs_index = torch.arange(self.batch_size,device = self.device) + self.step = 0 + self.scale_coords,self.scale_demand,self.scale_speed = scale + self.initial_node_state(input['loc'],input['demand'],input['depot']) + self.initial_veh_state(input['capacity'], input['speed']) + def initial_node_state(self,loc,demand,depot): + ''' + :param loc: customer coordinates [batch_size, graph_size,2] + :param demand: customer demands [batch_size, graph_size] + :param depot: depot coordinates [batch_size, 2] + :return: + ''' + assert loc.shape[:2] == demand.shape, "The custumer's loc and demand shape do not match" + self.customer_num = loc.shape[1] + self.N = loc.shape[1]+1 # Let N represent the graph size + self.coords = torch.cat([depot.unsqueeze(1), + loc],dim=1) # batch_size, N, 2 + self.demand = torch.cat([torch.zeros_like(demand[:,[0]]), + demand],dim=1) # batch_size, N + self.visited = torch.zeros_like(self.demand).bool() # batch_size, N + self.visited[:,0] = True # start from depot, so depot is visited + def all_finished(self): + ''' + :return: Are all tasks finished? + ''' + return self.visited.all() + + def finished(self): + ''' + :return: [bs],true or false, is each task finished? + ''' + return self.visited.all(-1) + + def get_all_node_state(self): + ''' + :return: [bs,N+1,3], get node initial features + ''' + return torch.cat([self.coords/self.scale_coords, + self.demand.unsqueeze(-1)/self.scale_demand],dim = -1) # batch_size, N, 3 + + def initial_veh_state(self,capacity,speed): + ''' + :param capacity: batch_size, veh_num + :param speed: batch_size, veh_num + :return + ''' + assert capacity.size() == speed.size(), "The vehicle's speed and capacity shape do not match" + self.veh_capacity = capacity + self.veh_speed = speed + self.veh_num = capacity.shape[1] + self.veh_time = torch.zeros_like(capacity) # batch_size, veh_num + self.veh_cur_node = torch.zeros_like(capacity).long() # batch_size, veh_num + self.veh_used_capacity = torch.zeros_like(capacity) + # a util vector + self.veh_index = torch.arange(self.veh_num, device=self.device) + + def min_max_norm(self,data): + ''' + deprecated + :param data: + :return: + ''' + # bs,M + min_data = data.min(-1,keepdim=True)[0] + max_data = data.max(-1, keepdim=True)[0] + return (data-min_data)/(max_data-min_data) + def get_all_veh_state(self): + ''' + :return: [bs,M,4] + # time,capacity,usage capacity,speed + ''' + + veh_cur_coords = self.coords[self.bs_index.unsqueeze(-1), + self.veh_cur_node] # batch_size, veh_num, 2 + + return torch.cat([ + self.veh_time.unsqueeze(-1), + self.veh_capacity.unsqueeze(-1)/self.scale_demand, + self.veh_used_capacity.unsqueeze(-1)/self.scale_demand, + self.veh_speed.unsqueeze(-1)/self.scale_speed, + # veh_cur_coords/self.scale_coords + ],dim=-1) + + def get_veh_state(self,veh): + # deprecated + ''' + :param veh: veh_index,batch_size + :return: + ''' + all_veh_state = self.get_all_veh_state() # bs,veh_num,4 + return all_veh_state[self.bs_index,veh] # bs,4 + + + def action_is_legal(self,veh,next_node): + # deprecated + return self.demand[self.bs_index, next_node] <= (self.veh_capacity - self.veh_used_capacity)[self.bs_index, veh] + + def update(self, veh, next_node): + ''' + input action tuple and update the env + :param veh: [batch_size,] + :param next_node: [batch_size,] + :return: + ''' + # select node must be unvisited,except depot + assert not self.visited[self.bs_index,next_node][next_node!=0].any(),"Wrong solution: node has been selected !" + # Note that demand<=remaining_capacity==capacity-usage_capacity + assert (self.demand[self.bs_index,next_node] <= + (self.veh_capacity-self.veh_used_capacity)[self.bs_index,veh]).all(),"Wrong solution: the remaining capacity of the vehicle cannot satisfy the node !" + + # update vehicle time, + last_node = self.veh_cur_node[self.bs_index,veh] + old_coords,new_coords = self.coords[self.bs_index,last_node],self.coords[self.bs_index,next_node] + length = torch.norm(new_coords-old_coords,p=2,dim=1) + time_add = length / self.veh_speed[self.bs_index,veh] + self.veh_time[self.bs_index,veh] += time_add + + # update the used_capacity + new_veh_used_capacity = self.veh_used_capacity[self.bs_index, veh] + self.demand[self.bs_index,next_node] + new_veh_used_capacity[next_node==0] = 0 # 回到仓库后装满车辆 + self.veh_used_capacity[self.bs_index, veh] = new_veh_used_capacity + + # update the node index where the vehicle stands + self.veh_cur_node[self.bs_index,veh] = next_node + self.step += 1 + # print(self.step) + + # update visited vector + self.visited[self.bs_index,next_node]=True + + def all_go_depot(self): + ''' + All vehicle go back the depot + :return: + ''' + veh_list = torch.arange(self.veh_num,device = self.device) + depot = torch.zeros_like(self.bs_index) + for i in veh_list: + self.update(i.expand(self.batch_size),depot) + + def get_cost(self,obj): + self.all_go_depot() + if obj=='min-max': + return self.veh_time.max(-1)[0] + elif obj=='min-sum': + return self.veh_time.sum(-1) + def get_action_mask(self): + # cannot select a visited node except the depot + visited_mask = self.visited.clone() # bs,N+1 + visited_mask[:,0]=False + # Here, clone() is important for avoiding the bug from expand() + visited_mask = visited_mask.unsqueeze(1).expand(self.batch_size, self.veh_num, self.N).clone() # bs,M,N+1 + # Vehicle cannot stay in place to avoid visiting the depot twice, + # otherwise an infinite loop will easily occur + visited_mask[self.bs_index.unsqueeze(-1),self.veh_index.unsqueeze(0),self.veh_cur_node]=True + # capacity constraints + demand_mask = (self.veh_capacity - self.veh_used_capacity).unsqueeze(-1) < self.demand.unsqueeze(1) # bs,M,N+1 + mask = visited_mask | demand_mask + # Special setting for batch processing, + # because the finished task will have a full mask and raise an error + mask[self.finished(),0,0]=False + return mask + + @staticmethod + def caculate_cost(input,solution,obj): + ''' + :param input: equal to __init__ + :param solution: (veh,next_node): [total_step, batch_size],[total_step, batch_size] + :param obj: 'min-max' or 'min-sum' + :return: cost : batch_size + ''' + + env = HcvrpEnv(input) + for veh,next_node in zip(*solution): + env.update(veh,next_node) + return env.get_cost(obj) \ No newline at end of file diff --git a/problems/hcvrp/problem_hcvrp.py b/problems/hcvrp/problem_hcvrp.py new file mode 100644 index 0000000..2e284dc --- /dev/null +++ b/problems/hcvrp/problem_hcvrp.py @@ -0,0 +1,64 @@ +from torch.utils.data import Dataset +import torch +import os +import pickle + +class HCVRP: + NAME = 'hcvrp' + @staticmethod + def make_dataset(*args,**kwargs): + return HCVRPDataset(*args,**kwargs) + +def make_instance(args): + depot, loc, demand, capacity, *args = args + grid_size = 1 + if len(args) > 0: + depot_types, customer_types, grid_size = args + return { + 'loc': torch.tensor(loc, dtype=torch.float) / grid_size, + 'demand': torch.tensor(demand, dtype=torch.float), # scale demand + 'depot': torch.tensor(depot, dtype=torch.float) / grid_size, + 'capacity': torch.tensor(capacity, dtype=torch.float) + } + +class HCVRPDataset(Dataset): + def __init__(self, filename=None, size=50, veh_num=3, num_samples=10000, offset=0, distribution=None): + super(HCVRPDataset, self).__init__() + + # self.data_set = [] + if filename is not None: + assert os.path.splitext(filename)[1] == '.pkl' + + with open(filename, 'rb') as f: + data = pickle.load(f) + self.data = [] + for i in range(data['depot'].shape[0]): + self.data.append({ + 'depot': torch.from_numpy(data['depot'][i]).float(), + 'loc': torch.from_numpy(data['loc'][i]).float(), + 'demand': torch.from_numpy(data['demand'][i]).float(), + 'capacity': torch.from_numpy(data['capacity'][i]).float(), + 'speed': torch.from_numpy(data['speed'][i]).float() + }) + else: + self.data = [ + { + 'loc': torch.FloatTensor(size, 2).uniform_(0, 1), + # Uniform 1 - 9, scaled by capacities + 'demand': (torch.FloatTensor(size).uniform_(0, 9).int() + 1).float(), + 'depot': torch.FloatTensor(2).uniform_(0, 1), + # Uniform 20 - 40, scaled by capacities + 'capacity': (torch.FloatTensor(veh_num).uniform_(19, 40).int() + 1).float(), + 'speed': torch.FloatTensor(veh_num).uniform_(0.5, 1) + } + for i in range(num_samples) + ] + + self.size = len(self.data) # num_samples + + def __len__(self): + return self.size + + def __getitem__(self, idx): + return self.data[idx] # index of sampled data + diff --git a/reinforce_baselines.py b/reinforce_baselines.py new file mode 100644 index 0000000..ed9ae25 --- /dev/null +++ b/reinforce_baselines.py @@ -0,0 +1,248 @@ +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset +from scipy.stats import ttest_rel +import copy +from train import rollout, get_inner_model + +class Baseline(object): + + def wrap_dataset(self, dataset): + return dataset + + def unwrap_batch(self, batch): + return batch, None + + def eval(self, x, c): + raise NotImplementedError("Override this method") + + def get_learnable_parameters(self): + return [] + + def epoch_callback(self, model, epoch): + pass + + def state_dict(self): + return {} + + def load_state_dict(self, state_dict): + pass + + +class WarmupBaseline(Baseline): + + def __init__(self, baseline, n_epochs=1, warmup_exp_beta=0.8, ): + super(Baseline, self).__init__() + + self.baseline = baseline + assert n_epochs > 0, "n_epochs to warmup must be positive" + self.warmup_baseline = ExponentialBaseline(warmup_exp_beta) + self.alpha = 0 + self.n_epochs = n_epochs + + def wrap_dataset(self, dataset): + if self.alpha > 0: + return self.baseline.wrap_dataset(dataset) + return self.warmup_baseline.wrap_dataset(dataset) + + def unwrap_batch(self, batch): + if self.alpha > 0: + return self.baseline.unwrap_batch(batch) + return self.warmup_baseline.unwrap_batch(batch) + + def eval(self, x, c): + + if self.alpha == 1: + return self.baseline.eval(x, c) + if self.alpha == 0: + return self.warmup_baseline.eval(x, c) + v, l = self.baseline.eval(x, c) + vw, lw = self.warmup_baseline.eval(x, c) + # Return convex combination of baseline and of loss + return self.alpha * v + (1 - self.alpha) * vw, self.alpha * l + (1 - self.alpha * lw) + + def epoch_callback(self, model, epoch): + # Need to call epoch callback of inner model (also after first epoch if we have not used it) + self.baseline.epoch_callback(model, epoch) + self.alpha = (epoch + 1) / float(self.n_epochs) + if epoch < self.n_epochs: + print("Set warmup alpha = {}".format(self.alpha)) + + def state_dict(self): + # Checkpointing within warmup stage makes no sense, only save inner baseline + return self.baseline.state_dict() + + def load_state_dict(self, state_dict): + # Checkpointing within warmup stage makes no sense, only load inner baseline + self.baseline.load_state_dict(state_dict) + + +class NoBaseline(Baseline): + + def eval(self, x, c): + return 0, 0 # No baseline, no loss + + +class ExponentialBaseline(Baseline): + + def __init__(self, beta): + super(Baseline, self).__init__() + + self.beta = beta + self.v = None + + def eval(self, x, c): # x is data and c is cost in actor network + + if self.v is None: + v = c.mean() + else: + v = self.beta * self.v + (1. - self.beta) * c.mean() + + self.v = v.detach() # Detach since we never want to backprop + return self.v, 0 # No loss + + def state_dict(self): + return { + 'v': self.v + } + + def load_state_dict(self, state_dict): + self.v = state_dict['v'] + + +class CriticBaseline(Baseline): + + def __init__(self, critic): + super(Baseline, self).__init__() + + self.critic = critic + + def eval(self, x, c): + v = self.critic(x) + # Detach v since actor should not backprop through baseline, only for loss + return v.detach(), F.mse_loss(v, c.detach()) + + def get_learnable_parameters(self): + return list(self.critic.parameters()) + + def epoch_callback(self, model, epoch): + pass + + def state_dict(self): + return { + 'critic': self.critic.state_dict() + } + + def load_state_dict(self, state_dict): + critic_state_dict = state_dict.get('critic', {}) + if not isinstance(critic_state_dict, dict): # backwards compatibility + critic_state_dict = critic_state_dict.state_dict() + self.critic.load_state_dict({**self.critic.state_dict(), **critic_state_dict}) + + +class RolloutBaseline(Baseline): + + def __init__(self, model, problem, opts, epoch=0): + super(Baseline, self).__init__() + + self.problem = problem + self.opts = opts + + self._update_model(model, epoch) + + def _update_model(self, model, epoch, dataset=None): + self.model = copy.deepcopy(model) + # Always generate baseline dataset when updating model to prevent overfitting to the baseline dataset + + if dataset is not None: + if len(dataset) != self.opts.val_size: + print("Warning: not using saved baseline dataset since val_size does not match") + dataset = None + elif (dataset[0] if self.problem.NAME == 'tsp' else dataset[0]['loc']).size(0) != self.opts.graph_size: + print("Warning: not using saved baseline dataset since graph_size does not match") + dataset = None + + if dataset is None: + self.dataset = self.problem.make_dataset( + size = self.opts.graph_size, veh_num = self.opts.veh_num, num_samples=self.opts.val_size, distribution=self.opts.data_distribution) + else: + self.dataset = dataset + print("Evaluating baseline model on evaluation dataset") + self.bl_vals = rollout(self.model, self.dataset, self.opts).cpu().numpy() + self.mean = self.bl_vals.mean() + self.epoch = epoch + + def wrap_dataset(self, dataset): + print("Evaluating baseline on dataset...") + # Need to convert baseline to 2D to prevent converting to double, see + # https://discuss.pytorch.org/t/dataloader-gives-double-instead-of-float/717/3 + return BaselineDataset(dataset, rollout(self.model, dataset, self.opts).view(-1, 1)) # [epoch_size, 1] (num_samples) + + def unwrap_batch(self, batch): + return batch['data'], batch['baseline'].view(-1) # Flatten result to undo wrapping as 2D + + def eval(self, x, c): + # Use volatile mode for efficient inference (single batch so we do not use rollout function) + with torch.no_grad(): + v, _, _ = self.model(x) # return baseline, cost + + # There is no loss + return v, 0 + + def epoch_callback(self, model, epoch): + """ + Challenges the current baseline with the model and replaces the baseline model if it is improved. + :param model: The model to challenge the baseline by + :param epoch: The current epoch + """ + print("Evaluating candidate model on evaluation dataset") + candidate_vals = rollout(model, self.dataset, self.opts).cpu().numpy() + + candidate_mean = candidate_vals.mean() + + print("Epoch {} candidate mean {}, baseline epoch {} mean {}, difference {}".format( + epoch, candidate_mean, self.epoch, self.mean, candidate_mean - self.mean)) + + # if candidate model have smaller cost than current baseline model + if candidate_mean - self.mean < 0: + # Calc p value + t, p = ttest_rel(candidate_vals, self.bl_vals) + + p_val = p / 2 # one-sided + assert t < 0, "T-statistic should be negative" + print("p-value: {}".format(p_val)) + if p_val < self.opts.bl_alpha: + print('Update baseline') + self._update_model(model, epoch) + + def state_dict(self): + return { + 'model': self.model, + 'dataset': self.dataset, + 'epoch': self.epoch + } + + def load_state_dict(self, state_dict): + # We make it such that it works whether model was saved as data parallel or not + load_model = copy.deepcopy(self.model) + get_inner_model(load_model).load_state_dict(get_inner_model(state_dict['model']).state_dict()) + self._update_model(load_model, state_dict['epoch'], state_dict['dataset']) + + +class BaselineDataset(Dataset): + + def __init__(self, dataset=None, baseline=None): + super(BaselineDataset, self).__init__() + + self.dataset = dataset + self.baseline = baseline + assert (len(self.dataset) == len(self.baseline)) + + def __getitem__(self, item): + return { + 'data': self.dataset[item], + 'baseline': self.baseline[item] + } + + def __len__(self): + return len(self.dataset) diff --git a/run.py b/run.py new file mode 100644 index 0000000..3daee04 --- /dev/null +++ b/run.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python + +import os +import json +import pprint as pp + +import torch +import torch.optim as optim +from tensorboard_logger import Logger as TbLogger + +from nets.critic_network import CriticNetwork +from options import get_options +from train import train_epoch, validate, get_inner_model +from reinforce_baselines import NoBaseline, ExponentialBaseline, CriticBaseline, RolloutBaseline, WarmupBaseline +from nets.attention_model import AttentionModel +#from nets.attention_model_minsum import AttentionModel +from nets.pointer_network import PointerNetwork, CriticNetworkLSTM +from utils import torch_load_cpu, load_problem +import warnings + +# from problems.vrp import CVRP +from problems import HCVRP + +def run(opts): + + # Pretty print the run args + pp.pprint(vars(opts)) + + # Set the random seed + torch.manual_seed(opts.seed) + + # Optionally configure tensorboard + tb_logger = None + if not opts.no_tensorboard: + tb_logger = TbLogger(os.path.join(opts.log_dir, "{}_v{}_c{}".format(opts.problem,opts.veh_num,opts.graph_size), opts.run_name)) + + # save model to outputs dir + os.makedirs(opts.save_dir) + # Save arguments so exact configuration can always be found + with open(os.path.join(opts.save_dir, "args.json"), 'w') as f: + json.dump(vars(opts), f, indent=True) + + # Set the device + opts.device = torch.device("cuda:0" if opts.use_cuda else "cpu") + + # Figure out what's the problem + problem = load_problem(opts.problem) + # problem = HCVRP(opts.graph_size,opts.veh_num,opts.obj) + + # Load data from load_path + # if u have run the model before, u can continue from resume path + load_data = {} + assert opts.load_path is None or opts.resume is None, "Only one of load path and resume can be given" + load_path = opts.load_path if opts.load_path is not None else opts.resume + if load_path is not None: + print(' [*] Loading data from {}'.format(load_path)) + load_data = torch_load_cpu(load_path) + + # Initialize model + model_class = { + 'attention': AttentionModel, + 'pointer': PointerNetwork + }.get(opts.model, None) + assert model_class is not None, "Unknown model: {}".format(model_class) + model = model_class( + opts.embedding_dim, + opts.hidden_dim, + opts.obj, + problem, + n_heads=opts.n_heads, + n_encode_layers=opts.n_encode_layers, + mask_inner=True, + mask_logits=True, + normalization=opts.normalization, + tanh_clipping=opts.tanh_clipping, + checkpoint_encoder=opts.checkpoint_encoder, + shrink_size=opts.shrink_size + ).to(opts.device) + + # multi-gpu + if opts.use_cuda and torch.cuda.device_count() > 1: + model = torch.nn.DataParallel(model) + + # Overwrite model parameters by parameters to load + model_ = get_inner_model(model) + model_.load_state_dict({**model_.state_dict(), **load_data.get('model', {})}) + + # Initialize baseline + if opts.baseline == 'exponential': + baseline = ExponentialBaseline(opts.exp_beta) + elif opts.baseline == 'critic' or opts.baseline == 'critic_lstm': + assert problem.NAME == 'tsp', "Critic only supported for TSP" + baseline = CriticBaseline( + ( + CriticNetworkLSTM( + 2, + opts.embedding_dim, + opts.hidden_dim, + opts.n_encode_layers, + opts.tanh_clipping + ) + if opts.baseline == 'critic_lstm' + else + CriticNetwork( + 2, + opts.embedding_dim, + opts.hidden_dim, + opts.n_encode_layers, + opts.normalization + ) + ).to(opts.device) + ) + elif opts.baseline == 'rollout': + baseline = RolloutBaseline(model, problem, opts) + else: + assert opts.baseline is None, "Unknown baseline: {}".format(opts.baseline) + baseline = NoBaseline() + + if opts.bl_warmup_epochs > 0: + baseline = WarmupBaseline(baseline, opts.bl_warmup_epochs, warmup_exp_beta=opts.exp_beta) + + # Load baseline from data, make sure script is called with same type of baseline + if 'baseline' in load_data: + baseline.load_state_dict(load_data['baseline']) + + # Initialize optimizer + optimizer = optim.Adam( + [{'params': model.parameters(), 'lr': opts.lr_model}] + + ( + [{'params': baseline.get_learnable_parameters(), 'lr': opts.lr_critic}] + if len(baseline.get_learnable_parameters()) > 0 + else [] + ) + ) + + # Load optimizer state from trained model + if 'optimizer' in load_data: + optimizer.load_state_dict(load_data['optimizer']) + for state in optimizer.state.values(): + for k, v in state.items(): + if torch.is_tensor(v): + state[k] = v.to(opts.device) + + # Initialize learning rate scheduler, decay by lr_decay once per epoch! + lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: opts.lr_decay ** epoch) + + # Start the actual training loop + val_dataset = problem.make_dataset( + size=opts.graph_size, veh_num=opts.veh_num, num_samples=opts.val_size, filename=opts.val_dataset, distribution=opts.data_distribution) + + if opts.resume: + epoch_resume = int(os.path.splitext(os.path.split(opts.resume)[-1])[0].split("-")[1]) + + torch.set_rng_state(load_data['rng_state']) + if opts.use_cuda: + torch.cuda.set_rng_state_all(load_data['cuda_rng_state']) + # Set the random states + # Dumping of state was done before epoch callback, so do that now (model is loaded) + baseline.epoch_callback(model, epoch_resume) + print("Resuming after {}".format(epoch_resume)) + opts.epoch_start = epoch_resume + 1 + + if opts.eval_only: + validate(model, val_dataset, opts) + else: + for epoch in range(opts.epoch_start, opts.epoch_start + opts.n_epochs): + train_epoch( + model, + optimizer, + baseline, + lr_scheduler, + epoch, + val_dataset, + problem, + tb_logger, + opts + ) + + +if __name__ == "__main__": + warnings.filterwarnings('ignore') + # os.environ["CUDA_VISIBLE_DEVICES"] = "0" + run(get_options()) diff --git a/train.py b/train.py new file mode 100644 index 0000000..ebb6b3b --- /dev/null +++ b/train.py @@ -0,0 +1,174 @@ +import os +import time +from tqdm import tqdm +import torch +import math + +from torch.utils.data import DataLoader +from torch.nn import DataParallel + +from nets.attention_model import set_decode_type +from utils.log_utils import log_values +from utils import move_to + + +def get_inner_model(model): + return model.module if isinstance(model, DataParallel) else model + + +def validate(model, dataset, opts): + # Validate + print('Validating...') + # multi batch + cost = rollout(model, dataset, opts) + avg_cost = cost.mean() + print('Validation overall avg_cost: {} +- {}'.format( + avg_cost, torch.std(cost) / math.sqrt(len(cost)))) + + return avg_cost + + +def rollout(model, dataset, opts): + # Put in greedy evaluation mode! + set_decode_type(model, "greedy") + model.eval() + + def eval_model_bat(bat): + # do not need backpropogation + with torch.no_grad(): + cost, _ = model(move_to(bat, opts.device)) + return cost.data.cpu() + + # tqdm is a function to show the progress bar + return torch.cat([ + eval_model_bat(bat) + for bat + in tqdm(DataLoader(dataset, batch_size=opts.eval_batch_size), disable=opts.no_progress_bar) + ], 0) + + +def clip_grad_norms(param_groups, max_norm=math.inf): + """ + Clips the norms for all param groups to max_norm and returns gradient norms before clipping + :param optimizer: + :param max_norm: + :param gradient_norms_log: + :return: grad_norms, clipped_grad_norms: list with (clipped) gradient norms per group + """ + grad_norms = [ + torch.nn.utils.clip_grad_norm_( + group['params'], + max_norm if max_norm > 0 else math.inf, # Inf so no clipping but still call to calc + norm_type=2 + ) + for group in param_groups + ] + grad_norms_clipped = [min(g_norm, max_norm) for g_norm in grad_norms] if max_norm > 0 else grad_norms + return grad_norms, grad_norms_clipped + + + +def train_epoch(model, optimizer, baseline, lr_scheduler, epoch, val_dataset, problem, tb_logger, opts): + print("Start train epoch {}, lr={} for run {}".format(epoch, optimizer.param_groups[0]['lr'], opts.run_name)) + step = epoch * (opts.epoch_size // opts.batch_size) + start_time = time.time() + lr_scheduler.step(epoch) + + if not opts.no_tensorboard: # need tensorboard + tb_logger.log_value('learnrate_pg0', optimizer.param_groups[0]['lr'], step) + + # Generate new training data for each epoch + training_dataset = baseline.wrap_dataset(problem.make_dataset( + size=opts.graph_size, + veh_num=opts.veh_num, + num_samples=opts.epoch_size, + distribution=opts.data_distribution)) # data, baseline (cost of data) + training_dataloader = DataLoader(training_dataset, batch_size=opts.batch_size, num_workers=0) + + # Put model in train mode! + model.train() + set_decode_type(model, "sampling") + + for batch_id, batch in enumerate(tqdm(training_dataloader, disable=opts.no_progress_bar)): + train_batch( + model, + optimizer, + baseline, + epoch, + batch_id, + step, + batch, + tb_logger, + opts + ) + step += 1 + + epoch_duration = time.time() - start_time + print("Finished epoch {}, took {} s".format(epoch, time.strftime('%H:%M:%S', time.gmtime(epoch_duration)))) + + # save results every checkpoint_epoches, saving memory + if (opts.checkpoint_epochs != 0 and epoch % opts.checkpoint_epochs == 0) or epoch == opts.n_epochs - 1: + print('Saving model and state...') + torch.save( + { + 'model': get_inner_model(model).state_dict(), + 'optimizer': optimizer.state_dict(), + # rng_state is the state of random generator + 'rng_state': torch.get_rng_state(), + 'cuda_rng_state': torch.cuda.get_rng_state_all(), + 'baseline': baseline.state_dict() + }, + # save state of runned model in outputs + os.path.join(opts.save_dir, 'epoch-{}.pt'.format(epoch)) + ) + + avg_reward = validate(model, val_dataset, opts) + + if not opts.no_tensorboard: + tb_logger.log_value('val_avg_reward', avg_reward, step) + + baseline.epoch_callback(model, epoch) + + +def train_batch( + model, + optimizer, + baseline, + epoch, + batch_id, + step, + batch, + tb_logger, + opts +): + x, bl_val = baseline.unwrap_batch(batch) # data, baseline(cost of data) + x = move_to(x, opts.device) + bl_val = move_to(bl_val, opts.device) if bl_val is not None else None + + + # Evaluate proposed model, get costs and log probabilities + # cost, log_likelihood, log_veh = model(x) # both [batch_size] + cost, log_likelihood = model(x) # both [batch_size] + + # Evaluate baseline, get baseline loss if any (only for critic) + bl_val, bl_loss = baseline.eval(x, cost) if bl_val is None else (bl_val, 0) + + # Calculate loss + # reinforce_loss = ((cost - bl_val) * (log_likelihood + log_veh)).mean() + reinforce_loss = ((cost - bl_val) * log_likelihood).mean() + loss = reinforce_loss + bl_loss + #print('bl_val', bl_val) + + # Perform backward pass and optimization step + optimizer.zero_grad() + loss.backward() + # Clip gradient norms and get (clipped) gradient norms for logging + grad_norms = clip_grad_norms(optimizer.param_groups, opts.max_grad_norm) + + optimizer.step() + + # Logging + if step % int(opts.log_step) == 0: + log_values(cost, grad_norms, epoch, batch_id, step, + log_likelihood, reinforce_loss, bl_loss, tb_logger, opts) + diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e7d4dc0 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1 @@ +from .functions import * \ No newline at end of file diff --git a/utils/beam_search.py b/utils/beam_search.py new file mode 100644 index 0000000..e564ca4 --- /dev/null +++ b/utils/beam_search.py @@ -0,0 +1,222 @@ +import time +import torch +from typing import NamedTuple +from utils.lexsort import torch_lexsort + + +def beam_search(*args, **kwargs): + beams, final_state = _beam_search(*args, **kwargs) + return get_beam_search_results(beams, final_state) + + +def get_beam_search_results(beams, final_state): + beam = beams[-1] # Final beam + if final_state is None: + return None, None, None, None, beam.batch_size + + # First state has no actions/parents and should be omitted when backtracking + actions = [beam.action for beam in beams[1:]] + parents = [beam.parent for beam in beams[1:]] + + solutions = final_state.construct_solutions(backtrack(parents, actions)) + return beam.score, solutions, final_state.get_final_cost()[:, 0], final_state.ids.view(-1), beam.batch_size + + +def _beam_search(state, beam_size, propose_expansions=None, + keep_states=False): + + beam = BatchBeam.initialize(state) + + # Initial state + beams = [beam if keep_states else beam.clear_state()] + + # Perform decoding steps + while not beam.all_finished(): + + # Use the model to propose and score expansions + parent, action, score = beam.propose_expansions() if propose_expansions is None else propose_expansions(beam) + if parent is None: + return beams, None + + # Expand and update the state according to the selected actions + beam = beam.expand(parent, action, score=score) + + # Get topk + beam = beam.topk(beam_size) + + # Collect output of step + beams.append(beam if keep_states else beam.clear_state()) + + # Return the final state separately since beams may not keep state + return beams, beam.state + + +class BatchBeam(NamedTuple): + """ + Class that keeps track of a beam for beam search in batch mode. + Since the beam size of different entries in the batch may vary, the tensors are not (batch_size, beam_size, ...) + but rather (sum_i beam_size_i, ...), i.e. flattened. This makes some operations a bit cumbersome. + """ + score: torch.Tensor # Current heuristic score of each entry in beam (used to select most promising) + state: None # To track the state + parent: torch.Tensor + action: torch.Tensor + batch_size: int # Can be used for optimizations if batch_size = 1 + device: None # Track on which device + + # Indicates for each row to which batch it belongs (0, 0, 0, 1, 1, 2, ...), managed by state + @property + def ids(self): + return self.state.ids.view(-1) # Need to flat as state has steps dimension + + def __getitem__(self, key): + assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor: + # if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor: + return self._replace( + # ids=self.ids[key], + score=self.score[key] if self.score is not None else None, + state=self.state[key], + parent=self.parent[key] if self.parent is not None else None, + action=self.action[key] if self.action is not None else None + ) + # return super(BatchBeam, self).__getitem__(key) + + # Do not use __len__ since this is used by namedtuple internally and should be number of fields + # def __len__(self): + # return len(self.ids) + + @staticmethod + def initialize(state): + batch_size = len(state.ids) + device = state.ids.device + return BatchBeam( + score=torch.zeros(batch_size, dtype=torch.float, device=device), + state=state, + parent=None, + action=None, + batch_size=batch_size, + device=device + ) + + def propose_expansions(self): + mask = self.state.get_mask() + # Mask always contains a feasible action + expansions = torch.nonzero(mask[:, 0, :] == 0) + parent, action = torch.unbind(expansions, -1) + return parent, action, None + + def expand(self, parent, action, score=None): + return self._replace( + score=score, # The score is cleared upon expanding as it is no longer valid, or it must be provided + state=self.state[parent].update(action), # Pass ids since we replicated state + parent=parent, + action=action + ) + + def topk(self, k): + idx_topk = segment_topk_idx(self.score, k, self.ids) + return self[idx_topk] + + def all_finished(self): + return self.state.all_finished() + + def cpu(self): + return self.to(torch.device('cpu')) + + def to(self, device): + if device == self.device: + return self + return self._replace( + score=self.score.to(device) if self.score is not None else None, + state=self.state.to(device), + parent=self.parent.to(device) if self.parent is not None else None, + action=self.action.to(device) if self.action is not None else None + ) + + def clear_state(self): + return self._replace(state=None) + + def size(self): + return self.state.ids.size(0) + + +def segment_topk_idx(x, k, ids): + """ + Finds the topk per segment of data x given segment ids (0, 0, 0, 1, 1, 2, ...). + Note that there may be fewer than k elements in a segment so the returned length index can vary. + x[result], ids[result] gives the sorted elements per segment as well as corresponding segment ids after sorting. + :param x: + :param k: + :param ids: + :return: + """ + assert x.dim() == 1 + assert ids.dim() == 1 + + # Since we may have varying beam size per batch entry we cannot reshape to (batch_size, beam_size) + # And use default topk along dim -1, so we have to be creative + # Now we have to get the topk per segment which is really annoying :( + # we use lexsort on (ids, score), create array with offset per id + # offsets[ids] then gives offsets repeated and only keep for which arange(len) < offsets + k + splits_ = torch.nonzero(ids[1:] - ids[:-1]) + + if len(splits_) == 0: # Only one group + _, idx_topk = x.topk(min(k, x.size(0))) + return idx_topk + + splits = torch.cat((ids.new_tensor([0]), splits_[:, 0] + 1)) + # Make a new array in which we store for each id the offset (start) of the group + # This way ids does not need to be increasing or adjacent, as long as each group is a single range + group_offsets = splits.new_zeros((splits.max() + 1,)) + group_offsets[ids[splits]] = splits + offsets = group_offsets[ids] # Look up offsets based on ids, effectively repeating for the repetitions per id + + # We want topk so need to sort x descending so sort -x (be careful with unsigned data type!) + idx_sorted = torch_lexsort((-(x if x.dtype != torch.uint8 else x.int()).detach(), ids)) + + # This will filter first k per group (example k = 2) + # ids = [0, 0, 0, 1, 1, 1, 1, 2] + # splits = [0, 3, 7] + # offsets = [0, 0, 0, 3, 3, 3, 3, 7] + # offs+2 = [2, 2, 2, 5, 5, 5, 5, 9] + # arange = [0, 1, 2, 3, 4, 5, 6, 7] + # filter = [1, 1, 0, 1, 1, 0, 0, 1] + # Use filter to get only topk of sorting idx + return idx_sorted[torch.arange(ids.size(0), out=ids.new()) < offsets + k] + + +def backtrack(parents, actions): + + # Now backtrack to find aligned action sequences in reversed order + cur_parent = parents[-1] + reversed_aligned_sequences = [actions[-1]] + for parent, sequence in reversed(list(zip(parents[:-1], actions[:-1]))): + reversed_aligned_sequences.append(sequence.gather(-1, cur_parent)) + cur_parent = parent.gather(-1, cur_parent) + + return torch.stack(list(reversed(reversed_aligned_sequences)), -1) + + +class CachedLookup(object): + + def __init__(self, data): + self.orig = data + self.key = None + self.current = None + + def __getitem__(self, key): + assert not isinstance(key, slice), "CachedLookup does not support slicing, " \ + "you can slice the result of an index operation instead" + + if torch.is_tensor(key): # If tensor, idx all tensors by this tensor: + + if self.key is None: + self.key = key + self.current = self.orig[key] + elif len(key) != len(self.key) or (key != self.key).any(): + self.key = key + self.current = self.orig[key] + + return self.current + + return super(CachedLookup, self).__getitem__(key) diff --git a/utils/boolmask.py b/utils/boolmask.py new file mode 100644 index 0000000..4764745 --- /dev/null +++ b/utils/boolmask.py @@ -0,0 +1,68 @@ +import torch +import torch.nn.functional as F + + +def _pad_mask(mask): + # By taking -size % 8, we get 0 if exactly divisible by 8 + # and required padding otherwise (i.e. -1 % 8 = 7 pad) + pad = -mask.size(-1) % 8 + if pad != 0: + mask = F.pad(mask, [0, pad]) + return mask, mask.size(-1) // 8 + + +def _mask_bool2byte(mask): + assert mask.dtype == torch.uint8 + # assert (mask <= 1).all() # Precondition, disabled for efficiency + mask, d = _pad_mask(mask) + return (mask.view(*mask.size()[:-1], d, 8) << torch.arange(8, out=mask.new())).sum(-1, dtype=torch.uint8) + + +def _mask_byte2long(mask): + assert mask.dtype == torch.uint8 + mask, d = _pad_mask(mask) + # Note this corresponds to a temporary factor 8 + # memory overhead by converting to long before summing + # Alternatively, aggregate using for loop + return (mask.view(*mask.size()[:-1], d, 8).long() << (torch.arange(8, dtype=torch.int64, device=mask.device) * 8)).sum(-1) + + +def mask_bool2long(mask): + assert mask.dtype == torch.uint8 + return _mask_byte2long(_mask_bool2byte(mask)) + + +def _mask_long2byte(mask, n=None): + if n is None: + n = 8 * mask.size(-1) + return (mask[..., None] >> (torch.arange(8, out=mask.new()) * 8))[..., :n].to(torch.uint8).view(*mask.size()[:-1], -1)[..., :n] + + +def _mask_byte2bool(mask, n=None): + if n is None: + n = 8 * mask.size(-1) + return (mask[..., None] & (mask.new_ones(8) << torch.arange(8, out=mask.new()) * 1)).view(*mask.size()[:-1], -1)[..., :n] > 0 + + +def mask_long2bool(mask, n=None): + assert mask.dtype == torch.int64 + return _mask_byte2bool(_mask_long2byte(mask), n=n) + + +def mask_long_scatter(mask, values, check_unset=True): + """ + Sets values in mask in dimension -1 with arbitrary batch dimensions + If values contains -1, nothing is set + Note: does not work for setting multiple values at once (like normal scatter) + """ + assert mask.size()[:-1] == values.size() + rng = torch.arange(mask.size(-1), out=mask.new()) + values_ = values[..., None] # Need to broadcast up do mask dim + # This indicates in which value of the mask a bit should be set + where = (values_ >= (rng * 64)) & (values_ < ((rng + 1) * 64)) + # Optional: check that bit is not already set + assert not (check_unset and ((mask & (where.long() << (values_ % 64))) > 0).any()) + # Set bit by shifting a 1 to the correct position + # (% not strictly necessary as bitshift is cyclic) + # since where is 0 if no value needs to be set, the bitshift has no effect + return mask | (where.long() << (values_ % 64)) diff --git a/utils/data_utils.py b/utils/data_utils.py new file mode 100644 index 0000000..54a3ee8 --- /dev/null +++ b/utils/data_utils.py @@ -0,0 +1,25 @@ +import os +import pickle + + +def check_extension(filename): + if os.path.splitext(filename)[1] != ".pkl": + return filename + ".pkl" + return filename + + +def save_dataset(dataset, filename): + + filedir = os.path.split(filename)[0] + + if not os.path.isdir(filedir): + os.makedirs(filedir) + + with open(check_extension(filename), 'wb') as f: + pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL) + + +def load_dataset(filename): + + with open(check_extension(filename), 'rb') as f: + return pickle.load(f) \ No newline at end of file diff --git a/utils/functions.py b/utils/functions.py new file mode 100644 index 0000000..1061568 --- /dev/null +++ b/utils/functions.py @@ -0,0 +1,214 @@ +import warnings + +import torch +import numpy as np +import os +import json +from tqdm import tqdm +from multiprocessing.dummy import Pool as ThreadPool +from multiprocessing import Pool +import torch.nn.functional as F + + + + +def load_problem(name): + from problems import HCVRP + problem = { + 'hcvrp': HCVRP + }.get(name, None) + assert problem is not None, "Currently unsupported problem: {}!".format(name) + return problem + + +def torch_load_cpu(load_path): + return torch.load(load_path, map_location=lambda storage, loc: storage) # Load on CPU + + +def move_to(var, device): + if isinstance(var, dict): + return {k: move_to(v, device) for k, v in var.items()} + return var.to(device) + + +def _load_model_file(load_path, model): + """Loads the model with parameters from the file and returns optimizer state dict if it is in the file""" + + # Load the model parameters from a saved state + load_optimizer_state_dict = None + print(' [*] Loading model from {}'.format(load_path)) + + load_data = torch.load( + os.path.join( + os.getcwd(), + load_path + ), map_location=lambda storage, loc: storage) + + if isinstance(load_data, dict): + load_optimizer_state_dict = load_data.get('optimizer', None) + load_model_state_dict = load_data.get('model', load_data) + else: + load_model_state_dict = load_data.state_dict() + + state_dict = model.state_dict() + + state_dict.update(load_model_state_dict) + + model.load_state_dict(state_dict) + + return model, load_optimizer_state_dict + + +def load_args(filename): + with open(filename, 'r') as f: + args = json.load(f) + + # Backwards compatibility + if 'data_distribution' not in args: + args['data_distribution'] = None + probl, *dist = args['problem'].split("_") + if probl == "op": + args['problem'] = probl + args['data_distribution'] = dist[0] + return args + + +def load_model(path, obj, epoch=None): + from problems import HCVRP + from nets.attention_model import AttentionModel + from nets.pointer_network import PointerNetwork + + if os.path.isfile(path): + model_filename = path + path = os.path.dirname(model_filename) + elif os.path.isdir(path): + if epoch is None: + epoch = max( + int(os.path.splitext(filename)[0].split("-")[1]) + for filename in os.listdir(path) + if os.path.splitext(filename)[1] == '.pt' + ) + model_filename = os.path.join(path, 'epoch-{}.pt'.format(epoch)) + else: + assert False, "{} is not a valid directory or file".format(path) + + args = load_args(os.path.join(path, 'args.json')) + + problem = load_problem(args['problem']) + # problem = HCVRP(graph_size=120,veh_num=5,obj='min-max') + model_class = { + 'attention': AttentionModel, + 'pointer': PointerNetwork + }.get(args.get('model', 'attention'), None) + assert model_class is not None, "Unknown model: {}".format(model_class) + + model = model_class( + args['embedding_dim'], + args['hidden_dim'], + obj, + problem, + n_encode_layers=args['n_encode_layers'], + mask_inner=True, + mask_logits=True, + normalization=args['normalization'], + tanh_clipping=args['tanh_clipping'], + checkpoint_encoder=args.get('checkpoint_encoder', False), + shrink_size=args.get('shrink_size', None) + ) + # Overwrite model parameters by parameters to load + load_data = torch_load_cpu(model_filename) + model.load_state_dict({**model.state_dict(), **load_data.get('model', {})}) + + model, *_ = _load_model_file(model_filename, model) + + model.eval() # Put in eval mode + + return model, args + + +def parse_softmax_temperature(raw_temp): + # Load from file + if os.path.isfile(raw_temp): + return np.loadtxt(raw_temp)[-1, 0] + return float(raw_temp) + + +def run_all_in_pool(func, directory, dataset, opts, use_multiprocessing=True): + # # Test + # res = func((directory, 'test', *dataset[0])) + # return [res] + + num_cpus = os.cpu_count() if opts.cpus is None else opts.cpus + + w = len(str(len(dataset) - 1)) + offset = getattr(opts, 'offset', None) + if offset is None: + offset = 0 + ds = dataset[offset:(offset + opts.n if opts.n is not None else len(dataset))] + pool_cls = (Pool if use_multiprocessing and num_cpus > 1 else ThreadPool) + with pool_cls(num_cpus) as pool: + results = list(tqdm(pool.imap( + func, + [ + ( + directory, + str(i + offset).zfill(w), + *problem + ) + for i, problem in enumerate(ds) + ] + ), total=len(ds), mininterval=opts.progress_bar_mininterval)) + + failed = [str(i + offset) for i, res in enumerate(results) if res is None] + assert len(failed) == 0, "Some instances failed: {}".format(" ".join(failed)) + return results, num_cpus + + +def do_batch_rep(v, n): + if isinstance(v, dict): + return {k: do_batch_rep(v_, n) for k, v_ in v.items()} + elif isinstance(v, list): + return [do_batch_rep(v_, n) for v_ in v] + elif isinstance(v, tuple): + return tuple(do_batch_rep(v_, n) for v_ in v) + + return v[None, ...].expand(n, *v.size()).contiguous().view(-1, *v.size()[1:]) + + +def sample_many(inner_func, get_cost_func, input, batch_rep=1, iter_rep=1): + """ + :param input: (batch_size, graph_size, node_dim) input node features + :return: + """ + + input = do_batch_rep(input, batch_rep) + costs = [] + veh_lists = [] + pis = [] + for i in range(iter_rep): + _log_p, pi, veh_list, cost = inner_func(input) + # #print('pi', pi.size()) + # + # # cost, mask = get_cost_func(input, pi, veh_list, tour_1, tour_2, tour_3) + # cost = get_cost_func(input, pi.T, veh_list.T) + costs.append(cost.view(batch_rep, -1).t()) # [1, (num_samples, batch_rep)] + veh_lists.append(veh_list.view(batch_rep, -1, veh_list.size(-1)).transpose(0, 1)) # [1, (num_samples, batch_rep, solu_len)] + pis.append(pi.view(batch_rep, -1, pi.size(-1)).transpose(0, 1)) # [1, (num_samples, batch_rep, solu_len)] + + max_length = max(pi.size(-1) for pi in pis) + # (batch_size * batch_rep, iter_rep, max_length) => (batch_size, batch_rep * iter_rep, max_length) + pis = torch.cat( # (num_samples, batch_rep, solu_len) + [F.pad(pi, (0, max_length - pi.size(-1))) for pi in pis], + 1 + ) # .view(embeddings.size(0), batch_rep * iter_rep, max_length) + + costs = torch.cat(costs, 1) # [num_samples, batch_rep] + veh_lists = torch.cat(veh_lists, 1) # [num_samples, batch_rep, solu_len] + + # (batch_size) + mincosts, argmincosts = costs.min(-1) # [num_samples] + # (batch_size, minlength) + minpis = pis[torch.arange(pis.size(0), out=argmincosts.new()), argmincosts] + minveh = veh_lists[torch.arange(veh_lists.size(0), out=argmincosts.new()), argmincosts] + + return minpis, mincosts, minveh diff --git a/utils/lexsort.py b/utils/lexsort.py new file mode 100644 index 0000000..c6a943e --- /dev/null +++ b/utils/lexsort.py @@ -0,0 +1,55 @@ +import torch +import numpy as np + + +def torch_lexsort(keys, dim=-1): + if keys[0].is_cuda: + return _torch_lexsort_cuda(keys, dim) + else: + # Use numpy lex sort + return torch.from_numpy(np.lexsort([k.numpy() for k in keys], axis=dim)) + + +def _torch_lexsort_cuda(keys, dim=-1): + """ + Function calculates a lexicographical sort order on GPU, similar to np.lexsort + Relies heavily on undocumented behavior of torch.sort, namely that when sorting more than + 2048 entries in the sorting dim, it performs a sort using Thrust and it uses a stable sort + https://github.com/pytorch/pytorch/blob/695fd981924bd805704ecb5ccd67de17c56d7308/aten/src/THC/generic/THCTensorSort.cu#L330 + """ + + MIN_NUMEL_STABLE_SORT = 2049 # Minimum number of elements for stable sort + + # Swap axis such that sort dim is last and reshape all other dims to a single (batch) dimension + reordered_keys = tuple(key.transpose(dim, -1).contiguous() for key in keys) + flat_keys = tuple(key.view(-1) for key in keys) + d = keys[0].size(dim) # Sort dimension size + numel = flat_keys[0].numel() + batch_size = numel // d + batch_key = torch.arange(batch_size, dtype=torch.int64, device=keys[0].device)[:, None].repeat(1, d).view(-1) + + flat_keys = flat_keys + (batch_key,) + + # We rely on undocumented behavior that the sort is stable provided that + if numel < MIN_NUMEL_STABLE_SORT: + n_rep = (MIN_NUMEL_STABLE_SORT + numel - 1) // numel # Ceil + rep_key = torch.arange(n_rep, dtype=torch.int64, device=keys[0].device)[:, None].repeat(1, numel).view(-1) + flat_keys = tuple(k.repeat(n_rep) for k in flat_keys) + (rep_key,) + + idx = None # Identity sorting initially + for k in flat_keys: + if idx is None: + _, idx = k.sort(-1) + else: + # Order data according to idx and then apply + # found ordering to current idx (so permutation of permutation) + # such that we can order the next key according to the current sorting order + _, idx_ = k[idx].sort(-1) + idx = idx[idx_] + + # In the end gather only numel and strip of extra sort key + if numel < MIN_NUMEL_STABLE_SORT: + idx = idx[:numel] + + # Get only numel (if we have replicated), swap axis back and shape results + return idx[:numel].view(*reordered_keys[0].size()).transpose(dim, -1) % d diff --git a/utils/log_utils.py b/utils/log_utils.py new file mode 100644 index 0000000..68c4ad2 --- /dev/null +++ b/utils/log_utils.py @@ -0,0 +1,24 @@ +def log_values(cost, grad_norms, epoch, batch_id, step, + log_likelihood, reinforce_loss, bl_loss, tb_logger, opts): + avg_cost = cost.mean().item() + grad_norms, grad_norms_clipped = grad_norms + + # Log values to screen + print('epoch: {}, train_batch_id: {}, avg_cost: {}'.format(epoch, batch_id, avg_cost)) + + print('grad_norm: {}, clipped: {}'.format(grad_norms[0], grad_norms_clipped[0])) + + # Log values to tensorboard + if not opts.no_tensorboard: + tb_logger.log_value('avg_cost', avg_cost, step) + + tb_logger.log_value('actor_loss', reinforce_loss.item(), step) + tb_logger.log_value('nll', -log_likelihood.mean().item(), step) + + tb_logger.log_value('grad_norm', grad_norms[0], step) + tb_logger.log_value('grad_norm_clipped', grad_norms_clipped[0], step) + + if opts.baseline == 'critic': + tb_logger.log_value('critic_loss', bl_loss.item(), step) + tb_logger.log_value('critic_grad_norm', grad_norms[1], step) + tb_logger.log_value('critic_grad_norm_clipped', grad_norms_clipped[1], step) diff --git a/utils/monkey_patch.py b/utils/monkey_patch.py new file mode 100644 index 0000000..5d90614 --- /dev/null +++ b/utils/monkey_patch.py @@ -0,0 +1,70 @@ +import torch +from itertools import chain +from collections import defaultdict, Iterable +from copy import deepcopy + + +def load_state_dict(self, state_dict): + """Loads the optimizer state. + Arguments: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = deepcopy(state_dict) + # Validate the state_dict + groups = self.param_groups + saved_groups = state_dict['param_groups'] + + if len(groups) != len(saved_groups): + raise ValueError("loaded state dict has a different number of " + "parameter groups") + param_lens = (len(g['params']) for g in groups) + saved_lens = (len(g['params']) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError("loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group") + + # Update the state + id_map = {old_id: p for old_id, p in + zip(chain(*(g['params'] for g in saved_groups)), + chain(*(g['params'] for g in groups)))} + + def cast(param, value): + """Make a deep copy of value, casting all tensors to device of param.""" + if torch.is_tensor(value): + # Floating-point types are a bit special here. They are the only ones + # that are assumed to always match the type of params. + if any(tp in type(param.data).__name__ for tp in {'Half', 'Float', 'Double'}): + value = value.type_as(param.data) + value = value.to(param.device) + return value + elif isinstance(value, dict): + return {k: cast(param, v) for k, v in value.items()} + elif isinstance(value, Iterable): + return type(value)(cast(param, v) for v in value) + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state = defaultdict(dict) + for k, v in state_dict['state'].items(): + if k in id_map: + param = id_map[k] + state[param] = cast(param, v) + else: + state[k] = v + + # Update parameter groups, setting their 'params' value + def update_group(group, new_group): + new_group['params'] = group['params'] + return new_group + + param_groups = [ + update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({'state': state, 'param_groups': param_groups}) + + +torch.optim.Optimizer.load_state_dict = load_state_dict \ No newline at end of file diff --git a/utils/tensor_functions.py b/utils/tensor_functions.py new file mode 100644 index 0000000..1e09f75 --- /dev/null +++ b/utils/tensor_functions.py @@ -0,0 +1,34 @@ +import torch + + +def compute_in_batches(f, calc_batch_size, *args, n=None): + """ + Computes memory heavy function f(*args) in batches + :param n: the total number of elements, optional if it cannot be determined as args[0].size(0) + :param f: The function that is computed, should take only tensors as arguments and return tensor or tuple of tensors + :param calc_batch_size: The batch size to use when computing this function + :param args: Tensor arguments with equally sized first batch dimension + :return: f(*args), this should be one or multiple tensors with equally sized first batch dimension + """ + if n is None: + n = args[0].size(0) + n_batches = (n + calc_batch_size - 1) // calc_batch_size # ceil + if n_batches == 1: + return f(*args) + + # Run all batches + # all_res = [f(*batch_args) for batch_args in zip(*[torch.chunk(arg, n_batches) for arg in args])] + # We do not use torch.chunk such that it also works for other classes that support slicing + all_res = [f(*(arg[i * calc_batch_size:(i + 1) * calc_batch_size] for arg in args)) for i in range(n_batches)] + + # Allow for functions that return None + def safe_cat(chunks, dim=0): + if chunks[0] is None: + assert all(chunk is None for chunk in chunks) + return None + return torch.cat(chunks, dim) + + # Depending on whether the function returned a tuple we need to concatenate each element or only the result + if isinstance(all_res[0], tuple): + return tuple(safe_cat(res_chunks, 0) for res_chunks in zip(*all_res)) + return safe_cat(all_res, 0)