From 1c79c79d488cad4d55d91b9b632acdf8020089cf Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 13 May 2024 14:27:02 -0400 Subject: [PATCH] =?UTF-8?q?fix:=20=F0=9F=9A=A7=20Subclass=20torch=20datase?= =?UTF-8?q?t=20Subset,=20hunting=20validation=20blocks=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/cellmap_data/multidataset.py | 2 +- src/cellmap_data/subdataset.py | 19 ++++++++----------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/cellmap_data/multidataset.py b/src/cellmap_data/multidataset.py index f1b604d..e65c330 100644 --- a/src/cellmap_data/multidataset.py +++ b/src/cellmap_data/multidataset.py @@ -185,7 +185,7 @@ def get_indices(self, chunk_size: dict[str, int]) -> Sequence[int]: index_offset = 0 for dataset in self.datasets: indices.append(dataset.get_indices(chunk_size)) - index_offset += len(dataset) + index_offset += len(dataset) - 1 return indices def set_raw_value_transforms(self, transforms: Callable): diff --git a/src/cellmap_data/subdataset.py b/src/cellmap_data/subdataset.py index d08781e..dbba10e 100644 --- a/src/cellmap_data/subdataset.py +++ b/src/cellmap_data/subdataset.py @@ -1,18 +1,15 @@ -from typing import Callable -from torch.utils.data import Dataset +from typing import Callable, Sequence +from torch.utils.data import Subset +from .dataset import CellMapDataset -class CellMapSubset(Dataset): - def __init__(self, dataset, indices): - super().__init__() - self.dataset = dataset - self.indices = indices +class CellMapSubset(Subset): - def __getitem__(self, idx): - return self.dataset[self.indices[idx]] + dataset: CellMapDataset + indices: Sequence[int] - def __len__(self): - return len(self.indices) + def __init__(self, dataset: CellMapDataset, indices: Sequence[int]) -> None: + super().__init__(dataset, indices) @property def classes(self):