Skip to content

Commit

Permalink
fix: 🚧 Subclass torch dataset Subset, hunting validation blocks bug
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed May 13, 2024
1 parent 036fc49 commit 1c79c79
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/cellmap_data/multidataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def get_indices(self, chunk_size: dict[str, int]) -> Sequence[int]:
index_offset = 0
for dataset in self.datasets:
indices.append(dataset.get_indices(chunk_size))
index_offset += len(dataset)
index_offset += len(dataset) - 1
return indices

def set_raw_value_transforms(self, transforms: Callable):
Expand Down
19 changes: 8 additions & 11 deletions src/cellmap_data/subdataset.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
from typing import Callable
from torch.utils.data import Dataset
from typing import Callable, Sequence
from torch.utils.data import Subset
from .dataset import CellMapDataset


class CellMapSubset(Dataset):
def __init__(self, dataset, indices):
super().__init__()
self.dataset = dataset
self.indices = indices
class CellMapSubset(Subset):

def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
dataset: CellMapDataset
indices: Sequence[int]

def __len__(self):
return len(self.indices)
def __init__(self, dataset: CellMapDataset, indices: Sequence[int]) -> None:
super().__init__(dataset, indices)

@property
def classes(self):
Expand Down

0 comments on commit 1c79c79

Please sign in to comment.