From e1426837be7374d5babcfb82d4de4474c4b3e000 Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Wed, 22 Nov 2023 13:54:24 -0500 Subject: [PATCH] pass aug args add timm kwargs option also need to clone timm (temp?) --- only_for_me/narval/make_webdataset_script.py | 20 +++++++----- only_for_me/narval/train.py | 14 ++++++--- only_for_me/narval/train.sh | 2 +- zoobot/pytorch/datasets/webdatamodule.py | 31 +++++++++++++++++-- .../training/train_with_pytorch_lightning.py | 11 +++++-- 5 files changed, 61 insertions(+), 17 deletions(-) diff --git a/only_for_me/narval/make_webdataset_script.py b/only_for_me/narval/make_webdataset_script.py index 67b37e72..efd33e91 100644 --- a/only_for_me/narval/make_webdataset_script.py +++ b/only_for_me/narval/make_webdataset_script.py @@ -62,19 +62,23 @@ def main(): # desi pipeline shreds sources. Be careful to deduplicate. columns = ['id_str'] + label_cols - votes = pd.concat([ - pd.read_parquet(f'/media/walml/beta/galaxy_zoo/decals/dr8/catalogs/training_catalogs/{campaign}_ortho_v5_labelled_catalog.parquet', columns=columns) - for campaign in ['dr12', 'dr5', 'dr8'] - ], axis=0) - assert votes['id_str'].value_counts().max() == 1, votes['id_str'].value_counts() - votes['dr8_id'] = votes['id_str'] - df = pd.merge(df, votes[['dr8_id']], on='dr8_id', how='inner') + # votes = pd.concat([ + # pd.read_parquet(f'/media/walml/beta/galaxy_zoo/decals/dr8/catalogs/training_catalogs/{campaign}_ortho_v5_labelled_catalog.parquet', columns=columns) + # for campaign in ['dr12', 'dr5', 'dr8'] + # ], axis=0) + # assert votes['id_str'].value_counts().max() == 1, votes['id_str'].value_counts() + # votes['dr8_id'] = votes['id_str'] + + # name = 'labelled' + # merge_strategy = {'labelled': 'inner', 'all': 'left'} + # df = pd.merge(df, votes[['dr8_id']], on='dr8_id', how=merge_strategy[name]) df['relative_file_loc'] = df.apply(lambda x: f"{x['brickid']}/{x['brickid']}_{x['objid']}.jpg", axis=1) df['file_loc'] = '/home/walml/data/desi/jpg/' + df['relative_file_loc'] df_dedup = remove_close_sky_matches(df) print(len(df_dedup)) + df_dedup.to_parquet('/home/walml/data/desi/master_all_file_index_all_dedup_20arcsec.parquet') exit() # df_dedup2 = remove_close_sky_matches(df_dedup) # print(len(df_dedup2)) @@ -103,6 +107,7 @@ def remove_close_sky_matches(df, seplimit=20*u.arcsec, col_to_prioritise='ra'): search_coords = catalog + logging.info('Beginning search for nearby galaxies') idxc, idxcatalog, d2d, _ = catalog.search_around_sky(search_coords, seplimit=seplimit) # idxc is index in search coords # idxcatalog is index in catalog @@ -114,6 +119,7 @@ def remove_close_sky_matches(df, seplimit=20*u.arcsec, col_to_prioritise='ra'): idxcatalog = idxcatalog[d2d > 0] d2d = d2d[d2d > 0] + logging.info('Beginning drop prioritisation') indices_to_drop = [] for search_index_val in pd.unique(idxc): matched_indices = idxcatalog[idxc == search_index_val] diff --git a/only_for_me/narval/train.py b/only_for_me/narval/train.py index 415bf24f..e9de4a0c 100644 --- a/only_for_me/narval/train.py +++ b/only_for_me/narval/train.py @@ -21,15 +21,16 @@ See zoobot/pytorch/examples/minimal_examples.py for a friendlier example """ parser = argparse.ArgumentParser() - parser.add_argument('--save-dir', dest='save_dir', type=str) + parser.add_argument('--save-dir', dest='save_dir', type=str, default='local_debug') # parser.add_argument('--data-dir', dest='data_dir', type=str) # parser.add_argument('--dataset', dest='dataset', type=str, help='dataset to use, either "gz_decals_dr5" or "gz_evo"') parser.add_argument('--architecture', dest='architecture_name', default='efficientnet_b0', type=str) parser.add_argument('--resize-after-crop', dest='resize_after_crop', type=int, default=224) parser.add_argument('--color', default=False, action='store_true') + parser.add_argument('--compile-encoder', dest='compile_encoder', default=False, action='store_true') parser.add_argument('--batch-size', dest='batch_size', - default=256, type=int) + default=16, type=int) parser.add_argument('--num-features', dest='num_features', default=1280, type=int) parser.add_argument('--gpus', dest='gpus', default=1, type=int) @@ -62,10 +63,13 @@ # logging.info([(x, y) for (x, y) in os.environ.items() if 'SLURM' in x]) if os.path.isdir('/home/walml/repos/zoobot'): - search_str = '/home/walml/repos/zoobot/gz_decals_5_train_*.tar' + logging.warning('local mode') + search_str = '/home/walml/data/wds/desi_labelled_2048/desi_labelled_train_*.tar' + cache_dir = None else: search_str = '/home/walml/projects/def-bovy/walml/data/webdatasets/desi_labelled_2048/desi_labelled_train_*.tar' + cache_dir = os.environ['SLURM_TMPDIR'] + '/cache' all_urls = glob.glob(search_str) assert len(all_urls) > 0, search_str @@ -115,10 +119,10 @@ wandb_logger=wandb_logger, prefetch_factor=1, # TODO num_workers=args.num_workers, - compile_encoder=True, # NEW + compile_encoder=args.compile_encoder, # NEW random_state=random_state, learning_rate=1e-3, - cache_dir=os.environ['SLURM_TMPDIR'] + '/cache' + cache_dir=cache_dir # cache_dir='/tmp/cache' # /tmp for ramdisk (400GB total, vs 4TB total for nvme) ) diff --git a/only_for_me/narval/train.sh b/only_for_me/narval/train.sh index 08689882..e4e217d6 100644 --- a/only_for_me/narval/train.sh +++ b/only_for_me/narval/train.sh @@ -26,7 +26,7 @@ srun $PYTHON $REPO_DIR/only_for_me/narval/train.py \ --num-features 128 \ --gpus 2 \ --num-workers 10 \ - --color --wandb --mixed-precision + --color --wandb --mixed-precision --compile-encoder # srun python $SLURM_TMPDIR/zoobot/only_for_me/narval/finetune.py diff --git a/zoobot/pytorch/datasets/webdatamodule.py b/zoobot/pytorch/datasets/webdatamodule.py index bec1c7f4..8eec8071 100644 --- a/zoobot/pytorch/datasets/webdatamodule.py +++ b/zoobot/pytorch/datasets/webdatamodule.py @@ -12,7 +12,23 @@ # https://github.com/webdataset/webdataset-lightning/blob/main/train.py class WebDataModule(pl.LightningDataModule): - def __init__(self, train_urls, val_urls, train_size=None, val_size=None, label_cols=None, batch_size=64, num_workers=4, prefetch_factor=4, cache_dir=None): + def __init__( + self, + train_urls, + val_urls, + label_cols=None, + train_size=None, + val_size=None, + # hardware + batch_size=64, + num_workers=4, + prefetch_factor=4, + cache_dir=None, + color=False, + crop_scale_bounds=(0.7, 0.8), + crop_ratio_bounds=(0.9, 1.1), + resize_after_crop=224 + ): super().__init__() # if isinstance(train_urls, types.GeneratorType): @@ -39,6 +55,12 @@ def __init__(self, train_urls, val_urls, train_size=None, val_size=None, label_c self.cache_dir = cache_dir + # could use mixin + self.color = color + self.resize_after_crop = resize_after_crop + self.crop_scale_bounds = crop_scale_bounds + self.crop_ratio_bounds = crop_ratio_bounds + logging.info(f'Creating webdatamodule with WORLD_SIZE: {os.environ.get("WORLD_SIZE")}, RANK: {os.environ.get("RANK")}') @@ -52,7 +74,12 @@ def make_image_transform(self, mode="train"): # if mode == "train": # elif mode == "val": - augmentation_transform = default_transforms() # A.Compose object + augmentation_transform = default_transforms( + crop_scale_bounds=self.crop_scale_bounds, + crop_ratio_bounds=self.crop_ratio_bounds, + resize_after_crop=self.resize_after_crop, + pytorch_greyscale=not self.color + ) # A.Compose object def do_transform(img): return np.transpose(augmentation_transform(image=np.array(img))["image"], axes=[2, 0, 1]).astype(np.float32) return do_transform diff --git a/zoobot/pytorch/training/train_with_pytorch_lightning.py b/zoobot/pytorch/training/train_with_pytorch_lightning.py index e46133c1..3f17fd65 100644 --- a/zoobot/pytorch/training/train_with_pytorch_lightning.py +++ b/zoobot/pytorch/training/train_with_pytorch_lightning.py @@ -234,10 +234,17 @@ def train_default_zoobot_from_scratch( datamodule = webdatamodule.WebDataModule( train_urls=train_urls, val_urls=val_urls, + label_cols=schema.label_cols, + # hardware batch_size=batch_size, num_workers=num_workers, - label_cols=schema.label_cols, - cache_dir=cache_dir + prefetch_factor=prefetch_factor, + cache_dir=cache_dir, + # augmentation args + color=color, + crop_scale_bounds=crop_scale_bounds, + crop_ratio_bounds=crop_ratio_bounds, + resize_after_crop=resize_after_crop, # TODO pass through the rest )