Skip to content

Commit

Permalink
pass aug args
Browse files Browse the repository at this point in the history
add timm kwargs option
also need to clone timm (temp?)
  • Loading branch information
mwalmsley committed Nov 22, 2023
1 parent 032e52a commit e142683
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 17 deletions.
20 changes: 13 additions & 7 deletions only_for_me/narval/make_webdataset_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,23 @@ def main():
# desi pipeline shreds sources. Be careful to deduplicate.

columns = ['id_str'] + label_cols
votes = pd.concat([
pd.read_parquet(f'/media/walml/beta/galaxy_zoo/decals/dr8/catalogs/training_catalogs/{campaign}_ortho_v5_labelled_catalog.parquet', columns=columns)
for campaign in ['dr12', 'dr5', 'dr8']
], axis=0)
assert votes['id_str'].value_counts().max() == 1, votes['id_str'].value_counts()
votes['dr8_id'] = votes['id_str']
df = pd.merge(df, votes[['dr8_id']], on='dr8_id', how='inner')
# votes = pd.concat([
# pd.read_parquet(f'/media/walml/beta/galaxy_zoo/decals/dr8/catalogs/training_catalogs/{campaign}_ortho_v5_labelled_catalog.parquet', columns=columns)
# for campaign in ['dr12', 'dr5', 'dr8']
# ], axis=0)
# assert votes['id_str'].value_counts().max() == 1, votes['id_str'].value_counts()
# votes['dr8_id'] = votes['id_str']

# name = 'labelled'
# merge_strategy = {'labelled': 'inner', 'all': 'left'}
# df = pd.merge(df, votes[['dr8_id']], on='dr8_id', how=merge_strategy[name])

df['relative_file_loc'] = df.apply(lambda x: f"{x['brickid']}/{x['brickid']}_{x['objid']}.jpg", axis=1)
df['file_loc'] = '/home/walml/data/desi/jpg/' + df['relative_file_loc']

df_dedup = remove_close_sky_matches(df)
print(len(df_dedup))
df_dedup.to_parquet('/home/walml/data/desi/master_all_file_index_all_dedup_20arcsec.parquet')
exit()
# df_dedup2 = remove_close_sky_matches(df_dedup)
# print(len(df_dedup2))
Expand Down Expand Up @@ -103,6 +107,7 @@ def remove_close_sky_matches(df, seplimit=20*u.arcsec, col_to_prioritise='ra'):

search_coords = catalog

logging.info('Beginning search for nearby galaxies')
idxc, idxcatalog, d2d, _ = catalog.search_around_sky(search_coords, seplimit=seplimit)
# idxc is index in search coords
# idxcatalog is index in catalog
Expand All @@ -114,6 +119,7 @@ def remove_close_sky_matches(df, seplimit=20*u.arcsec, col_to_prioritise='ra'):
idxcatalog = idxcatalog[d2d > 0]
d2d = d2d[d2d > 0]

logging.info('Beginning drop prioritisation')
indices_to_drop = []
for search_index_val in pd.unique(idxc):
matched_indices = idxcatalog[idxc == search_index_val]
Expand Down
14 changes: 9 additions & 5 deletions only_for_me/narval/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@
See zoobot/pytorch/examples/minimal_examples.py for a friendlier example
"""
parser = argparse.ArgumentParser()
parser.add_argument('--save-dir', dest='save_dir', type=str)
parser.add_argument('--save-dir', dest='save_dir', type=str, default='local_debug')
# parser.add_argument('--data-dir', dest='data_dir', type=str)
# parser.add_argument('--dataset', dest='dataset', type=str, help='dataset to use, either "gz_decals_dr5" or "gz_evo"')
parser.add_argument('--architecture', dest='architecture_name', default='efficientnet_b0', type=str)
parser.add_argument('--resize-after-crop', dest='resize_after_crop',
type=int, default=224)
parser.add_argument('--color', default=False, action='store_true')
parser.add_argument('--compile-encoder', dest='compile_encoder', default=False, action='store_true')
parser.add_argument('--batch-size', dest='batch_size',
default=256, type=int)
default=16, type=int)
parser.add_argument('--num-features', dest='num_features',
default=1280, type=int)
parser.add_argument('--gpus', dest='gpus', default=1, type=int)
Expand Down Expand Up @@ -62,10 +63,13 @@
# logging.info([(x, y) for (x, y) in os.environ.items() if 'SLURM' in x])

if os.path.isdir('/home/walml/repos/zoobot'):
search_str = '/home/walml/repos/zoobot/gz_decals_5_train_*.tar'
logging.warning('local mode')
search_str = '/home/walml/data/wds/desi_labelled_2048/desi_labelled_train_*.tar'
cache_dir = None

else:
search_str = '/home/walml/projects/def-bovy/walml/data/webdatasets/desi_labelled_2048/desi_labelled_train_*.tar'
cache_dir = os.environ['SLURM_TMPDIR'] + '/cache'

all_urls = glob.glob(search_str)
assert len(all_urls) > 0, search_str
Expand Down Expand Up @@ -115,10 +119,10 @@
wandb_logger=wandb_logger,
prefetch_factor=1, # TODO
num_workers=args.num_workers,
compile_encoder=True, # NEW
compile_encoder=args.compile_encoder, # NEW
random_state=random_state,
learning_rate=1e-3,
cache_dir=os.environ['SLURM_TMPDIR'] + '/cache'
cache_dir=cache_dir
# cache_dir='/tmp/cache'
# /tmp for ramdisk (400GB total, vs 4TB total for nvme)
)
Expand Down
2 changes: 1 addition & 1 deletion only_for_me/narval/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ srun $PYTHON $REPO_DIR/only_for_me/narval/train.py \
--num-features 128 \
--gpus 2 \
--num-workers 10 \
--color --wandb --mixed-precision
--color --wandb --mixed-precision --compile-encoder

# srun python $SLURM_TMPDIR/zoobot/only_for_me/narval/finetune.py

Expand Down
31 changes: 29 additions & 2 deletions zoobot/pytorch/datasets/webdatamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,23 @@

# https://github.com/webdataset/webdataset-lightning/blob/main/train.py
class WebDataModule(pl.LightningDataModule):
def __init__(self, train_urls, val_urls, train_size=None, val_size=None, label_cols=None, batch_size=64, num_workers=4, prefetch_factor=4, cache_dir=None):
def __init__(
self,
train_urls,
val_urls,
label_cols=None,
train_size=None,
val_size=None,
# hardware
batch_size=64,
num_workers=4,
prefetch_factor=4,
cache_dir=None,
color=False,
crop_scale_bounds=(0.7, 0.8),
crop_ratio_bounds=(0.9, 1.1),
resize_after_crop=224
):
super().__init__()

# if isinstance(train_urls, types.GeneratorType):
Expand All @@ -39,6 +55,12 @@ def __init__(self, train_urls, val_urls, train_size=None, val_size=None, label_c

self.cache_dir = cache_dir

# could use mixin
self.color = color
self.resize_after_crop = resize_after_crop
self.crop_scale_bounds = crop_scale_bounds
self.crop_ratio_bounds = crop_ratio_bounds


logging.info(f'Creating webdatamodule with WORLD_SIZE: {os.environ.get("WORLD_SIZE")}, RANK: {os.environ.get("RANK")}')

Expand All @@ -52,7 +74,12 @@ def make_image_transform(self, mode="train"):
# if mode == "train":
# elif mode == "val":

augmentation_transform = default_transforms() # A.Compose object
augmentation_transform = default_transforms(
crop_scale_bounds=self.crop_scale_bounds,
crop_ratio_bounds=self.crop_ratio_bounds,
resize_after_crop=self.resize_after_crop,
pytorch_greyscale=not self.color
) # A.Compose object
def do_transform(img):
return np.transpose(augmentation_transform(image=np.array(img))["image"], axes=[2, 0, 1]).astype(np.float32)
return do_transform
Expand Down
11 changes: 9 additions & 2 deletions zoobot/pytorch/training/train_with_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,17 @@ def train_default_zoobot_from_scratch(
datamodule = webdatamodule.WebDataModule(
train_urls=train_urls,
val_urls=val_urls,
label_cols=schema.label_cols,
# hardware
batch_size=batch_size,
num_workers=num_workers,
label_cols=schema.label_cols,
cache_dir=cache_dir
prefetch_factor=prefetch_factor,
cache_dir=cache_dir,
# augmentation args
color=color,
crop_scale_bounds=crop_scale_bounds,
crop_ratio_bounds=crop_ratio_bounds,
resize_after_crop=resize_after_crop,
# TODO pass through the rest
)

Expand Down

0 comments on commit e142683

Please sign in to comment.