diff --git a/create_parameter_weights.py b/create_parameter_weights.py index acb4084e..68d25c1b 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -19,27 +19,14 @@ def get_rank(): """Get the rank of the current process in the distributed group.""" if "SLURM_PROCID" in os.environ: return int(os.environ["SLURM_PROCID"]) - parser = ArgumentParser() - parser.add_argument( - "--rank", type=int, default=0, help="Rank of the current process" - ) - args, _ = parser.parse_known_args() - return args.rank + return 0 def get_world_size(): """Get the number of processes in the distributed group.""" if "SLURM_NTASKS" in os.environ: return int(os.environ["SLURM_NTASKS"]) - parser = ArgumentParser() - parser.add_argument( - "--world_size", - type=int, - default=1, - help="Number of processes in the distributed group", - ) - args, _ = parser.parse_known_args() - return args.world_size + return 1 def setup(rank, world_size): # pylint: disable=redefined-outer-name @@ -63,6 +50,11 @@ def setup(rank, world_size): # pylint: disable=redefined-outer-name dist.init_process_group("nccl", rank=rank, world_size=world_size) else: dist.init_process_group("gloo", rank=rank, world_size=world_size) + print( + f"Initialized {dist.get_backend()} process group with " + f"world size " + f"{world_size}." + ) def cleanup(): @@ -70,6 +62,24 @@ def cleanup(): dist.destroy_process_group() +def adjust_dataset_size(ds, world_size, batch_size): + # pylint: disable=redefined-outer-name + """Adjust the dataset size to be divisible by world_size * batch_size.""" + total_samples = len(ds) + subset_samples = (total_samples // (world_size * batch_size)) * ( + world_size * batch_size + ) + + if subset_samples != total_samples: + ds = torch.utils.data.Subset(ds, range(subset_samples)) + print( + f"Dataset size adjusted from {total_samples} to " + f"{subset_samples} to be divisible by (world_size * batch_size)." + ) + + return ds + + def main(rank, world_size): # pylint: disable=redefined-outer-name """Compute the mean and standard deviation of the input data.""" setup(rank, world_size) @@ -100,11 +110,6 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name ) args = parser.parse_args() - if args.subset % (world_size * args.batch_size) != 0: - raise ValueError( - "Subset size must be divisible by (world_size * batch_size)" - ) - device = torch.device( f"cuda:{rank % torch.cuda.device_count()}" if torch.cuda.is_available() @@ -140,6 +145,8 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name standardize=False, ) # Without standardization + ds = adjust_dataset_size(ds, world_size, args.batch_size) + train_sampler = DistributedSampler(ds, num_replicas=world_size, rank=rank) loader = torch.utils.data.DataLoader( ds, @@ -202,6 +209,9 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name pred_length=63, standardize=True, ) # Re-load with standardization + + ds_standard = adjust_dataset_size(ds_standard, world_size, args.batch_size) + sampler_standard = DistributedSampler( ds_standard, num_replicas=world_size, rank=rank ) @@ -217,9 +227,7 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name diff_means = [] diff_squares = [] - for init_batch, target_batch, _, _ in tqdm( - loader_standard, disable=rank != 0 - ): + for init_batch, target_batch, _ in tqdm(loader_standard, disable=rank != 0): batch = torch.cat((init_batch, target_batch), dim=1).to(device) # Note: batch contains only 1h-steps stepped_batch = torch.cat(