From 1f4821cbb41b225957bcaffae357afd15ea93d75 Mon Sep 17 00:00:00 2001 From: RoyalSkye Date: Thu, 20 Oct 2022 20:51:25 +0800 Subject: [PATCH] update curri strategy --- POMO/TSP/TSPTester.py | 4 +- POMO/TSP/TSPTrainer_Meta.py | 85 ++++++++++++++++++++++++++++--------- POMO/TSP/TSPTrainer_pomo.py | 13 +++--- POMO/TSP/TSP_baseline.py | 15 ++++--- POMO/TSP/TSP_gurobi.py | 2 +- POMO/TSP/TSProblemDef.py | 5 +-- POMO/TSP/test.py | 3 +- POMO/TSP/train.py | 53 +++++++++++------------ POMO/utils/functions.py | 13 ++++-- 9 files changed, 118 insertions(+), 75 deletions(-) diff --git a/POMO/TSP/TSPTester.py b/POMO/TSP/TSPTester.py index 072eafb..1bc8ffe 100644 --- a/POMO/TSP/TSPTester.py +++ b/POMO/TSP/TSPTester.py @@ -8,7 +8,7 @@ from TSPEnv import TSPEnv as Env from TSPModel import TSPModel as Model -from baselines import solve_all_gurobi +from TSP_gurobi import solve_all_gurobi from utils.utils import * from utils.functions import load_dataset, save_dataset @@ -60,6 +60,8 @@ def __init__(self, checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load) checkpoint = torch.load(checkpoint_fullname, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # TODO: which performance is good? load or not load? + self.logger.info(">> Model loaded from {}".format(checkpoint_fullname)) # utility self.time_estimator = TimeEstimator() diff --git a/POMO/TSP/TSPTrainer_Meta.py b/POMO/TSP/TSPTrainer_Meta.py index f41152d..b30eb5a 100644 --- a/POMO/TSP/TSPTrainer_Meta.py +++ b/POMO/TSP/TSPTrainer_Meta.py @@ -12,34 +12,33 @@ from torch.optim import Adam as Optimizer # from torch.optim import SGD as Optimizer -# from torch.optim.lr_scheduler import MultiStepLR as Scheduler -from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts as Scheduler from TSProblemDef import get_random_problems, generate_task_set from utils.utils import * from utils.functions import * +from TSP_baseline import * class TSPTrainer: """ - Implementation of POMO with MAML / FOMAML / Reptile. + Implementation of POMO with MAML / FOMAML / Reptile on TSP. For MAML & FOMAML, ref to "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks"; For Reptile, ref to "On First-Order Meta-Learning Algorithms". Refer to "https://lilianweng.github.io/posts/2018-11-30-meta-learning" - MAML's time and space complexity (i.e., GPU memory) is high, so we only update decoder in inner-loop (similar performance). """ def __init__(self, env_params, model_params, optimizer_params, - trainer_params): + trainer_params, + meta_params): # save arguments self.env_params = env_params self.model_params = model_params self.optimizer_params = optimizer_params self.trainer_params = trainer_params - self.meta_params = trainer_params['meta_params'] + self.meta_params = meta_params # result folder, logger self.logger = getLogger(name='trainer') @@ -60,9 +59,11 @@ def __init__(self, # Main Components self.meta_model = Model(**self.model_params) self.meta_optimizer = Optimizer(self.meta_model.parameters(), **self.optimizer_params['optimizer']) - # self.scheduler = Scheduler(self.meta_optimizer, **self.optimizer_params['scheduler']) self.alpha = self.meta_params['alpha'] # for reptile self.task_set = generate_task_set(self.meta_params) + self.min_n, self.max_n, self.task_interval = self.task_set[0][0], self.task_set[-1][0], 5 # [20, 150] / [0, 100] + # self.task_w = {start: 1/(len(self.task_set)//5) for start in range(self.min_n, self.max_n, self.task_interval)} + self.task_w = torch.full((len(self.task_set)//self.task_interval,), 1/(len(self.task_set)//self.task_interval)) self.ema_est = {i[0]: 1 for i in self.task_set} # Restore @@ -75,8 +76,7 @@ def __init__(self, self.start_epoch = 1 + model_load['epoch'] self.result_log.set_raw_data(checkpoint['result_log']) self.meta_optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - # self.scheduler.last_epoch = model_load['epoch']-1 - self.logger.info('Saved Model Loaded !!') + self.logger.info(">> Model loaded from {}".format(checkpoint_fullname)) # utility self.time_estimator = TimeEstimator() @@ -90,7 +90,6 @@ def run(self): # Train train_score, train_loss = self._train_one_epoch(epoch) - # self.scheduler.step() self.result_log.append('train_score', epoch, train_score) self.result_log.append('train_loss', epoch, train_loss) # Val @@ -98,7 +97,7 @@ def run(self): if self.meta_params["data_type"] == "size": paths = ["tsp50_uniform.pkl", "tsp100_uniform.pkl", "tsp200_uniform.pkl"] elif self.meta_params["data_type"] == "distribution": - paths = ["tsp100_uniform.pkl", "tsp100_gaussian.pkl", "tsp100_cluster.pkl", "tsp100_diagonal.pkl", "tsp100_tsplib.pkl"] + paths = ["tsp100_uniform.pkl", "tsp100_cluster.pkl", "tsp100_diagonal.pkl"] elif self.meta_params["data_type"] == "size_distribution": pass for val_path in paths: @@ -135,7 +134,6 @@ def run(self): 'epoch': epoch, 'model_state_dict': self.meta_model.state_dict(), 'optimizer_state_dict': self.meta_optimizer.state_dict(), - # 'scheduler_state_dict': self.scheduler.state_dict(), 'result_log': self.result_log.get_raw_data() } torch.save(checkpoint_dict, '{}/checkpoint-{}.pt'.format(self.result_folder, epoch)) @@ -162,16 +160,23 @@ def _train_one_epoch(self, epoch): loss_AM = AverageMeter() """ - Curriculum learning: + Curriculum learning / Adaptive task scheduler: for size: gradually increase the problem size - for distribution: gradually increase adversarial budgets (i.e., \epsilon) + for distribution: adversarial budgets (i.e., \epsilon) may not be correlated with the hardness of constructed + data distribution. Instead, we evaluate the relative gaps (w.r.t. LKH3) of dist/eps sampled + from each interval every X iters. Hopefully, it can indicate the hardness of its neighbor. """ - if self.meta_params["data_type"] in ["size", "distribution"]: - self.min_n, self.max_n = self.task_set[0][0], self.task_set[-1][0] # [20, 150] / [0, 130] + if self.meta_params["data_type"] == "size": # start = self.min_n + int(epoch/self.meta_params['epochs'] * (self.max_n - self.min_n)) # linear start = self.min_n + int(1/2 * (1-math.cos(math.pi * min(epoch/self.meta_params['epochs'], 1))) * (self.max_n - self.min_n)) # cosine end = min(start + 10, self.max_n) # 10 is the size of the sliding window if self.meta_params["curriculum"]: print(">> training task {}".format((start, end))) + elif self.meta_params["data_type"] == "distribution": + # Every X iters, evaluating 50 instances for each interval (e.g., [1, 6) / [6, 11) / ...) using LKH3 + if epoch != 0 and epoch % self.meta_params['update_weight'] == 0: + self._update_task_weight() + start = torch.multinomial(self.task_w, 1).item() * self.task_interval + end = min(start + self.task_interval, self.max_n) elif self.meta_params["data_type"] == "size_distribution": pass @@ -215,6 +220,7 @@ def _train_one_epoch(self, epoch): loss_AM.update(avg_loss.item(), batch_size) val_data = self._get_val_data(batch_size, task_params) + self.meta_model.train() if self.meta_params['meta_method'] == 'maml': val_loss = self._fast_val(fast_weight, data=val_data, mode="maml") val_loss /= self.meta_params['B'] @@ -340,7 +346,7 @@ def _train_one_batch_maml(self, fast_weight, data, env, optimizer=None): return score_mean, loss_mean, fast_weight - def _fast_val(self, model, data=None, path=None, val_episodes=32, mode="eval"): + def _fast_val(self, model, data=None, path=None, val_episodes=32, mode="eval", return_all=False): aug_factor = 1 data = torch.Tensor(load_dataset(path)[: val_episodes]) if data is None else data env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)}) @@ -395,9 +401,13 @@ def _fast_val(self, model, data=None, path=None, val_episodes=32, mode="eval"): max_pomo_reward, _ = aug_reward.max(dim=2) # get best results from pomo # shape: (augmentation, batch) no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive value + print(no_aug_score) if mode == "eval": - return no_aug_score.detach().item() + if return_all: + return -max_pomo_reward[0, :].float() + else: + return no_aug_score.detach().item() else: return loss_mean @@ -460,11 +470,14 @@ def _get_data(self, batch_size, task_params): return data def _get_val_data(self, batch_size, task_params): - if self.meta_params["data_type"] in ["size", "distribution"]: + if self.meta_params["data_type"] == "size": start1, end1 = min(task_params[0] + 10, self.max_n), min(task_params[0] + 20, self.max_n) + val_size = random.sample(range(start1, end1 + 1), 1)[0] + elif self.meta_params["data_type"] == "distribution": + val_size = task_params[0] elif self.meta_params["data_type"] == "size_distribution": pass - val_size = random.sample(range(start1, end1 + 1), 1)[0] + val_data = self._get_data(batch_size, (val_size,)) return val_data @@ -487,7 +500,6 @@ def minmax(xy_): if eps == 0: return data # generate x_adv - print(">> Warning! Generating x_adv!") self.meta_model.eval() aug_factor, batch_size = 1, data.size(0) env = Env(**{'problem_size': data.size(1), 'pomo_size': data.size(1)}) @@ -519,3 +531,34 @@ def minmax(xy_): # return data, opt_sol return data + + def _update_task_weight(self): + """ + Update the weights of tasks. + """ + gap = torch.zeros(len(self.task_set)//self.task_interval) + for i in range(gap.size(0)): + start = i * self.task_interval + end = min(start + self.task_interval, self.max_n) + selected = random.sample([j for j in range(start, end+1)], 1)[0] + data = self._get_data(batch_size=50, task_params=(selected, )) + model_score = self._fast_val(self.meta_model, data=data, mode="eval", return_all=True) + model_score = model_score.tolist() + + # get results from LKH3 (~14s) + # start_t = time.time() + opts = argparse.ArgumentParser() + opts.cpus, opts.n, opts.progress_bar_mininterval = None, None, 0.1 + dataset = [(instance.cpu().numpy(),) for instance in data] + executable = get_lkh_executable() + global run_func + def run_func(args): + return solve_lkh_log(executable, *args, runs=1, disable_cache=True) # otherwise it directly loads data from dir + results, _ = run_all_in_pool(run_func, "./LKH3_result", dataset, opts, use_multiprocessing=False) + gap_list = [(model_score[j]-results[j][0])/results[j][0]*100 for j in range(len(results))] + gap[i] = sum(gap_list)/len(gap_list) + # print(">> LKH3 finished within {}s".format(time.time()-start_t)) + print(gap) + print(">> Old task weights: {}".format(self.task_w)) + self.task_w = torch.softmax(gap, dim=0) + print(">> New task weights: {}".format(self.task_w)) diff --git a/POMO/TSP/TSPTrainer_pomo.py b/POMO/TSP/TSPTrainer_pomo.py index 9c3c970..7b836f2 100644 --- a/POMO/TSP/TSPTrainer_pomo.py +++ b/POMO/TSP/TSPTrainer_pomo.py @@ -27,14 +27,15 @@ def __init__(self, env_params, model_params, optimizer_params, - trainer_params): + trainer_params, + meta_params): # save arguments self.env_params = env_params self.model_params = model_params self.optimizer_params = optimizer_params self.trainer_params = trainer_params - self.meta_params = trainer_params['meta_params'] + self.meta_params = meta_params # result folder, logger self.logger = getLogger(name='trainer') @@ -66,7 +67,7 @@ def __init__(self, self.meta_model.load_state_dict(checkpoint['model_state_dict']) self.start_epoch = 1 + model_load['epoch'] self.result_log.set_raw_data(checkpoint['result_log']) - # self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # self.scheduler.last_epoch = model_load['epoch']-1 self.logger.info('Saved Model Loaded !!') @@ -89,13 +90,13 @@ def run(self): if self.meta_params["data_type"] == "size": paths = ["tsp50_uniform.pkl", "tsp100_uniform.pkl", "tsp200_uniform.pkl"] elif self.meta_params["data_type"] == "distribution": - paths = ["tsp100_uniform.pkl", "tsp100_gaussian.pkl", "tsp100_cluster.pkl", "tsp100_diagonal.pkl", "tsp100_tsplib.pkl"] + paths = ["tsp100_uniform.pkl", "tsp100_cluster.pkl", "tsp100_diagonal.pkl"] elif self.meta_params["data_type"] == "size_distribution": pass for val_path in paths: no_aug_score = self._fast_val(self.meta_model, path=os.path.join(dir, val_path), val_episodes=64) no_aug_score_list.append(round(no_aug_score, 4)) - self.result_log.append('val_score', epoch, no_aug_score_list[-1]) + self.result_log.append('val_score', epoch, no_aug_score_list) # Logs & Checkpoint elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.meta_params['epochs']) @@ -122,7 +123,7 @@ def run(self): checkpoint_dict = { 'epoch': epoch, 'model_state_dict': self.meta_model.state_dict(), - # 'optimizer_state_dict': self.optimizer.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), # 'scheduler_state_dict': self.scheduler.state_dict(), 'result_log': self.result_log.get_raw_data() } diff --git a/POMO/TSP/TSP_baseline.py b/POMO/TSP/TSP_baseline.py index a2ae178..a38870d 100644 --- a/POMO/TSP/TSP_baseline.py +++ b/POMO/TSP/TSP_baseline.py @@ -1,12 +1,13 @@ import argparse import numpy as np -import os, re +import os, re, sys import time from datetime import timedelta from scipy.spatial import distance_matrix from subprocess import check_call, check_output, CalledProcessError import torch from torch.utils.data import Dataset +from urllib.parse import urlparse from tqdm import tqdm os.chdir(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, "..") # for utils @@ -99,7 +100,7 @@ def solve_concorde_log(executable, directory, name, loc, disable_cache=False): def get_lkh_executable(url="http://www.akira.ruc.dk/~keld/research/LKH-3/LKH-3.0.7.tgz"): - cwd = os.path.abspath(os.path.join("problems", "vrp", "lkh")) + cwd = os.path.abspath(os.path.join("lkh")) os.makedirs(cwd, exist_ok=True) file = os.path.join(cwd, os.path.split(urlparse(url).path)[-1]) @@ -132,7 +133,7 @@ def solve_lkh_log(executable, directory, name, loc, runs=1, disable_cache=False) try: # May have already been run if os.path.isfile(output_filename) and not disable_cache: - tour, duration = load_dataset(output_filename) + tour, duration = load_dataset(output_filename, disable_print=True) else: write_tsplib(problem_filename, loc, name=name) @@ -371,8 +372,8 @@ def solve_all_nn(dataset_path, eval_batch_size=1024, no_cuda=False, dataset_n=No if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("method", type=str, default='concorde', choices=['nn', "gurobi", "gurobigap", "gurobit", "concorde", "lkh", "random_insertion", "nearest_insertion", "farthest_insertion"]) - parser.add_argument("datasets", nargs='+', help="Filename of the dataset(s) to evaluate") + parser.add_argument("--method", type=str, default='lkh', choices=['nn', "gurobi", "gurobigap", "gurobit", "concorde", "lkh", "random_insertion", "nearest_insertion", "farthest_insertion"]) + parser.add_argument("--datasets", nargs='+', default=["../../data/TSP/tsp50_uniform.pkl", ], help="Filename of the dataset(s) to evaluate") parser.add_argument("-f", action='store_true', help="Set true to overwrite") parser.add_argument("-o", default=None, help="Name of the results file to write") parser.add_argument("--cpus", type=int, help="Number of CPUs to use, defaults to all cores") @@ -380,8 +381,8 @@ def solve_all_nn(dataset_path, eval_batch_size=1024, no_cuda=False, dataset_n=No parser.add_argument('--disable_cache', action='store_true', help='Disable caching') parser.add_argument('--max_calc_batch_size', type=int, default=1000, help='Size for subbatches') parser.add_argument('--progress_bar_mininterval', type=float, default=0.1, help='Minimum interval') - parser.add_argument('-n', type=int, help="Number of instances to process") - parser.add_argument('--offset', type=int, help="Offset where to start processing") + parser.add_argument('-n', type=int, default=1000, help="Number of instances to process") + parser.add_argument('--offset', type=int, default=0, help="Offset where to start processing") parser.add_argument('--results_dir', default='results', help="Name of results directory") opts = parser.parse_args() diff --git a/POMO/TSP/TSP_gurobi.py b/POMO/TSP/TSP_gurobi.py index 7480eb8..916358b 100644 --- a/POMO/TSP/TSP_gurobi.py +++ b/POMO/TSP/TSP_gurobi.py @@ -120,7 +120,7 @@ def solve_all_gurobi(dataset): for i, instance in enumerate(dataset): print("Solving instance {}".format(i)) # some hard instances may take prohibitively long time, and ultimately kill the solver, so we set tl=1800s for TSP100 to avoid that. - result = solve_euclidian_tsp(instance, timeout=1800) + result = solve_euclidian_tsp(instance) results.append(result) return results diff --git a/POMO/TSP/TSProblemDef.py b/POMO/TSP/TSProblemDef.py index 1c17d5a..7645c26 100644 --- a/POMO/TSP/TSProblemDef.py +++ b/POMO/TSP/TSProblemDef.py @@ -10,13 +10,10 @@ def generate_task_set(meta_params): if meta_params['data_type'] == "distribution": # focus on the TSP100 with different distributions - # task_set = [(m, l) for l in [1, 10, 20, 30, 50] for m in range(1, 1 + meta_params['num_task'] // 5)] + [(0, 0)] task_set = [(eps,) for eps in range(0, 0 + meta_params['num_task'] + 1)] elif meta_params['data_type'] == "size": # focus on uniform distribution with different sizes task_set = [(n,) for n in range(20, 20 + meta_params['num_task'] + 1)] elif meta_params['data_type'] == "size_distribution": - # task_set = [(m, l) for l in [1, 10, 20, 30, 50] for m in range(1, 11)] + [(0, 0)] - # task_set = [(n, m, l) for n in [25, 50, 75, 100, 125, 150] for (m, l) in task_set] task_set = [(n, eps) for n in range(20, 20 + meta_params['num_task'] + 1) for eps in range(0, 0 + meta_params['num_task'] + 1, 10)] else: raise NotImplementedError @@ -115,7 +112,7 @@ def generate_tsp_dist(n_samples, n_nodes, distribution): for i in range(n_samples): print(n_nodes, i) loc = [] - n_cluster = np.random.randint(low=3, high=9) + n_cluster = np.random.randint(low=2, high=9) loc.append(np.random.randint(1000, size=[1, n_cluster, 2])) prob = np.zeros((1000, 1000)) coord = np.concatenate([np.tile(np.arange(1000).reshape(-1, 1, 1), [1, 1000, 1]), diff --git a/POMO/TSP/test.py b/POMO/TSP/test.py index 8bef724..4f31031 100644 --- a/POMO/TSP/test.py +++ b/POMO/TSP/test.py @@ -47,7 +47,7 @@ 'test_set_path': '../../data/TSP/tsp100_uniform.pkl', 'test_set_opt_sol_path': '../../data/TSP/gurobi/tsp100_uniform.pkl', 'fine_tune_params': { - 'enable': True, # evaluate few-shot generalization + 'enable': False, # evaluate few-shot generalization 'fine_tune_episodes': 3000, # how many data used to fine-tune the pretrained model 'k': 20, # gradient decent steps in the inner-loop optimization of meta-learning method 'fine_tune_batch_size': 64, # the batch size of the inner-loop optimization @@ -101,5 +101,4 @@ def _print_config(): if __name__ == "__main__": - # TODO: 1. why not use test dataset to fine-tune the model? main() diff --git a/POMO/TSP/train.py b/POMO/TSP/train.py index 3fd1786..5f60f89 100644 --- a/POMO/TSP/train.py +++ b/POMO/TSP/train.py @@ -5,7 +5,6 @@ import logging from utils.utils import create_logger, copy_all_src from utils.functions import seed_everything -from TSPTrainer import TSPTrainer as Trainer from TSPTrainer_pomo import TSPTrainer as Trainer_Pomo from TSPTrainer_Meta import TSPTrainer as Trainer_Meta @@ -42,10 +41,6 @@ # 'milestones': [3001, ], # 'gamma': 0.1 # }, - 'scheduler': { - 'T_0': 5000, - 'T_mult': 2, - }, } trainer_params = { @@ -73,22 +68,23 @@ 'enable': False, # enable loading pre-trained model # 'path': './result/saved_tsp20_model', # directory path of pre-trained model and log files saved. # 'epoch': 510, # epoch version of pre-trained model to laod. - }, - 'meta_params': { - 'enable': True, # whether use meta-learning or not - 'curriculum': True, - 'meta_method': 'maml', # choose from ['maml', 'fomaml', 'reptile'] - 'bootstrap_steps': 0, - 'data_type': 'distribution', # choose from ["size", "distribution", "size_distribution"] - 'epochs': 50000, # the number of meta-model updates: (250*100000) / (1*5*64) - 'B': 1, # the number of tasks in a mini-batch - 'k': 1, # gradient decent steps in the inner-loop optimization of meta-learning method - 'meta_batch_size': 64, # will be divided by 2 if problem_size >= 100 - 'num_task': 130, # the number of tasks in the training task set: e.g., [20, 150] / [0, 130] - 'alpha': 0.99, # params for the outer-loop optimization of reptile - 'alpha_decay': 0.999, # params for the outer-loop optimization of reptile - } +} + +meta_params = { + 'enable': True, # whether use meta-learning or not + 'curriculum': True, # adaptive sample task + 'meta_method': 'maml', # choose from ['maml', 'fomaml', 'reptile'] + 'bootstrap_steps': 0, + 'data_type': 'distribution', # choose from ["size", "distribution", "size_distribution"] + 'epochs': 50000, # the number of meta-model updates: (250*100000) / (1*5*64) + 'B': 1, # the number of tasks in a mini-batch + 'k': 1, # gradient decent steps in the inner-loop optimization of meta-learning method + 'meta_batch_size': 64, # will be divided by 2 if problem_size >= 100 + 'num_task': 100, # the number of tasks in the training task set: e.g., [20, 150] / [0, 100] + 'update_weight': 1000, # update weight of rach task per X iters + 'alpha': 0.99, # params for the outer-loop optimization of reptile + 'alpha_decay': 0.999, # params for the outer-loop optimization of reptile } logger_params = { @@ -108,13 +104,13 @@ def main(): seed_everything(trainer_params['seed']) - if not trainer_params['meta_params']['enable']: + if not meta_params['enable']: print(">> Start POMO Training.") # trainer = Trainer(env_params=env_params, model_params=model_params, optimizer_params=optimizer_params, trainer_params=trainer_params) - trainer = Trainer_Pomo(env_params=env_params, model_params=model_params, optimizer_params=optimizer_params, trainer_params=trainer_params) - elif trainer_params['meta_params']['meta_method'] in ['maml', 'fomaml', 'reptile']: - print(">> Start POMO-{} Training.".format(trainer_params['meta_params']['meta_method'])) - trainer = Trainer_Meta(env_params=env_params, model_params=model_params, optimizer_params=optimizer_params, trainer_params=trainer_params) + trainer = Trainer_Pomo(env_params=env_params, model_params=model_params, optimizer_params=optimizer_params, trainer_params=trainer_params, meta_params=meta_params) + elif meta_params['meta_method'] in ['maml', 'fomaml', 'reptile']: + print(">> Start POMO-{} Training.".format(meta_params['meta_method'])) + trainer = Trainer_Meta(env_params=env_params, model_params=model_params, optimizer_params=optimizer_params, trainer_params=trainer_params, meta_params=meta_params) else: raise NotImplementedError @@ -148,13 +144,12 @@ def occumpy_mem(cuda_device): total, used = check_mem(cuda_device) total = int(total) used = int(used) - max_mem = int(total * 0.85) - block_mem = max_mem - used + block_mem = int((total-used) * 0.85) x = torch.cuda.FloatTensor(256, 1024, block_mem) del x if __name__ == "__main__": - if trainer_params["meta_params"]["data_type"] in ["size", "size_distribution"]: - occumpy_mem(CUDA_DEVICE_NUM) + if meta_params["data_type"] in ["size", "size_distribution"]: + occumpy_mem(CUDA_DEVICE_NUM) # reserve GPU memory for large size instances main() diff --git a/POMO/utils/functions.py b/POMO/utils/functions.py index aa928c7..24db419 100644 --- a/POMO/utils/functions.py +++ b/POMO/utils/functions.py @@ -7,6 +7,9 @@ import json import pickle import matplotlib.pyplot as plt +from tqdm import tqdm +from multiprocessing import Pool +from multiprocessing.dummy import Pool as ThreadPool def check_extension(filename): @@ -26,11 +29,12 @@ def save_dataset(dataset, filename): pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL) -def load_dataset(filename): +def load_dataset(filename, disable_print=False): with open(check_extension(filename), 'rb') as f: data = pickle.load(f) - print(">> Load {} data ({}) from {}".format(len(data), type(data), filename)) + if not disable_print: + print(">> Load {} data ({}) from {}".format(len(data), type(data), filename)) return data @@ -70,11 +74,12 @@ def clip_grad_norms(param_groups, max_norm=math.inf): return grad_norms, grad_norms_clipped -def run_all_in_pool(func, directory, dataset, opts, use_multiprocessing=True): +def run_all_in_pool(func, directory, dataset, opts, use_multiprocessing=True, disable_tqdm=True): # # Test # res = func((directory, 'test', *dataset[0])) # return [res] + os.makedirs(directory, exist_ok=True) num_cpus = os.cpu_count() if opts.cpus is None else opts.cpus w = len(str(len(dataset) - 1)) @@ -94,7 +99,7 @@ def run_all_in_pool(func, directory, dataset, opts, use_multiprocessing=True): ) for i, problem in enumerate(ds) ] - ), total=len(ds), mininterval=opts.progress_bar_mininterval)) + ), total=len(ds), mininterval=opts.progress_bar_mininterval, disable=disable_tqdm)) failed = [str(i + offset) for i, res in enumerate(results) if res is None] assert len(failed) == 0, "Some instances failed: {}".format(" ".join(failed))