Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revision datamodules...CityscapesDataModule #956

Draft
wants to merge 37 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
964501b
review cityscapes datamodule
lijm1358 Dec 28, 2022
67debd9
add color target type test
lijm1358 Dec 28, 2022
be4464f
Merge branch 'master' into cityscape_review
Borda May 18, 2023
bd23c27
update mergify team
Borda May 19, 2023
7e93ecd
Merge branch 'master' into cityscape_review
Borda May 19, 2023
6551849
Merge branch 'master' into cityscape_review
Borda May 19, 2023
dd83ea7
Merge branch 'master' into cityscape_review
mergify[bot] May 19, 2023
3a0124f
Merge branch 'master' into cityscape_review
mergify[bot] May 20, 2023
352df8c
Merge branch 'master' into cityscape_review
mergify[bot] May 20, 2023
f5c242b
Merge branch 'master' into cityscape_review
mergify[bot] May 20, 2023
0ee5d5f
Merge branch 'master' into cityscape_review
mergify[bot] May 20, 2023
4ecf896
Merge branch 'master' into cityscape_review
mergify[bot] May 20, 2023
bba1075
Merge branch 'master' into cityscape_review
mergify[bot] May 20, 2023
8817f1d
Merge branch 'master' into cityscape_review
mergify[bot] May 20, 2023
30793e2
Merge branch 'master' into cityscape_review
mergify[bot] May 20, 2023
ac6739c
Merge branch 'master' into cityscape_review
mergify[bot] May 21, 2023
54a57f5
Merge branch 'master' into cityscape_review
mergify[bot] May 21, 2023
60059d8
Merge branch 'master' into cityscape_review
mergify[bot] May 22, 2023
72a9f3b
Merge branch 'master' into cityscape_review
mergify[bot] May 22, 2023
a12efde
Merge branch 'master' into cityscape_review
mergify[bot] May 29, 2023
68d560b
change test file name to id based name
lijm1358 May 29, 2023
1f7169e
Revert "change test file name to id based name"
lijm1358 May 29, 2023
5d99773
parametrize
Borda May 30, 2023
54b168f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2023
ba943da
for loader
Borda May 30, 2023
0bf6db8
Merge branch 'cityscape_review' of https://github.com/lijm1358/lightn…
Borda May 30, 2023
d65f540
loops
Borda May 30, 2023
67b646e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2023
4fb2fbb
Merge branch 'master' into cityscape_review
mergify[bot] May 30, 2023
edc49be
Merge branch 'master' into cityscape_review
Borda May 31, 2023
334099c
Merge branch 'master' into cityscape_review
mergify[bot] May 31, 2023
b4b01b9
Merge branch 'master' into cityscape_review
Borda May 31, 2023
7d3a188
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2023
36c5777
fix typing, ruff error
lijm1358 Jun 1, 2023
f302db1
Merge branch 'master' into cityscape_review
mergify[bot] Jun 12, 2023
59e1172
Merge branch 'master' into cityscape_review
mergify[bot] Jun 16, 2023
b755d9c
Merge branch 'master' into cityscape_review
mergify[bot] Jun 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions src/pl_bolts/datamodules/cityscapes_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from torch.utils.data import DataLoader

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -14,7 +13,6 @@
warn_missing_pkg("torchvision")


@under_review()
class CityscapesDataModule(LightningDataModule):
"""
.. figure:: https://www.cityscapes-dataset.com/wordpress/wp-content/uploads/2015/07/muenster00-1024x510.png
Expand Down Expand Up @@ -85,7 +83,7 @@ def __init__(
data_dir: where to load the data from path, i.e. where directory leftImg8bit and gtFine or gtCoarse
are located
quality_mode: the quality mode to use, either 'fine' or 'coarse'
target_type: targets to use, either 'instance' or 'semantic'
target_type: targets to use, can be 'instance', 'semantic', 'color', or 'polygon'.
num_workers: how many workers to use for loading data
batch_size: number of examples per training/eval step
seed: random seed to be used for train/val/test splits
Expand All @@ -101,8 +99,10 @@ def __init__(
"You want to use CityScapes dataset loaded from `torchvision` which is not installed yet."
)

if target_type not in ["instance", "semantic"]:
raise ValueError(f'Only "semantic" and "instance" target types are supported. Got {target_type}.')
if target_type not in ["instance", "semantic", "color", "polygon"]:
raise ValueError(
f'Only "instance", "semantic", "color", "polygon" target types are supported. Got {target_type}.'
)

self.dims = (3, 1024, 2048)
self.data_dir = data_dir
Expand All @@ -121,10 +121,7 @@ def __init__(

@property
def num_classes(self) -> int:
"""
Return:
30
"""
"""Returns the number of classes."""
return 30

def train_dataloader(self) -> DataLoader:
Expand All @@ -151,6 +148,33 @@ def train_dataloader(self) -> DataLoader:
pin_memory=self.pin_memory,
)

def train_extra_dataloader(self) -> DataLoader:
"""Cityscapes extra train dataset.

Only supported in coarse quality mode.
"""
transforms = self.train_transforms or self._default_transforms()
target_transforms = self.target_transforms or self._default_target_transforms()

dataset = Cityscapes(
self.data_dir,
split="train_extra",
target_type=self.target_type,
mode=self.quality_mode,
transform=transforms,
target_transform=target_transforms,
**self.extra_args,
)

return DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
)

def val_dataloader(self) -> DataLoader:
"""Cityscapes val set."""
transforms = self.val_transforms or self._default_transforms()
Expand All @@ -176,7 +200,10 @@ def val_dataloader(self) -> DataLoader:
)

def test_dataloader(self) -> DataLoader:
"""Cityscapes test set."""
"""Cityscapes test set.

Only supported in fine quality mode.
"""
transforms = self.test_transforms or self._default_transforms()
target_transforms = self.target_transforms or self._default_target_transforms()

Expand Down Expand Up @@ -208,5 +235,8 @@ def _default_transforms(self) -> Callable:
]
)

def _default_target_transforms(self) -> Callable:
def _default_target_transforms(self) -> Optional[Callable]:
if self.target_type == "polygon":
return None

return transform_lib.Compose([transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze())])
66 changes: 22 additions & 44 deletions tests/datamodules/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_dev_datasets(datadir):
pass


def _create_synth_Cityscapes_dataset(path_dir):
def _create_synth_Cityscapes_dataset(path_dir, img_size=(2048, 1024)):
"""Create synthetic dataset with random images, just to simulate that the dataset have been already
downloaded."""
non_existing_citites = ["dummy_city_1", "dummy_city_2"]
Expand All @@ -40,32 +40,26 @@ def _create_synth_Cityscapes_dataset(path_dir):
image_name = f"{base_name}_leftImg8bit.png"
instance_target_name = f"{base_name}_gtFine_instanceIds.png"
semantic_target_name = f"{base_name}_gtFine_labelIds.png"
Image.new("RGB", (2048, 1024)).save(images_dir / split / city / image_name)
Image.new("L", (2048, 1024)).save(fine_labels_dir / split / city / instance_target_name)
Image.new("L", (2048, 1024)).save(fine_labels_dir / split / city / semantic_target_name)
color_target_name = f"{base_name}_gtFine_color.png"
Image.new("RGB", img_size).save(images_dir / split / city / image_name)
Image.new("L", img_size).save(fine_labels_dir / split / city / instance_target_name)
Image.new("L", img_size).save(fine_labels_dir / split / city / semantic_target_name)
Image.new("RGBA", img_size).save(fine_labels_dir / split / city / color_target_name)


def test_cityscapes_datamodule(datadir):
@pytest.mark.parametrize(
("target_type", "target_size"),
[("semantic", (1024, 2048)), ("instance", (1024, 2048)), ("color", (4, 1024, 2048))],
)
def test_cityscapes_datamodule(datadir, catch_warnings, target_type: str, target_size: tuple, batch_size: int = 1):
_create_synth_Cityscapes_dataset(datadir)

batch_size = 1
target_types = ["semantic", "instance"]
for target_type in target_types:
dm = CityscapesDataModule(datadir, num_workers=0, batch_size=batch_size, target_type=target_type)
loader = dm.train_dataloader()
img, mask = next(iter(loader))
assert img.size() == torch.Size([batch_size, 3, 1024, 2048])
assert mask.size() == torch.Size([batch_size, 1024, 2048])
dm = CityscapesDataModule(datadir, num_workers=0, batch_size=batch_size, target_type=target_type)

loader = dm.val_dataloader()
img, mask = next(iter(loader))
assert img.size() == torch.Size([batch_size, 3, 1024, 2048])
assert mask.size() == torch.Size([batch_size, 1024, 2048])

loader = dm.test_dataloader()
img, mask = next(iter(loader))
assert img.size() == torch.Size([batch_size, 3, 1024, 2048])
assert mask.size() == torch.Size([batch_size, 1024, 2048])
for loader in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]:
img, mask = next(iter(loader))
assert img.size() == torch.Size([batch_size, 3, 1024, 2048])
assert mask.size() == torch.Size([batch_size, *target_size])


@pytest.mark.parametrize(("val_split", "train_len"), [(0.2, 48_000), (5_000, 55_000)])
Expand All @@ -78,17 +72,9 @@ def test_vision_data_module(datadir, val_split, catch_warnings, train_len):
def test_data_modules(datadir, catch_warnings, dm_cls):
"""Test datamodules train, val, and test dataloaders outputs have correct shape."""
dm = _create_dm(dm_cls, datadir)
train_loader = dm.train_dataloader()
img, _ = next(iter(train_loader))
assert img.size() == torch.Size([2, *dm.dims])

val_loader = dm.val_dataloader()
img, _ = next(iter(val_loader))
assert img.size() == torch.Size([2, *dm.dims])

test_loader = dm.test_dataloader()
img, _ = next(iter(test_loader))
assert img.size() == torch.Size([2, *dm.dims])
for loader in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]:
img, _ = next(iter(loader))
assert img.size() == torch.Size([2, *dm.dims])


def _create_dm(dm_cls, datadir, **kwargs):
Expand All @@ -112,17 +98,9 @@ def test_sr_datamodule(datadir):
def test_emnist_datamodules(datadir, catch_warnings, dm_cls, split):
"""Test BinaryEMNIST and EMNIST datamodules download data and have the correct shape."""
dm = _create_dm(dm_cls, datadir, split=split)
train_loader = dm.train_dataloader()
img, _ = next(iter(train_loader))
assert img.size() == torch.Size([2, *dm.dims])

val_loader = dm.val_dataloader()
img, _ = next(iter(val_loader))
assert img.size() == torch.Size([2, *dm.dims])

test_loader = dm.test_dataloader()
img, _ = next(iter(test_loader))
assert img.size() == torch.Size([2, *dm.dims])
for loader in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]:
img, _ = next(iter(loader))
assert img.size() == torch.Size([2, *dm.dims])


@pytest.mark.parametrize("dm_cls", [BinaryEMNISTDataModule, EMNISTDataModule])
Expand Down