Skip to content

Commit

Permalink
try maxvit on desi only
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Nov 23, 2023
1 parent 751b5d6 commit afae520
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
18 changes: 12 additions & 6 deletions only_for_me/narval/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@ export NCCL_BLOCKING_WAIT=1 #Set this environment variable if you wish to use t
# echo "r$SLURM_NODEID Launching python script"

REPO_DIR=/project/def-bovy/walml/zoobot
# srun $PYTHON $REPO_DIR/only_for_me/narval/train.py \
# --save-dir $REPO_DIR/only_for_me/narval/desi_300px_f128_1gpu \
# --batch-size 256 \
# --num-features 128 \
# --gpus 1 \
# --num-workers 10 \
# --color --wandb --mixed-precision --compile-encoder

srun $PYTHON $REPO_DIR/only_for_me/narval/train.py \
--save-dir $REPO_DIR/only_for_me/narval/desi_300px_f128_1gpu \
--batch-size 256 \
--num-features 128 \
--save-dir $REPO_DIR/only_for_me/narval/desi_300px_maxvittiny_rw_224_1gpu \
--batch-size 64 \
--gpus 1 \
--num-workers 10 \
--architecture maxvit_tiny_rw_224 \
--color --wandb --mixed-precision --compile-encoder

# srun python $SLURM_TMPDIR/zoobot/only_for_me/narval/finetune.py

# --architecture maxvit_small_tf_224 \
# maxvit_small_tf_224 \
8 changes: 5 additions & 3 deletions zoobot/pytorch/datasets/webdataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ def df_to_wds(df: pd.DataFrame, label_cols, save_loc: str, n_shards: int, sparse
# transform = None

for shard_n, shard_df in tqdm.tqdm(enumerate(shard_dfs), total=len(shard_dfs)):
if sparse_label_df is not None:
shard_df = pd.merge(shard_df, sparse_label_df, how='left', validate='one_to_one', suffixes=('', '_badlabelmerge')) # auto-merge
shard_save_loc = save_loc.replace('.tar', f'_{shard_n}_{len(shard_df)}.tar')
if overwrite or not(os.path.isfile(shard_save_loc)):
logging.info(shard_save_loc)

if sparse_label_df is not None:
shard_df = pd.merge(shard_df, sparse_label_df, how='left', validate='one_to_one', suffixes=('', '_badlabelmerge')) # auto-merge

# logging.info(shard_save_loc)
sink = wds.TarWriter(shard_save_loc)
for _, galaxy in shard_df.iterrows():
sink.write(galaxy_to_wds(galaxy, label_cols, transform=transform))
Expand Down

0 comments on commit afae520

Please sign in to comment.