Skip to content

Commit

Permalink
it doesn't split! add my logging
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Nov 19, 2023
1 parent 5ac6be9 commit f2c89b2
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
8 changes: 4 additions & 4 deletions only_for_me/narval/gpu_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def main():
cache_dir=None
# TODO pass through the rest
)
use_distributed_sampler=False
# use_distributed_sampler=False

trainer = pl.Trainer(
# log_every_n_steps=16, # at batch 512 (A100 MP max), DR5 has ~161 train steps
Expand All @@ -104,7 +104,7 @@ def main():
max_epochs=1,
default_root_dir=save_dir,
# plugins=plugins,
use_distributed_sampler=use_distributed_sampler
# use_distributed_sampler=use_distributed_sampler
)

# logging.info((trainer.strategy, trainer.world_size,
Expand All @@ -115,8 +115,8 @@ def main():
trainer.fit(lightning_model, datamodule) # uses batch size of datamodule

# batch size 16
# shard size 16, 10 shards with 8 being assigned as training shards so 8*32 train images, 8*2 train batches
# shard size 16, 10 shards with 8 being assigned as training shards so 8*32 train images, 8*2=16 train batches


if __name__=='__main__':
main()
main()
6 changes: 3 additions & 3 deletions only_for_me/narval/gpu_split.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/bin/bash
#SBATCH --time=0:15:0
#SBATCH --time=0:10:0
#SBATCH --nodes=1
#SBATCH --ntasks=2
#SBATCH --ntasks-per-node=2
#SBATCH --cpus-per-task=4
#SBATCH --mem-per-cpu 4G
#SBATCH --gres=gpu:v100:1
#SBATCH --gres=gpu:v100:2

nvidia-smi

Expand All @@ -22,5 +22,5 @@ export NCCL_BLOCKING_WAIT=1 #Set this environment variable if you wish to use t

echo 'Running script'
REPO_DIR=/project/def-bovy/walml/zoobot
srun $PYTHON $REPO_DIR/only_for_me/narval/gpu_split.py --gpus 1
srun $PYTHON $REPO_DIR/only_for_me/narval/gpu_split.py --gpus 2

3 changes: 1 addition & 2 deletions zoobot/pytorch/datasets/webdatamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def make_loader(self, urls, mode="train"):
dataset = (
# https://webdataset.github.io/webdataset/multinode/
# WDS 'knows' which worker it is running on and selects a subset of urls accordingly
wds.WebDataset(urls, cache_dir=self.cache_dir, shardshuffle=shuffle>0
# , nodesplitter=nodesplitter_func
wds.WebDataset(urls, cache_dir=self.cache_dir, shardshuffle=shuffle>0, nodesplitter=nodesplitter_func
)
.shuffle(shuffle)
.decode("rgb")
Expand Down

0 comments on commit f2c89b2

Please sign in to comment.