From 67debd9439ecaa1d4d3f2dc35d2793f4179f6a07 Mon Sep 17 00:00:00 2001 From: lijm1358 Date: Wed, 28 Dec 2022 18:18:22 +0900 Subject: [PATCH] add color target type test --- tests/datamodules/test_datamodules.py | 37 +++++++++++++++------------ 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/tests/datamodules/test_datamodules.py b/tests/datamodules/test_datamodules.py index e827147abe..0ff6688dad 100644 --- a/tests/datamodules/test_datamodules.py +++ b/tests/datamodules/test_datamodules.py @@ -41,32 +41,35 @@ def _create_synth_Cityscapes_dataset(path_dir): image_name = f"{base_name}_leftImg8bit.png" instance_target_name = f"{base_name}_gtFine_instanceIds.png" semantic_target_name = f"{base_name}_gtFine_labelIds.png" + color_target_name = f"{base_name}_gtFine_color.png" Image.new("RGB", (2048, 1024)).save(images_dir / split / city / image_name) Image.new("L", (2048, 1024)).save(fine_labels_dir / split / city / instance_target_name) Image.new("L", (2048, 1024)).save(fine_labels_dir / split / city / semantic_target_name) + Image.new("RGBA", (2048, 1024)).save(fine_labels_dir / split / city / color_target_name) -def test_cityscapes_datamodule(datadir): +def test_cityscapes_datamodule(datadir, catch_warnings): _create_synth_Cityscapes_dataset(datadir) batch_size = 1 - target_types = ["semantic", "instance"] - for target_type in target_types: + target_types = ["semantic", "instance", "color"] + target_sizes = [(1024, 2048), (1024, 2048), (4, 1024, 2048)] + for target_type, target_size in zip(target_types, target_sizes): dm = CityscapesDataModule(datadir, num_workers=0, batch_size=batch_size, target_type=target_type) - loader = dm.train_dataloader() - img, mask = next(iter(loader)) - assert img.size() == torch.Size([batch_size, 3, 1024, 2048]) - assert mask.size() == torch.Size([batch_size, 1024, 2048]) - - loader = dm.val_dataloader() - img, mask = next(iter(loader)) - assert img.size() == torch.Size([batch_size, 3, 1024, 2048]) - assert mask.size() == torch.Size([batch_size, 1024, 2048]) - - loader = dm.test_dataloader() - img, mask = next(iter(loader)) - assert img.size() == torch.Size([batch_size, 3, 1024, 2048]) - assert mask.size() == torch.Size([batch_size, 1024, 2048]) + loader = dm.train_dataloader() + img, mask = next(iter(loader)) + assert img.size() == torch.Size([batch_size, 3, 1024, 2048]) + assert mask.size() == torch.Size([batch_size, *target_size]) + + loader = dm.val_dataloader() + img, mask = next(iter(loader)) + assert img.size() == torch.Size([batch_size, 3, 1024, 2048]) + assert mask.size() == torch.Size([batch_size, *target_size]) + + loader = dm.test_dataloader() + img, mask = next(iter(loader)) + assert img.size() == torch.Size([batch_size, 3, 1024, 2048]) + assert mask.size() == torch.Size([batch_size, *target_size]) @pytest.mark.parametrize("val_split, train_len", [(0.2, 48_000), (5_000, 55_000)])