diff --git a/only_for_me/narval/gpu_split.py b/only_for_me/narval/gpu_split.py index 9e7802d3..79aad30b 100644 --- a/only_for_me/narval/gpu_split.py +++ b/only_for_me/narval/gpu_split.py @@ -90,7 +90,7 @@ def main(): cache_dir=None # TODO pass through the rest ) - use_distributed_sampler=False + # use_distributed_sampler=False trainer = pl.Trainer( # log_every_n_steps=16, # at batch 512 (A100 MP max), DR5 has ~161 train steps @@ -104,7 +104,7 @@ def main(): max_epochs=1, default_root_dir=save_dir, # plugins=plugins, - use_distributed_sampler=use_distributed_sampler + # use_distributed_sampler=use_distributed_sampler ) # logging.info((trainer.strategy, trainer.world_size, @@ -115,8 +115,8 @@ def main(): trainer.fit(lightning_model, datamodule) # uses batch size of datamodule # batch size 16 - # shard size 16, 10 shards with 8 being assigned as training shards so 8*32 train images, 8*2 train batches + # shard size 16, 10 shards with 8 being assigned as training shards so 8*32 train images, 8*2=16 train batches if __name__=='__main__': - main() \ No newline at end of file + main() diff --git a/only_for_me/narval/gpu_split.sh b/only_for_me/narval/gpu_split.sh index 0c5d27b3..715ca5e3 100644 --- a/only_for_me/narval/gpu_split.sh +++ b/only_for_me/narval/gpu_split.sh @@ -1,11 +1,11 @@ #!/bin/bash -#SBATCH --time=0:15:0 +#SBATCH --time=0:10:0 #SBATCH --nodes=1 #SBATCH --ntasks=2 #SBATCH --ntasks-per-node=2 #SBATCH --cpus-per-task=4 #SBATCH --mem-per-cpu 4G -#SBATCH --gres=gpu:v100:1 +#SBATCH --gres=gpu:v100:2 nvidia-smi @@ -22,5 +22,5 @@ export NCCL_BLOCKING_WAIT=1 #Set this environment variable if you wish to use t echo 'Running script' REPO_DIR=/project/def-bovy/walml/zoobot -srun $PYTHON $REPO_DIR/only_for_me/narval/gpu_split.py --gpus 1 +srun $PYTHON $REPO_DIR/only_for_me/narval/gpu_split.py --gpus 2 diff --git a/zoobot/pytorch/datasets/webdatamodule.py b/zoobot/pytorch/datasets/webdatamodule.py index 27e0c408..bec1c7f4 100644 --- a/zoobot/pytorch/datasets/webdatamodule.py +++ b/zoobot/pytorch/datasets/webdatamodule.py @@ -81,8 +81,7 @@ def make_loader(self, urls, mode="train"): dataset = ( # https://webdataset.github.io/webdataset/multinode/ # WDS 'knows' which worker it is running on and selects a subset of urls accordingly - wds.WebDataset(urls, cache_dir=self.cache_dir, shardshuffle=shuffle>0 - # , nodesplitter=nodesplitter_func + wds.WebDataset(urls, cache_dir=self.cache_dir, shardshuffle=shuffle>0, nodesplitter=nodesplitter_func ) .shuffle(shuffle) .decode("rgb")