diff --git a/src/pl_bolts/models/detection/yolo/yolo_config.py b/src/pl_bolts/models/detection/yolo/yolo_config.py index fea807b1c0..293a33e3a6 100644 --- a/src/pl_bolts/models/detection/yolo/yolo_config.py +++ b/src/pl_bolts/models/detection/yolo/yolo_config.py @@ -6,10 +6,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pl_bolts.models.detection.yolo import yolo_layers -from pl_bolts.utils.stability import under_review -@under_review() class YOLOConfiguration: """This class can be used to parse the configuration files of the Darknet YOLOv4 implementation. @@ -149,7 +147,6 @@ def convert(key, value): return sections -@under_review() def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]: """Calls one of the ``_create_(config, num_inputs)`` functions to create a PyTorch module from the layer config. @@ -173,8 +170,7 @@ def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]: return create_func[config["type"]](config, num_inputs) -@under_review() -def _create_convolutional(config, num_inputs): +def _create_convolutional(config: dict, num_inputs: int) -> Tuple[nn.Module, int]: module = nn.Sequential() batch_normalize = config.get("batch_normalize", False) @@ -210,15 +206,13 @@ def _create_convolutional(config, num_inputs): return module, config["filters"] -@under_review() -def _create_maxpool(config, num_inputs): +def _create_maxpool(config: dict, num_inputs: int) -> Tuple[nn.Module, int]: padding = (config["size"] - 1) // 2 module = nn.MaxPool2d(config["size"], config["stride"], padding) return module, num_inputs[-1] -@under_review() -def _create_route(config, num_inputs): +def _create_route(config: dict, num_inputs: int) -> Tuple[nn.Module, int]: num_chunks = config.get("groups", 1) chunk_idx = config.get("group_id", 0) @@ -234,20 +228,17 @@ def _create_route(config, num_inputs): return module, num_outputs -@under_review() -def _create_shortcut(config, num_inputs): +def _create_shortcut(config: dict, num_inputs: int) -> Tuple[nn.Module, int]: module = yolo_layers.ShortcutLayer(config["from"]) return module, num_inputs[-1] -@under_review() -def _create_upsample(config, num_inputs): +def _create_upsample(config: dict, num_inputs: int) -> Tuple[nn.Module, int]: module = nn.Upsample(scale_factor=config["stride"], mode="nearest") return module, num_inputs[-1] -@under_review() -def _create_yolo(config, num_inputs): +def _create_yolo(config: dict, num_inputs: int) -> Tuple[nn.Module, int]: # The "anchors" list alternates width and height. anchor_dims = config["anchors"] anchor_dims = [(anchor_dims[i], anchor_dims[i + 1]) for i in range(0, len(anchor_dims), 2)] @@ -264,8 +255,10 @@ def _create_yolo(config, num_inputs): overlap_loss_func = yolo_layers.SELoss() elif overlap_loss_name == "giou": overlap_loss_func = yolo_layers.GIoULoss() - else: + elif overlap_loss_name == "iou": overlap_loss_func = yolo_layers.IoULoss() + else: + raise ValueError("Unknown overlap loss: " + overlap_loss_name) module = yolo_layers.DetectionLayer( num_classes=config["classes"], diff --git a/src/pl_bolts/models/detection/yolo/yolo_layers.py b/src/pl_bolts/models/detection/yolo/yolo_layers.py index 9e2d9f7475..d2259520e5 100644 --- a/src/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/src/pl_bolts/models/detection/yolo/yolo_layers.py @@ -1,11 +1,10 @@ -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor, nn -from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import under_review +from pl_bolts.utils import _TORCH_MESHGRID_REQUIRES_INDEXING, _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -21,7 +20,6 @@ warn_missing_pkg("torchvision") -@under_review() def _corner_coordinates(xy: Tensor, wh: Tensor) -> Tensor: """Converts box center points and sizes to corner coordinates. @@ -38,7 +36,6 @@ def _corner_coordinates(xy: Tensor, wh: Tensor) -> Tensor: return torch.cat((top_left, bottom_right), -1) -@under_review() def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor: """Calculates a matrix of intersections over union from box dimensions, assuming that the boxes are located at the same coordinates. @@ -61,7 +58,6 @@ def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor: return inter / union -@under_review() class SELoss(nn.MSELoss): def __init__(self): super().__init__(reduction="none") @@ -70,13 +66,11 @@ def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return super().forward(inputs, target).sum(1) -@under_review() class IoULoss(nn.Module): def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return 1.0 - box_iou(inputs, target).diagonal() -@under_review() class GIoULoss(nn.Module): def __init__(self) -> None: super().__init__() @@ -89,7 +83,6 @@ def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return 1.0 - generalized_box_iou(inputs, target).diagonal() -@under_review() class DetectionLayer(nn.Module): """A YOLO detection layer. @@ -263,7 +256,10 @@ def _global_xy(self, xy: Tensor, image_size: Tensor) -> Tensor: x_range = torch.arange(width, device=xy.device) y_range = torch.arange(height, device=xy.device) - grid_y, grid_x = torch.meshgrid(y_range, x_range) + if _TORCH_MESHGRID_REQUIRES_INDEXING: + grid_y, grid_x = torch.meshgrid(y_range, x_range, indexing="ij") + else: + grid_y, grid_x = torch.meshgrid(y_range, x_range) offset = torch.stack((grid_x, grid_y), -1) # [height, width, 2] offset = offset.unsqueeze(2) # [height, width, 1, 2] @@ -468,15 +464,13 @@ def _calculate_losses( return losses, hits -@under_review() class Mish(nn.Module): """Mish activation.""" - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: return x * torch.tanh(nn.functional.softplus(x)) -@under_review() class RouteLayer(nn.Module): """Route layer concatenates the output (or part of it) from given layers.""" @@ -492,12 +486,11 @@ def __init__(self, source_layers: List[int], num_chunks: int, chunk_idx: int) -> self.num_chunks = num_chunks self.chunk_idx = chunk_idx - def forward(self, x, outputs): + def forward(self, x, outputs: List[Union[Tensor, None]]) -> Tensor: chunks = [torch.chunk(outputs[layer], self.num_chunks, dim=1)[self.chunk_idx] for layer in self.source_layers] return torch.cat(chunks, dim=1) -@under_review() class ShortcutLayer(nn.Module): """Shortcut layer adds a residual connection from the source layer.""" @@ -510,5 +503,5 @@ def __init__(self, source_layer: int) -> None: super().__init__() self.source_layer = source_layer - def forward(self, x, outputs): + def forward(self, x, outputs: List[Union[Tensor, None]]) -> Tensor: return outputs[-1] + outputs[self.source_layer] diff --git a/src/pl_bolts/models/detection/yolo/yolo_module.py b/src/pl_bolts/models/detection/yolo/yolo_module.py index 6a012f1db9..650765421d 100644 --- a/src/pl_bolts/models/detection/yolo/yolo_module.py +++ b/src/pl_bolts/models/detection/yolo/yolo_module.py @@ -11,7 +11,6 @@ from pl_bolts.models.detection.yolo.yolo_layers import DetectionLayer, RouteLayer, ShortcutLayer from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -23,7 +22,6 @@ log = logging.getLogger(__name__) -@under_review() class YOLO(LightningModule): """PyTorch Lightning implementation of YOLOv3 and YOLOv4. @@ -179,7 +177,7 @@ def forward( ) for layer_idx, layer_hits in enumerate(hits): hit_rate = torch.true_divide(layer_hits, total_hits) if total_hits > 0 else 1.0 - self.log(f"layer_{layer_idx}_hit_rate", hit_rate, sync_dist=False) + self.log(f"layer_{layer_idx}_hit_rate", hit_rate, sync_dist=False, batch_size=images.size(0)) def total_loss(loss_name): """Returns the sum of the loss over detection layers.""" @@ -233,8 +231,8 @@ def validation_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], b total_loss = torch.stack(tuple(losses.values())).sum() for name, value in losses.items(): - self.log(f"val/{name}_loss", value, sync_dist=True) - self.log("val/total_loss", total_loss, sync_dist=True) + self.log(f"val/{name}_loss", value, sync_dist=True, batch_size=images.size(0)) + self.log("val/total_loss", total_loss, sync_dist=True, batch_size=images.size(0)) def test_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int): """Evaluates a batch of data from the test set. @@ -455,7 +453,6 @@ def _filter_detections(self, detections: Dict[str, Tensor]) -> Dict[str, List[Te return {"boxes": out_boxes, "scores": out_scores, "classprobs": out_classprobs, "labels": out_labels} -@under_review() class Resize: """Rescales the image and target to given dimensions. @@ -486,7 +483,6 @@ def __call__(self, image: Tensor, target: Dict[str, Any]): return image, target -@under_review() def run_cli(): from argparse import ArgumentParser diff --git a/src/pl_bolts/utils/__init__.py b/src/pl_bolts/utils/__init__.py index 07cbca7ae0..58fa717f83 100644 --- a/src/pl_bolts/utils/__init__.py +++ b/src/pl_bolts/utils/__init__.py @@ -6,7 +6,6 @@ from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerification # type: ignore _NATIVE_AMP_AVAILABLE: bool = module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") - _TORCHVISION_AVAILABLE: bool = module_available("torchvision") _GYM_AVAILABLE: bool = module_available("gym") _SKLEARN_AVAILABLE: bool = module_available("sklearn") @@ -20,6 +19,7 @@ _PL_GREATER_EQUAL_1_4_5 = compare_version("pytorch_lightning", operator.ge, "1.4.5") _TORCH_ORT_AVAILABLE = module_available("torch_ort") _TORCH_MAX_VERSION_SPARSEML = compare_version("torch", operator.lt, "1.11.0") +_TORCH_MESHGRID_REQUIRES_INDEXING = compare_version("torch", operator.ge, "1.10.0") _SPARSEML_AVAILABLE = module_available("sparseml") and _PL_GREATER_EQUAL_1_4_5 and _TORCH_MAX_VERSION_SPARSEML _JSONARGPARSE_GREATER_THAN_4_16_0 = compare_version("jsonargparse", operator.gt, "4.16.0") diff --git a/tests/data/yolo_giou.cfg b/tests/data/yolo_giou.cfg new file mode 100644 index 0000000000..16a96f918d --- /dev/null +++ b/tests/data/yolo_giou.cfg @@ -0,0 +1,81 @@ +[net] +width=256 +height=256 +channels=3 + +[convolutional] +batch_normalize=1 +filters=8 +size=3 +stride=1 +pad=1 +activation=leaky + +[route] +layers=-1 +groups=2 +group_id=1 + +[maxpool] +size=2 +stride=2 + +[convolutional] +batch_normalize=1 +filters=2 +size=1 +stride=1 +pad=1 +activation=mish + +[convolutional] +batch_normalize=1 +filters=4 +size=3 +stride=1 +pad=1 +activation=mish + +[shortcut] +from=-3 +activation=linear + +[convolutional] +size=1 +stride=1 +pad=1 +filters=14 +activation=linear + +[yolo] +mask=2,3 +anchors=1,2, 3,4, 5,6, 9,10 +classes=2 +iou_loss=giou +scale_x_y=1.05 +cls_normalizer=1.0 +iou_normalizer=0.07 +ignore_thresh=0.7 + +[route] +layers = -4 + +[upsample] +stride=2 + +[convolutional] +size=1 +stride=1 +pad=1 +filters=14 +activation=linear + +[yolo] +mask=0,1 +anchors=1,2, 3,4, 5,6, 9,10 +classes=2 +iou_loss=giou +scale_x_y=1.05 +cls_normalizer=1.0 +iou_normalizer=0.07 +ignore_thresh=0.7 diff --git a/tests/models/test_detection.py b/tests/models/test_detection.py index 31ac35377f..b451ddfde7 100644 --- a/tests/models/test_detection.py +++ b/tests/models/test_detection.py @@ -1,8 +1,10 @@ +import warnings from pathlib import Path import pytest import torch from pytorch_lightning import Trainer +from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader from pl_bolts.datasets import DummyDetectionDataset @@ -79,8 +81,9 @@ def test_fasterrcnn_pyt_module_bbone_train(tmpdir): trainer.fit(model, train_dl, valid_dl) -def test_yolo(tmpdir): - config_path = Path(TEST_ROOT) / "data" / "yolo.cfg" +@pytest.mark.parametrize("config", [("yolo"), ("yolo_giou")]) +def test_yolo(config, catch_warnings): + config_path = Path(TEST_ROOT) / "data" / f"{config}.cfg" config = YOLOConfiguration(config_path) model = YOLO(config.get_network()) @@ -88,15 +91,28 @@ def test_yolo(tmpdir): model(image) -def test_yolo_train(tmpdir): - config_path = Path(TEST_ROOT) / "data" / "yolo.cfg" +@pytest.mark.parametrize( + "cfg_name", + [ + ("yolo"), + ("yolo_giou"), + ], +) +def test_yolo_train(tmpdir, cfg_name, catch_warnings): + warnings.filterwarnings( + "ignore", + message=".*does not have many workers which may be a bottleneck.*", + category=PossibleUserWarning, + ) + + config_path = Path(TEST_ROOT) / "data" / f"{cfg_name}.cfg" config = YOLOConfiguration(config_path) model = YOLO(config.get_network()) train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) - trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) + trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, logger=False, max_epochs=10, accelerator="auto") trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=valid_dl) @@ -110,5 +126,11 @@ def test_yolo_train(tmpdir): ) ], ) -def test_aligned_iou(dims1, dims2, expected_ious): - torch.testing.assert_allclose(_aligned_iou(dims1, dims2), expected_ious) +def test_aligned_iou(dims1, dims2, expected_ious, catch_warnings): + warnings.filterwarnings( + "ignore", + message=".*does not have many workers which may be a bottleneck.*", + category=PossibleUserWarning, + ) + + torch.testing.assert_close(_aligned_iou(dims1, dims2), expected_ious) diff --git a/tests/models/yolo/__init__.py b/tests/models/yolo/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/yolo/unit/__init__.py b/tests/models/yolo/unit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/yolo/unit/test_yolo_config.py b/tests/models/yolo/unit/test_yolo_config.py new file mode 100644 index 0000000000..807e2a84db --- /dev/null +++ b/tests/models/yolo/unit/test_yolo_config.py @@ -0,0 +1,119 @@ +import warnings + +import pytest +from pytorch_lightning.utilities.warnings import PossibleUserWarning + +from pl_bolts.models.detection.yolo.yolo_config import ( + _create_convolutional, + _create_maxpool, + _create_shortcut, + _create_upsample, +) + + +@pytest.mark.parametrize( + "config", + [ + ({"batch_normalize": 1, "filters": 8, "size": 3, "stride": 1, "pad": 1, "activation": "leaky"}), + ({"batch_normalize": 0, "filters": 2, "size": 1, "stride": 1, "pad": 1, "activation": "mish"}), + ({"batch_normalize": 1, "filters": 6, "size": 3, "stride": 2, "pad": 1, "activation": "logistic"}), + ({"batch_normalize": 0, "filters": 4, "size": 3, "stride": 2, "pad": 0, "activation": "linear"}), + ], +) +def test_create_convolutional(config, catch_warnings): + warnings.filterwarnings( + "ignore", + message=".*does not have many workers which may be a bottleneck.*", + category=PossibleUserWarning, + ) + + conv, _ = _create_convolutional(config, [3]) + + assert conv.conv.out_channels == config["filters"] + assert conv.conv.kernel_size == (config["size"], config["size"]) + assert conv.conv.stride == (config["stride"], config["stride"]) + + activation = config["activation"] + pad_size = (config["size"] - 1) // 2 if config["pad"] else 0 + + if config["pad"]: + assert conv.conv.padding == (pad_size, pad_size) + + if config["batch_normalize"]: + assert len(conv) == 3 + + if activation != "linear": + if activation != "logistic": + assert activation == conv[-1].__class__.__name__.lower()[: len(activation)] + elif activation == "logistic": + assert "sigmoid" == conv[-1].__class__.__name__.lower() + + +@pytest.mark.parametrize( + "config", + [ + ( + { + "size": 2, + "stride": 2, + } + ), + ( + { + "size": 6, + "stride": 3, + } + ), + ], +) +def test_create_maxpool(config, catch_warnings): + warnings.filterwarnings( + "ignore", + message=".*does not have many workers which may be a bottleneck.*", + category=PossibleUserWarning, + ) + + pad_size = (config["size"] - 1) // 2 + maxpool, _ = _create_maxpool(config, [3]) + + assert maxpool.kernel_size == config["size"] + assert maxpool.stride == config["stride"] + assert maxpool.padding == pad_size + + +@pytest.mark.parametrize( + "config", + [ + ({"from": 1, "activation": "linear"}), + ({"from": 3, "activation": "linear"}), + ], +) +def test_create_shortcut(config, catch_warnings): + warnings.filterwarnings( + "ignore", + message=".*does not have many workers which may be a bottleneck.*", + category=PossibleUserWarning, + ) + + shortcut, _ = _create_shortcut(config, [3]) + + assert shortcut.source_layer == config["from"] + + +@pytest.mark.parametrize( + "config", + [ + ({"stride": 2}), + ({"stride": 4}), + ], +) +def test_create_upsample(config, catch_warnings): + warnings.filterwarnings( + "ignore", + message=".*does not have many workers which may be a bottleneck.*", + category=PossibleUserWarning, + ) + + upsample, _ = _create_upsample(config, [3]) + + assert upsample.scale_factor == float(config["stride"]) diff --git a/tests/models/yolo/unit/test_yolo_layers.py b/tests/models/yolo/unit/test_yolo_layers.py new file mode 100644 index 0000000000..02e209d400 --- /dev/null +++ b/tests/models/yolo/unit/test_yolo_layers.py @@ -0,0 +1,51 @@ +import warnings + +import pytest +import torch +from pytorch_lightning.utilities.warnings import PossibleUserWarning + +from pl_bolts.models.detection.yolo.yolo_layers import GIoULoss, IoULoss, SELoss, _corner_coordinates + + +@pytest.mark.parametrize( + "xy, wh, expected", + [ + ([0.0, 0.0], [1.0, 1.0], [-0.5, -0.5, 0.5, 0.5]), + ([5.0, 5.0], [2.0, 2.0], [4.0, 4.0, 6.0, 6.0]), + ], +) +def test_corner_coordinates(xy, wh, expected, catch_warnings): + warnings.filterwarnings( + "ignore", + message=".*does not have many workers which may be a bottleneck.*", + category=PossibleUserWarning, + ) + + xy = torch.tensor(xy) + wh = torch.tensor(wh) + corners = _corner_coordinates(xy, wh) + assert torch.allclose(corners, torch.tensor(expected)) + + +@pytest.mark.parametrize( + "loss_func, bbox1, bbox2, expected", + [ + (GIoULoss, [[0.0, 0.0, 120.0, 200.0]], [[189.0, 93.0, 242.0, 215.0]], 1.4144532680511475), + (IoULoss, [[0.0, 0.0, 120.0, 200.0]], [[189.0, 93.0, 242.0, 215.0]], 1.0), + (SELoss, [[0.0, 0.0, 120.0, 200.0]], [[189.0, 93.0, 242.0, 215.0]], 59479.0), + ], +) +def test_loss_functions(loss_func, bbox1, bbox2, expected, catch_warnings): + warnings.filterwarnings( + "ignore", + message=".*does not have many workers which may be a bottleneck.*", + category=PossibleUserWarning, + ) + + loss_func = loss_func() + tensor1 = torch.tensor(bbox1, dtype=torch.float32) + tensor2 = torch.tensor(bbox2, dtype=torch.float32) + + loss = loss_func(tensor1, tensor2) + assert loss.item() > 0.0 + assert loss.item() == expected