Skip to content

Commit

Permalink
prediction adjustments
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Nov 23, 2023
1 parent 7d81c92 commit 751b5d6
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 99 deletions.
14 changes: 7 additions & 7 deletions only_for_me/narval/make_webdataset_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def catalogs_to_webdataset(dataset_name, label_cols, train_catalog, test_catalog

save_loc = f"/home/walml/data/wds/{dataset_name}/{dataset_name}_{catalog_name}.tar" # .tar replace automatically

webdataset_utils.df_to_wds(catalog, label_cols, save_loc, n_shards=n_shards, sparse_label_df=sparse_label_df)
webdataset_utils.df_to_wds(catalog, label_cols, save_loc, n_shards=n_shards, sparse_label_df=sparse_label_df, overwrite=False)

# webdataset_utils.load_wds_directly(save_loc)

Expand All @@ -53,8 +53,8 @@ def main():


# for converting other catalogs e.g. DESI
dataset_name = 'desi_labelled_300px_2048'
# dataset_name = 'desi_all_2048'
# dataset_name = 'desi_labelled_300px_2048'
dataset_name = 'desi_all_300px_2048'
label_cols = label_metadata.decals_all_campaigns_ortho_label_cols
columns = [
'dr8_id', 'brickid', 'objid', 'ra', 'dec'
Expand Down Expand Up @@ -85,8 +85,8 @@ def main():
# print(len(df_dedup2))
# df_dedup.to_parquet('/home/walml/data/desi/master_all_file_index_labelled_dedup_20arcsec.parquet')

df_dedup = pd.read_parquet('/home/walml/data/desi/master_all_file_index_labelled_dedup_20arcsec.parquet')
# df_dedup = pd.read_parquet('/home/walml/data/desi/master_all_file_index_all_dedup_20arcsec.parquet')
# df_dedup = pd.read_parquet('/home/walml/data/desi/master_all_file_index_labelled_dedup_20arcsec.parquet')
df_dedup = pd.read_parquet('/home/walml/data/desi/master_all_file_index_all_dedup_20arcsec.parquet')
df_dedup['id_str'] = df_dedup['dr8_id']

# columns = ['id_str', 'smooth-or-featured-dr12_total-votes', 'smooth-or-featured-dr5_total-votes', 'smooth-or-featured-dr8_total-votes']
Expand All @@ -95,8 +95,8 @@ def main():
# df_dedup_with_votes = pd.merge(df_dedup, votes, how='left', on='dr8_id')

train_catalog, test_catalog = train_test_split(df_dedup, test_size=0.2, random_state=42)
train_catalog.to_parquet('/home/walml/data/wds/desi_labelled_300px_2048/train_catalog_v1.parquet', index=False)
test_catalog.to_parquet('/home/walml/data/wds/desi_labelled_300px_2048/test_catalog_v1.parquet', index=False)
train_catalog.to_parquet(f'/home/walml/data/wds/{dataset_name}/train_catalog_v1.parquet', index=False)
test_catalog.to_parquet(f'/home/walml/data/wds/{dataset_name}/test_catalog_v1.parquet', index=False)

catalogs_to_webdataset(dataset_name, label_cols, train_catalog, test_catalog, divisor=2048, sparse_label_df=votes)

Expand Down
2 changes: 1 addition & 1 deletion only_for_me/narval/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

if os.path.isdir('/home/walml/repos/zoobot'):
logging.warning('local mode')
search_str = '/home/walml/data/wds/desi_labelled_300px_2048/desi_labelled_train_*.tar'
search_str = '/home/walml/data/wds/desi_labelled_300px_2048/desi_labelled_300px_2048_train_*.tar'
cache_dir = None

else:
Expand Down
132 changes: 53 additions & 79 deletions zoobot/pytorch/datasets/webdatamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
class WebDataModule(pl.LightningDataModule):
def __init__(
self,
train_urls,
val_urls,
train_urls=None,
val_urls=None,
test_urls=None,
predict_urls=None,
label_cols=None,
train_size=None,
val_size=None,
# hardware
batch_size=64,
num_workers=4,
Expand All @@ -31,21 +31,20 @@ def __init__(
):
super().__init__()

# if isinstance(train_urls, types.GeneratorType):
# train_urls = list(train_urls)
# if isinstance(val_urls, types.GeneratorType):
# val_urls = list(val_urls)
self.train_urls = train_urls
self.val_urls = val_urls
self.test_urls = test_urls
self.predict_urls = predict_urls

if train_size is None:
if train_urls is not None:
# assume the size of each shard is encoded in the filename as ..._{size}.tar
train_size = sum([int(url.rstrip('.tar').split('_')[-1]) for url in train_urls])
if val_size is None:
val_size = sum([int(url.rstrip('.tar').split('_')[-1]) for url in val_urls])

self.train_size = train_size
self.val_size = val_size
self.train_size = interpret_dataset_size_from_urls(train_urls)
if val_urls is not None:
self.val_size = interpret_dataset_size_from_urls(val_urls)
if test_urls is not None:
self.test_size = interpret_dataset_size_from_urls(test_urls)
if predict_urls is not None:
self.predict_size = interpret_dataset_size_from_urls(predict_urls)

self.label_cols = label_cols

Expand All @@ -61,18 +60,14 @@ def __init__(
self.crop_scale_bounds = crop_scale_bounds
self.crop_ratio_bounds = crop_ratio_bounds

for url_name in ['train', 'val', 'test', 'predict']:
urls = getattr(self, f'{url_name}_urls')
if urls is not None:
logging.info(f"{url_name} (before hardware splits) = {len(urls)} e.g. {urls[0]}", )

logging.info(f'Creating webdatamodule with WORLD_SIZE: {os.environ.get("WORLD_SIZE")}, RANK: {os.environ.get("RANK")}')

logging.info(f"train_urls (before hardware splits) = {len(self.train_urls)} e.g. {self.train_urls[0]}", )
logging.info(f"val_urls (before hardware splits) = {len(self.val_urls)} e.g. {self.val_urls[0]}", )
# logging.info("train_size (before hardware splits) = ", self.train_size)
# logging.info("val_size (before hardware splits) = ", self.val_size)
logging.info(f"batch_size: {self.batch_size}, num_workers: {self.num_workers}")

def make_image_transform(self, mode="train"):
# if mode == "train":
# elif mode == "val":

augmentation_transform = transforms.default_transforms(
crop_scale_bounds=self.crop_scale_bounds,
Expand Down Expand Up @@ -102,11 +97,11 @@ def label_transform(label_dict):


def make_loader(self, urls, mode="train"):
dataset_size = getattr(self, f'{mode}_size')
if mode == "train":
dataset_size = self.train_size
shuffle = min(self.train_size, 5000)
elif mode == "val":
dataset_size = self.val_size
shuffle = min(dataset_size, 5000)
else:
assert mode in ['val', 'test', 'predict'], mode
shuffle = 0

transform_image = self.make_image_transform(mode=mode)
Expand All @@ -120,21 +115,20 @@ def make_loader(self, urls, mode="train"):
)
.shuffle(shuffle)
.decode("rgb")
.to_tuple('image.jpg', 'labels.json')
.map_tuple(transform_image, transform_label)
# torch collate stacks dicts nicely while webdataset only lists them
# so use the torch collate instead
.batched(self.batch_size, torch.utils.data.default_collate, partial=False)
# .repeat(5)
)
if mode == 'predict':
# dataset = dataset.extract_keys('image.jpg').map(transform_image)
dataset = dataset.to_tuple('image.jpg').map_tuple(transform_image) # (im,) tuple. But map applied to all elements
# .map(get_first)
else:
dataset = (
dataset.to_tuple('image.jpg', 'labels.json')
.map_tuple(transform_image, transform_label)
)

# from itertools import islice
# for batch in islice(dataset, 0, 3):
# images, labels = batch
# # print(len(sample))
# print(images.shape)
# print(len(labels)) # list of dicts
# # exit()
# torch collate stacks dicts nicely while webdataset only lists them
# so use the torch collate instead
dataset = dataset.batched(self.batch_size, torch.utils.data.default_collate, partial=False)

loader = wds.WebLoader(
dataset,
Expand All @@ -145,17 +139,13 @@ def make_loader(self, urls, mode="train"):
prefetch_factor=self.prefetch_factor
)

# print('sampling')
# for sample in islice(loader, 0, 3):
# images, labels = sample
# print(images.shape)
# print(len(labels)) # list of dicts
# exit()

loader.length = dataset_size // self.batch_size

# temp hack instead
assert dataset_size % self.batch_size == 0, (dataset_size, self.batch_size, dataset_size % self.batch_size)
if mode in ['train', 'val']:
assert dataset_size % self.batch_size == 0, (dataset_size, self.batch_size, dataset_size % self.batch_size)
# for test/predict, always single GPU anyway

# if mode == "train":
# ensure same number of batches in all clients
# loader = loader.ddp_equalize(dataset_size // self.batch_size)
Expand All @@ -168,32 +158,14 @@ def train_dataloader(self):

def val_dataloader(self):
return self.make_loader(self.val_urls, mode="val")

# @staticmethod
# def add_loader_specific_args(parser):
# parser.add_argument("-b", "--batch-size", type=int, default=128)
# parser.add_argument("--workers", type=int, default=6)
# parser.add_argument("--bucket", default="./shards")
# parser.add_argument("--shards", default="imagenet-train-{000000..001281}.tar")
# parser.add_argument("--valshards", default="imagenet-val-{000000..000006}.tar")
# return parser

# def nodesplitter_func(urls): # SimpleShardList
# # print(urls)
# try:
# node_id, node_count = torch.distributed.get_rank(), torch.distributed.get_world_size()
# urls_to_use = list(urls)[node_id::node_count]
# logging.info(f'id: {node_id}, of count {node_count}. \nURLS: {len(urls_to_use)} of {len(urls)} ({urls_to_use})\n\n')
# return urls_to_use
# except RuntimeError:
# # print('Distributed not initialised. Hopefully single node.')
# return urls

def predict_dataloader(self):
return self.make_loader(self.predict_urls, mode="predict")

def identity(x):
return x

def nodesplitter_func(urls):
# num_urls = len(list(urls.copy()))
urls_to_use = list(wds.split_by_node(urls)) # rely on WDS for the hard work
rank, world_size, worker, num_workers = wds.utils.pytorch_worker_info()
logging.info(
Expand All @@ -205,14 +177,16 @@ def nodesplitter_func(urls):
)
return urls_to_use

def interpret_shard_size_from_url(url):
return int(url.rstrip('.tar').split('_')[-1])

def interpret_dataset_size_from_urls(urls):
return sum([interpret_shard_size_from_url(url) for url in urls])

def get_first(x):
return x[0]

# def split_by_worker(urls):
# rank, world_size, worker, num_workers = wds.utils.pytorch_worker_info()
# if num_workers > 1:
# logging.info(f'Slicing urls for rank {rank}, world_size {world_size}, worker {worker}')
# for s in islice(urls, worker, None, num_workers):
# yield s
# else:
# logging.warning('only one worker?!')
# for s in urls:
# yield s
def custom_collate(x):
if isinstance(x, list) and len(x) == 1:
x = x[0]
return torch.utils.data.default_collate(x)
13 changes: 7 additions & 6 deletions zoobot/pytorch/datasets/webdataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def make_mock_wds(save_dir: str, label_cols: List, n_shards: int, shard_size: in



def df_to_wds(df: pd.DataFrame, label_cols, save_loc: str, n_shards: int, sparse_label_df=None):
def df_to_wds(df: pd.DataFrame, label_cols, save_loc: str, n_shards: int, sparse_label_df=None, overwrite=False):

assert '.tar' in save_loc
df['id_str'] = df['id_str'].astype(str).str.replace('.', '_')
Expand Down Expand Up @@ -85,11 +85,12 @@ def df_to_wds(df: pd.DataFrame, label_cols, save_loc: str, n_shards: int, sparse
if sparse_label_df is not None:
shard_df = pd.merge(shard_df, sparse_label_df, how='left', validate='one_to_one', suffixes=('', '_badlabelmerge')) # auto-merge
shard_save_loc = save_loc.replace('.tar', f'_{shard_n}_{len(shard_df)}.tar')
logging.info(shard_save_loc)
sink = wds.TarWriter(shard_save_loc)
for _, galaxy in shard_df.iterrows():
sink.write(galaxy_to_wds(galaxy, label_cols, transform=transform))
sink.close()
if overwrite or not(os.path.isfile(shard_save_loc)):
logging.info(shard_save_loc)
sink = wds.TarWriter(shard_save_loc)
for _, galaxy in shard_df.iterrows():
sink.write(galaxy_to_wds(galaxy, label_cols, transform=transform))
sink.close()


def galaxy_to_wds(galaxy: pd.Series, label_cols, transform=None):
Expand Down
5 changes: 5 additions & 0 deletions zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ def on_test_batch_end(self, outputs, *args):


def predict_step(self, batch, batch_idx, dataloader_idx=0):
# I can't work out how to get webdataset to return a single item im, not a tuple (im,).
# this is fine for training but annoying for predict
# help welcome. meanwhile, this works around it
if isinstance(batch, list) and len(batch) == 1:
return self(batch[0])
# https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#inference
# this calls forward, while avoiding the need for e.g. model.eval(), torch.no_grad()
# x, y = batch # would be usual format, but here, batch does not include labels
Expand Down
16 changes: 12 additions & 4 deletions zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# https://github.com/inigoval/finetune/blob/main/finetune.py
import logging
import os
from typing import Any
import warnings
from functools import partial

Expand Down Expand Up @@ -182,10 +183,6 @@ def configure_optimizers(self):
"lr": lr * (self.lr_decay**i)
})

# TODO this actually breaks training because the generator only iterates once!
# total_params = sum(p.numel() for param_set in params.copy() for p in param_set['params'])
# logging.info('Total params to fit: {}'.format(total_params))

# Initialize AdamW optimizer
opt = torch.optim.AdamW(params, weight_decay=self.weight_decay) # lr included in params dict

Expand Down Expand Up @@ -219,6 +216,14 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):

def test_step(self, batch, batch_idx, dataloader_idx=0):
return self.make_step(batch)

def predict_step(self, batch, batch_idx) -> Any:
# I can't work out how to get webdataset to return a single item im, not a tuple (im,).
# this is fine for training but annoying for predict
# help welcome. meanwhile, this works around it
if isinstance(batch, list) and len(batch) == 1:
return self(batch[0])
return self(batch)

def on_train_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx=0):
# v2 docs currently do not show dataloader_idx as train argument so unclear if this will value be updated properly
Expand Down Expand Up @@ -355,6 +360,9 @@ def on_test_batch_end(self, step_output, *args) -> None:


def predict_step(self, x, batch_idx):
# see Abstract version
if isinstance(x, list) and len(x) == 1:
return self(x[0])
x = self.forward(x) # logits from LinearClassifier
# then applies softmax
return F.softmax(x, dim=1)
Expand Down
4 changes: 2 additions & 2 deletions zoobot/pytorch/training/train_with_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def train_default_zoobot_from_scratch(
datamodule = webdatamodule.WebDataModule(
train_urls=train_urls,
val_urls=val_urls,
test_urls=test_urls,
label_cols=schema.label_cols,
# hardware
batch_size=batch_size,
Expand All @@ -245,8 +246,7 @@ def train_default_zoobot_from_scratch(
color=color,
crop_scale_bounds=crop_scale_bounds,
crop_ratio_bounds=crop_ratio_bounds,
resize_after_crop=resize_after_crop,
# TODO pass through the rest
resize_after_crop=resize_after_crop
)

datamodule.setup(stage='fit')
Expand Down

0 comments on commit 751b5d6

Please sign in to comment.