diff --git a/src/cellmap_data/multidataset.py b/src/cellmap_data/multidataset.py index f1b604d..e65c330 100644 --- a/src/cellmap_data/multidataset.py +++ b/src/cellmap_data/multidataset.py @@ -185,7 +185,7 @@ def get_indices(self, chunk_size: dict[str, int]) -> Sequence[int]: index_offset = 0 for dataset in self.datasets: indices.append(dataset.get_indices(chunk_size)) - index_offset += len(dataset) + index_offset += len(dataset) - 1 return indices def set_raw_value_transforms(self, transforms: Callable): diff --git a/src/cellmap_data/subdataset.py b/src/cellmap_data/subdataset.py index d08781e..dbba10e 100644 --- a/src/cellmap_data/subdataset.py +++ b/src/cellmap_data/subdataset.py @@ -1,18 +1,15 @@ -from typing import Callable -from torch.utils.data import Dataset +from typing import Callable, Sequence +from torch.utils.data import Subset +from .dataset import CellMapDataset -class CellMapSubset(Dataset): - def __init__(self, dataset, indices): - super().__init__() - self.dataset = dataset - self.indices = indices +class CellMapSubset(Subset): - def __getitem__(self, idx): - return self.dataset[self.indices[idx]] + dataset: CellMapDataset + indices: Sequence[int] - def __len__(self): - return len(self.indices) + def __init__(self, dataset: CellMapDataset, indices: Sequence[int]) -> None: + super().__init__(dataset, indices) @property def classes(self):