Skip to content

Commit

Permalink
feat: ⚡️ Add weighted random sampling.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Apr 1, 2024
1 parent 2dbf137 commit 85fd727
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 80 deletions.
2 changes: 1 addition & 1 deletion src/cellmap_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
__author__ = "Jeff Rhoades"
__email__ = "[email protected]"

from .dataloader import CellMapDataLoader
from .multidataset import CellMapMultiDataset
from .dataloader import CellMapDataLoader
from .datasplit import CellMapDataSplit
from .dataset import CellMapDataset
from .image import CellMapImage
81 changes: 51 additions & 30 deletions src/cellmap_data/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,72 @@
from torch.utils.data import DataLoader

import torch
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
from .dataset import CellMapDataset
from .multidataset import CellMapMultiDataset
from .datasplit import CellMapDataSplit

from typing import Callable, Iterable, Optional


class CellMapDataLoader:
# TODO: This class may be unnecessary
# TODO: docstring corrections
"""This subclasses PyTorch DataLoader to load CellMap data for training. It maintains the same API as the DataLoader class. This includes applying augmentations to the data and returning the data in the correct format for training, such as generating the target arrays (e.g. signed distance transform of labels). It retrieves raw and groundtruth data from a CellMapDataSplit object, which is a subclass of PyTorch Dataset. Training and validation data are split using the CellMapDataSplit object, and separate dataloaders are maintained as `train_loader` and `validate_loader` respectively."""

datasplit: CellMapDataSplit
train_datasets: CellMapMultiDataset
validate_datasets: CellMapMultiDataset
train_loader: DataLoader
validate_loader: DataLoader
dataset: CellMapMultiDataset | CellMapDataset
classes: Iterable[str]
loader = DataLoader
batch_size: int
num_workers: int
weighted_sampler: bool
is_train: bool
rng: Optional[torch.Generator] = None

def __init__(
self,
datasplit: CellMapDataSplit,
batch_size: int,
num_workers: int,
is_train: bool,
dataset: CellMapMultiDataset | CellMapDataset,
classes: Iterable[str],
batch_size: int = 1,
num_workers: int = 0,
weighted_sampler: bool = False,
is_train: bool = True,
rng: Optional[torch.Generator] = None,
):
self.datasplit = datasplit
self.dataset = dataset
self.classes = classes
self.batch_size = batch_size
self.num_workers = num_workers
self.weighted_sampler = weighted_sampler
self.is_train = is_train
self.rng = rng
self.construct()

# TODO: could keep dataloaders separate

def construct(self):
self.train_datasets = self.datasplit.train_datasets_combined
self.validate_datasets = self.datasplit.validate_datasets_combined
self.train_loader = DataLoader(
self.train_datasets,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=self.is_train,
)
self.validate_loader = DataLoader(
self.validate_datasets,
batch_size=1,
num_workers=self.num_workers,
shuffle=False,
)
if self.weighted_sampler:
assert isinstance(
self.dataset, CellMapMultiDataset
), "Weighted sampler only relevant for CellMapMultiDataset"
self.sampler = self.dataset.weighted_sampler(self.batch_size, self.rng)
else:
self.sampler = None
kwargs = {
"dataset": self.dataset,
"batch_size": self.batch_size,
"num_workers": self.num_workers,
"collate_fn": self.collate_fn,
}
if self.weighted_sampler:
kwargs["sampler"] = self.sampler
elif self.is_train:
kwargs["shuffle"] = True
else:
kwargs["shuffle"] = False
self.loader = DataLoader(**kwargs)

def collate_fn(self, batch):
outputs = {}
for b in batch:
for key, value in b.items():
if key not in outputs:
outputs[key] = []
outputs[key].append(value)
for key, value in outputs.items():
outputs[key] = torch.stack(value)
return outputs
45 changes: 31 additions & 14 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
# %%
import csv
from typing import Callable, Dict, Generator, Sequence, Optional
import math
import os
from typing import Callable, Dict, Sequence, Optional
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import Dataset, get_worker_info
import tensorstore
from fibsem_tools.io.core import read, read_xarray
from .image import CellMapImage, EmptyImage


def split_gt_path(path: str) -> tuple[str, list[str]]:
"""Splits a path to groundtruth data into the main path string, and the classes supplied for it."""
path_prefix, path_rem = path.split("[")
classes, path_suffix = path_rem.split("]")
classes = classes.split(",")
path_string = path_prefix + "{label}" + path_suffix
try:
path_prefix, path_rem = path.split("[")
classes, path_suffix = path_rem.split("]")
classes = classes.split(",")
path_string = path_prefix + "{label}" + path_suffix
except ValueError:
path_string = path
classes = [path.split(os.path.sep)[-1]]
return path_string, classes


Expand Down Expand Up @@ -60,7 +66,7 @@ def __init__(
is_train: bool = False,
axis_order: str = "zyx",
context: Optional[tensorstore.Context] = None, # type: ignore
rng: Optional[Generator] = None,
rng: Optional[np.random.Generator] = None,
):
"""Initializes the CellMapDataset class.
Expand Down Expand Up @@ -91,7 +97,7 @@ def __init__(
gt_value_transforms (Optional[Callable | Sequence[Callable] | dict[str, Callable]], optional): A function to convert the ground truth data to target arrays. Defaults to None. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order.
is_train (bool, optional): Whether the dataset is for training. Defaults to False.
context (Optional[tensorstore.Context], optional): The context for the image data. Defaults to None.
rng (Optional[Generator], optional): A random number generator. Defaults to None.
rng (Optional[np.random.Generator], optional): A random number generator. Defaults to None.
"""
self.raw_path = raw_path
self.gt_paths = gt_path
Expand Down Expand Up @@ -148,11 +154,21 @@ def __getitem__(self, idx):

def __iter__(self):
"""Iterates over the dataset, covering each section of the bounding box. For instance, for calculating validation scores."""
# TODO
raise NotImplementedError
if self._iter_coords is None:
self._iter_coords = ...
yield self.__getitem__(self._iter_coords)
# TODO : determine if this is right
worker_info = get_worker_info()
start = 0
end = len(self) - 1
# single-process data loading, return the full iterator
if worker_info is None:
iter_start = start
iter_end = end
else: # in a worker process
# split workload
per_worker = int(math.ceil((end - start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = start + worker_id * per_worker
iter_end = min(iter_start + per_worker, end)
return iter(range(iter_start, iter_end))

def construct(self):
"""Constructs the input and target sources for the dataset."""
Expand Down Expand Up @@ -333,11 +349,12 @@ def sampling_box_shape(self):
def class_counts(self) -> Dict[str, Dict[str, int]]:
"""Returns the number of pixels for each class in the ground truth data, normalized by the resolution."""
if self._class_counts is None:
class_counts = {}
class_counts = {"totals": {c: 0 for c in self.classes}}
for array_name, sources in self.target_sources.items():
class_counts[array_name] = {}
for label, source in sources.items():
class_counts[array_name][label] = source.class_counts
class_counts["totals"][label] += source.class_counts
self._class_counts = class_counts
return self._class_counts

Expand Down
45 changes: 26 additions & 19 deletions src/cellmap_data/datasplit.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import csv
from typing import Callable, Dict, Iterable, Optional, Sequence
import tensorstore

from .multidataset import CellMapMultiDataset
from .dataset import CellMapDataset
from .multidataset import CellMapMultiDataset


class CellMapDataSplit:
Expand All @@ -20,11 +19,12 @@ class CellMapDataSplit:
validate_datasets: Iterable[CellMapDataset]
train_datasets_combined: CellMapMultiDataset
validate_datasets_combined: CellMapMultiDataset
spatial_transforms: Optional[Sequence[dict[str, any]]]
spatial_transforms: Optional[dict[str, any]] = None
raw_value_transforms: Optional[Callable] = None
gt_value_transforms: Optional[
Callable | Sequence[Callable] | dict[str, Callable]
] = None
force_has_data: bool = False
context: Optional[tensorstore.Context] = None # type: ignore

# TODO: may want different transforms for different arrays
Expand All @@ -34,9 +34,9 @@ def __init__(
target_arrays: dict[str, dict[str, Sequence[int | float]]],
classes: Sequence[str],
datasets: Optional[Dict[str, Iterable[CellMapDataset]]] = None,
dataset_dict: Optional[Dict[str, Dict[str, str]]] = None,
dataset_dict: Optional[Dict[str, Sequence[Dict[str, str]]]] = None,
csv_path: Optional[str] = None,
spatial_transforms: Optional[Sequence[dict[str, any]]] = None,
spatial_transforms: Optional[dict[str, any]] = None,
raw_value_transforms: Optional[Callable] = None,
gt_value_transforms: Optional[
Callable | Sequence[Callable] | dict[str, Callable]
Expand Down Expand Up @@ -71,12 +71,12 @@ def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tenso
"train": Iterable[CellMapDataset],
"validate": Iterable[CellMapDataset],
}. Defaults to None.
dataset_dict (Optional[Dict[str, Dict[str, str]]], optional): A dictionary containing the dataset data. The dictionary should have the following structure:
dataset_dict (Optional[Dict[str, Sequence[Dict[str, str]]]], optional): A dictionary containing the dataset data. The dictionary should have the following structure:
{
"train" | "validate": {
"train" | "validate": [{
"raw": str (path to raw data),
"gt": str (path to ground truth data),
},
}],
...
}. Defaults to None.
csv_path (Optional[str], optional): A path to a csv file containing the dataset data. Defaults to None. Each row in the csv file should have the following structure:
Expand Down Expand Up @@ -115,21 +115,21 @@ def from_csv(self, csv_path):
reader = csv.reader(f)
for row in reader:
if row[0] not in dataset_dict:
dataset_dict[row[0]] = {"raw": [], "gt": []}
dataset_dict[row[0]]["raw"].append(row[1])
dataset_dict[row[0]]["gt"].append(row[2])
dataset_dict[row[0]] = []
dataset_dict[row[0]].append({"raw": row[1], "gt": row[2]})

self.dataset_dict = dataset_dict
self.construct(dataset_dict)

def construct(self, dataset_dict):
self._class_counts = None
self.train_datasets = []
self.validate_datasets = []
for raw, gt in zip(dataset_dict["train"]["raw"], dataset_dict["train"]["gt"]):
for data_paths in dataset_dict["train"]:
self.train_datasets.append(
CellMapDataset(
raw,
gt,
data_paths["raw"],
data_paths["gt"],
self.classes,
self.input_arrays,
self.target_arrays,
Expand All @@ -142,13 +142,11 @@ def construct(self, dataset_dict):

# TODO: probably want larger arrays for validation

for raw, gt in zip(
dataset_dict["validate"]["raw"], dataset_dict["validate"]["gt"]
):
for data_paths in dataset_dict["validate"]:
self.validate_datasets.append(
CellMapDataset(
raw,
gt,
data_paths["raw"],
data_paths["gt"],
self.classes,
self.input_arrays,
self.target_arrays,
Expand All @@ -168,6 +166,15 @@ def construct(self, dataset_dict):
[ds for ds in self.validate_datasets if self.force_has_data or ds.has_data],
)

@property
def class_counts(self):
if self._class_counts is None:
self._class_counts = {
"train": self.train_datasets_combined.class_counts,
"validate": self.validate_datasets_combined.class_counts,
}
return self._class_counts


# Example input arrays:
# {'0_input': {'shape': (90, 90, 90), 'scale': (32, 32, 32)},
Expand Down
2 changes: 1 addition & 1 deletion src/cellmap_data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def __getitem__(self, center: dict[str, float]) -> torch.Tensor:
"""Returns image data centered around the given point, based on the scale and shape of the target output image."""
return self.store

def set_spatial_transforms(self, transforms: dict[str, any]):
def set_spatial_transforms(self, transforms: dict[str, any] | None):
pass

@property
Expand Down
Loading

0 comments on commit 85fd727

Please sign in to comment.