From b9b35c4be8fb0457f419a036605863f19047c086 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Fri, 10 Sep 2021 09:33:24 +0200 Subject: [PATCH] Add YOLO object detection model (#552) * Add YOLO object detection model * Reading Darknet weights works also with truncated files. * Use torch.min() instead of torch.minimum() to avoid error with older PyTorch versions. * Generalized interface for custom losses * IoU loss functions take image space coordinates as input. * box_area() implementation copied from torchvision * IoU losses use torchvision * IoU losses take the diagnoal of torchvision iou ops instead of implementing their own elementwise ops. * YOLO written with all caps in class names * Generic way to specify optimizer and LR scheduler * Possible to limit the number of predictions per image * No need to check for NaN values as Trainer has terminate_on_nan argument. * YOLO test configuration moved to tests/data/yolo.cfg * Synchronize validation and test step logging calls * Log losses to progress bar * Fixed documentation formatting * Coordinate predictions are in image scale * Use default dtype for torch.arange() to fix export to TensorRT * Network input size can differ from the image size specified in the configuration * Image size is given to detection layer forward() instead of the constructor to allow variable image sizes. * Use default data type for torch.arange() to fix export to TensorRT. * Loss is normalized by batch size only once * Fixed division by zero when there are no targets in a batch * Always return all losses to avoid deadlock with DDP when there are no targets * Hit rates are always logged so don't prefix the names * The vector of overlap losses was accidentally transformed to a square matrix. * Some versions of Lightning don't work correctly when logging losses with sync_dist=True. * Truncate nms() inputs to avoid it crashing when too many boxes are detected * Use sum() instead of count_nonzero() as it's available already before PyTorch 1.7 * Squared error loss takes the sum over the predicted attributes * Swish and logistic activation functions * VOCDetectionDataModule constructor takes batch size and the transforms take image and target * Use true_divide() for integer division * Fixed doc and package build without Torchvision Co-authored-by: Akihiro Nitta --- CHANGELOG.md | 1 + docs/source/index.rst | 1 + docs/source/object_detection.rst | 20 + .../datamodules/vocdetection_datamodule.py | 77 +-- pl_bolts/models/detection/__init__.py | 4 + .../faster_rcnn/faster_rcnn_module.py | 3 +- pl_bolts/models/detection/yolo/__init__.py | 0 pl_bolts/models/detection/yolo/yolo_config.py | 273 ++++++++ pl_bolts/models/detection/yolo/yolo_layers.py | 504 +++++++++++++++ pl_bolts/models/detection/yolo/yolo_module.py | 609 ++++++++++++++++++ tests/data/yolo.cfg | 79 +++ tests/models/test_detection.py | 42 +- 12 files changed, 1568 insertions(+), 45 deletions(-) create mode 100644 docs/source/object_detection.rst create mode 100644 pl_bolts/models/detection/yolo/__init__.py create mode 100644 pl_bolts/models/detection/yolo/yolo_config.py create mode 100644 pl_bolts/models/detection/yolo/yolo_layers.py create mode 100644 pl_bolts/models/detection/yolo/yolo_module.py create mode 100644 tests/data/yolo.cfg diff --git a/CHANGELOG.md b/CHANGELOG.md index a4ffecc4e9..44df6f2152 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added YOLO model ([#552](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/552)) ### Changed diff --git a/docs/source/index.rst b/docs/source/index.rst index c0eeea8689..2fff84677c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -71,6 +71,7 @@ Lightning-Bolts documentation autoencoders convolutional + object_detection gans reinforce_learn self_supervised_models diff --git a/docs/source/object_detection.rst b/docs/source/object_detection.rst new file mode 100644 index 0000000000..37ace12a88 --- /dev/null +++ b/docs/source/object_detection.rst @@ -0,0 +1,20 @@ +Object Detection +================ +This package lists contributed object detection models. + +-------------- + + +Faster R-CNN +------------ + +.. autoclass:: pl_bolts.models.detection.faster_rcnn.faster_rcnn_module.FasterRCNN + :noindex: + +------------- + +YOLO +---- + +.. autoclass:: pl_bolts.models.detection.yolo.yolo_module.YOLO + :noindex: diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 8f8edc66aa..15e3fc8a80 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -1,9 +1,10 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import os +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from pytorch_lightning import LightningDataModule from torch import Tensor -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg @@ -107,10 +108,11 @@ class VOCDetectionDataModule(LightningDataModule): def __init__( self, - data_dir: str, + data_dir: Optional[str] = None, year: str = "2012", num_workers: int = 0, normalize: bool = False, + batch_size: int = 16, shuffle: bool = True, pin_memory: bool = True, drop_last: bool = False, @@ -125,9 +127,10 @@ def __init__( super().__init__(*args, **kwargs) self.year = year - self.data_dir = data_dir + self.data_dir = data_dir if data_dir is not None else os.getcwd() self.num_workers = num_workers self.normalize = normalize + self.batch_size = batch_size self.shuffle = shuffle self.pin_memory = pin_memory self.drop_last = drop_last @@ -145,60 +148,50 @@ def prepare_data(self) -> None: VOCDetection(self.data_dir, year=self.year, image_set="train", download=True) VOCDetection(self.data_dir, year=self.year, image_set="val", download=True) - def train_dataloader( - self, batch_size: int = 1, image_transforms: Union[List[Callable], Callable] = None - ) -> DataLoader: + def train_dataloader(self, image_transforms: Optional[Callable] = None) -> DataLoader: """VOCDetection train set uses the `train` subset. Args: - batch_size: size of batch - transforms: custom transforms + image_transforms: custom image-only transforms """ - transforms = [_prepare_voc_instance] - image_transforms = image_transforms or self.train_transforms or self._default_transforms() + transforms = [ + _prepare_voc_instance, + self.default_transforms() if self.train_transforms is None else self.train_transforms, + ] transforms = Compose(transforms, image_transforms) + dataset = VOCDetection(self.data_dir, year=self.year, image_set="train", transforms=transforms) - loader = DataLoader( - dataset, - batch_size=batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - drop_last=self.drop_last, - pin_memory=self.pin_memory, - collate_fn=_collate_fn, - ) - return loader + return self._data_loader(dataset, shuffle=self.shuffle) - def val_dataloader(self, batch_size: int = 1, image_transforms: Optional[List[Callable]] = None) -> DataLoader: + def val_dataloader(self, image_transforms: Optional[Callable] = None) -> DataLoader: """VOCDetection val set uses the `val` subset. Args: - batch_size: size of batch - transforms: custom transforms + image_transforms: custom image-only transforms """ - transforms = [_prepare_voc_instance] - image_transforms = image_transforms or self.train_transforms or self._default_transforms() + transforms = [ + _prepare_voc_instance, + self.default_transforms() if self.val_transforms is None else self.val_transforms, + ] transforms = Compose(transforms, image_transforms) + dataset = VOCDetection(self.data_dir, year=self.year, image_set="val", transforms=transforms) - loader = DataLoader( + return self._data_loader(dataset, shuffle=False) + + def default_transforms(self) -> Callable: + voc_transforms = [transform_lib.ToTensor()] + if self.normalize: + voc_transforms += [transform_lib.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] + voc_transforms = transform_lib.Compose(voc_transforms) + return lambda image, target: (voc_transforms(image), target) + + def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader: + return DataLoader( dataset, - batch_size=batch_size, - shuffle=False, + batch_size=self.batch_size, + shuffle=shuffle, num_workers=self.num_workers, drop_last=self.drop_last, pin_memory=self.pin_memory, collate_fn=_collate_fn, ) - return loader - - def _default_transforms(self) -> Callable: - if self.normalize: - voc_transforms = transform_lib.Compose( - [ - transform_lib.ToTensor(), - transform_lib.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ] - ) - else: - voc_transforms = transform_lib.Compose([transform_lib.ToTensor()]) - return voc_transforms diff --git a/pl_bolts/models/detection/__init__.py b/pl_bolts/models/detection/__init__.py index 3920378102..2d7a4a2d95 100644 --- a/pl_bolts/models/detection/__init__.py +++ b/pl_bolts/models/detection/__init__.py @@ -1,7 +1,11 @@ from pl_bolts.models.detection import components from pl_bolts.models.detection.faster_rcnn import FasterRCNN +from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration +from pl_bolts.models.detection.yolo.yolo_module import YOLO __all__ = [ "components", "FasterRCNN", + "YOLOConfiguration", + "YOLO", ] diff --git a/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py b/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py index 2773473ad7..bb1c4c1051 100644 --- a/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py +++ b/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py @@ -147,9 +147,8 @@ def run_cli(): seed_everything(42) parser = ArgumentParser() + parser = VOCDetectionDataModule.add_argparse_args(parser) parser = Trainer.add_argparse_args(parser) - parser.add_argument("--data_dir", type=str, default=".") - parser.add_argument("--batch_size", type=int, default=1) parser = FasterRCNN.add_model_specific_args(parser) args = parser.parse_args() diff --git a/pl_bolts/models/detection/yolo/__init__.py b/pl_bolts/models/detection/yolo/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py new file mode 100644 index 0000000000..fea56df7e8 --- /dev/null +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -0,0 +1,273 @@ +import re +from typing import Any, Dict, Iterable, List, Tuple +from warnings import warn + +import torch.nn as nn +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from pl_bolts.models.detection.yolo import yolo_layers + + +class YOLOConfiguration: + """This class can be used to parse the configuration files of the Darknet YOLOv4 implementation. + + The :func:`~pl_bolts.models.detection.yolo.yolo_config.YOLOConfiguration.get_network` method + returns a PyTorch module list that can be used to construct a YOLO model. + """ + + def __init__(self, path: str) -> None: + """Saves the variables from the first configuration section to attributes of this object, and the rest of + the sections to the ``layer_configs`` list. + + Args: + path: Path to a configuration file + """ + with open(path) as config_file: + sections = self._read_file(config_file) + + if len(sections) < 2: + raise MisconfigurationException("The model configuration file should include at least two sections.") + + self.__dict__.update(sections[0]) + self.global_config = sections[0] + self.layer_configs = sections[1:] + + def get_network(self) -> nn.ModuleList: + """Iterates through the layers from the configuration and creates corresponding PyTorch modules. Returns + the network structure that can be used to create a YOLO model. + + Returns: + A :class:`~torch.nn.ModuleList` that defines the YOLO network. + """ + result = nn.ModuleList() + num_inputs = [3] # Number of channels in the input of every layer up to the current layer + for layer_config in self.layer_configs: + config = {**self.global_config, **layer_config} + module, num_outputs = _create_layer(config, num_inputs) + result.append(module) + num_inputs.append(num_outputs) + return result + + def _read_file(self, config_file: Iterable[str]) -> List[Dict[str, Any]]: + """Reads a YOLOv4 network configuration file and returns a list of configuration sections. + + Args: + config_file: The configuration file to read. + + Returns: + A list of configuration sections. + """ + section_re = re.compile(r"\[([^]]+)\]") + list_variables = ("layers", "anchors", "mask", "scales") + variable_types = { + "activation": str, + "anchors": int, + "angle": float, + "batch": int, + "batch_normalize": bool, + "beta_nms": float, + "burn_in": int, + "channels": int, + "classes": int, + "cls_normalizer": float, + "decay": float, + "exposure": float, + "filters": int, + "from": int, + "groups": int, + "group_id": int, + "height": int, + "hue": float, + "ignore_thresh": float, + "iou_loss": str, + "iou_normalizer": float, + "iou_thresh": float, + "jitter": float, + "layers": int, + "learning_rate": float, + "mask": int, + "max_batches": int, + "max_delta": float, + "momentum": float, + "mosaic": bool, + "new_coords": int, + "nms_kind": str, + "num": int, + "obj_normalizer": float, + "pad": bool, + "policy": str, + "random": bool, + "resize": float, + "saturation": float, + "scales": float, + "scale_x_y": float, + "size": int, + "steps": str, + "stride": int, + "subdivisions": int, + "truth_thresh": float, + "width": int, + } + + section = None + sections = [] + + def convert(key, value): + """Converts a value to the correct type based on key.""" + if key not in variable_types: + warn("Unknown YOLO configuration variable: " + key) + return key, value + if key in list_variables: + value = [variable_types[key](v) for v in value.split(",")] + else: + value = variable_types[key](value) + return key, value + + for line in config_file: + line = line.strip() + if (not line) or (line[0] == "#"): + continue + + section_match = section_re.match(line) + if section_match: + if section is not None: + sections.append(section) + section = {"type": section_match.group(1)} + else: + key, value = line.split("=") + key = key.rstrip() + value = value.lstrip() + key, value = convert(key, value) + section[key] = value + if section is not None: + sections.append(section) + + return sections + + +def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]: + """Calls one of the ``_create_(config, num_inputs)`` functions to create a PyTorch module from the + layer config. + + Args: + config: Dictionary of configuration options for this layer. + num_inputs: Number of channels in the input of every layer up to this layer. + + Returns: + module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the + number of channels in its output. + """ + create_func = { + "convolutional": _create_convolutional, + "maxpool": _create_maxpool, + "route": _create_route, + "shortcut": _create_shortcut, + "upsample": _create_upsample, + "yolo": _create_yolo, + } + return create_func[config["type"]](config, num_inputs) + + +def _create_convolutional(config, num_inputs): + module = nn.Sequential() + + batch_normalize = config.get("batch_normalize", False) + padding = (config["size"] - 1) // 2 if config["pad"] else 0 + + conv = nn.Conv2d( + num_inputs[-1], config["filters"], config["size"], config["stride"], padding, bias=not batch_normalize + ) + module.add_module("conv", conv) + + if batch_normalize: + bn = nn.BatchNorm2d(config["filters"]) + module.add_module("bn", bn) + + activation_name = config["activation"] + if activation_name == "leaky": + leakyrelu = nn.LeakyReLU(0.1, inplace=True) + module.add_module("leakyrelu", leakyrelu) + elif activation_name == "mish": + mish = yolo_layers.Mish() + module.add_module("mish", mish) + elif activation_name == "swish": + swish = nn.SiLU(inplace=True) + module.add_module("swish", swish) + elif activation_name == "logistic": + logistic = nn.Sigmoid() + module.add_module("logistic", logistic) + elif activation_name == "linear": + pass + else: + raise ValueError("Unknown activation: " + activation_name) + + return module, config["filters"] + + +def _create_maxpool(config, num_inputs): + padding = (config["size"] - 1) // 2 + module = nn.MaxPool2d(config["size"], config["stride"], padding) + return module, num_inputs[-1] + + +def _create_route(config, num_inputs): + num_chunks = config.get("groups", 1) + chunk_idx = config.get("group_id", 0) + + # 0 is the first layer, -1 is the previous layer + last = len(num_inputs) - 1 + source_layers = [layer if layer >= 0 else last + layer for layer in config["layers"]] + + module = yolo_layers.RouteLayer(source_layers, num_chunks, chunk_idx) + + # The number of outputs of a source layer is the number of inputs of the next layer. + num_outputs = sum(num_inputs[layer + 1] // num_chunks for layer in source_layers) + + return module, num_outputs + + +def _create_shortcut(config, num_inputs): + module = yolo_layers.ShortcutLayer(config["from"]) + return module, num_inputs[-1] + + +def _create_upsample(config, num_inputs): + module = nn.Upsample(scale_factor=config["stride"], mode="nearest") + return module, num_inputs[-1] + + +def _create_yolo(config, num_inputs): + # The "anchors" list alternates width and height. + anchor_dims = config["anchors"] + anchor_dims = [(anchor_dims[i], anchor_dims[i + 1]) for i in range(0, len(anchor_dims), 2)] + + xy_scale = config.get("scale_x_y", 1.0) + input_is_normalized = config.get("new_coords", 0) > 0 + ignore_threshold = config.get("ignore_thresh", 1.0) + overlap_loss_multiplier = config.get("iou_normalizer", 1.0) + class_loss_multiplier = config.get("cls_normalizer", 1.0) + confidence_loss_multiplier = config.get("obj_normalizer", 1.0) + + overlap_loss_name = config.get("iou_loss", "mse") + if overlap_loss_name == "mse": + overlap_loss_func = yolo_layers.SELoss() + elif overlap_loss_name == "giou": + overlap_loss_func = yolo_layers.GIoULoss() + else: + overlap_loss_func = yolo_layers.IoULoss() + + module = yolo_layers.DetectionLayer( + num_classes=config["classes"], + anchor_dims=anchor_dims, + anchor_ids=config["mask"], + xy_scale=xy_scale, + input_is_normalized=input_is_normalized, + ignore_threshold=ignore_threshold, + overlap_loss_func=overlap_loss_func, + image_space_loss=overlap_loss_name != "mse", + overlap_loss_multiplier=overlap_loss_multiplier, + class_loss_multiplier=class_loss_multiplier, + confidence_loss_multiplier=confidence_loss_multiplier, + ) + + return module, num_inputs[-1] diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py new file mode 100644 index 0000000000..9b1ee891df --- /dev/null +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -0,0 +1,504 @@ +from typing import Callable, Dict, List, Optional, Tuple + +import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch import Tensor, nn + +from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _TORCHVISION_AVAILABLE: + from torchvision.ops import box_iou + + try: + from torchvision.ops import generalized_box_iou + except ImportError: + _GIOU_AVAILABLE = False + else: + _GIOU_AVAILABLE = True +else: + warn_missing_pkg("torchvision") + + +def _corner_coordinates(xy: Tensor, wh: Tensor) -> Tensor: + """Converts box center points and sizes to corner coordinates. + + Args: + xy: Center coordinates. Tensor of size ``[..., 2]``. + wh: Width and height. Tensor of size ``[..., 2]``. + + Returns: + A matrix of `(x1, y1, x2, y2)` coordinates. + """ + half_wh = wh / 2 + top_left = xy - half_wh + bottom_right = xy + half_wh + return torch.cat((top_left, bottom_right), -1) + + +def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor: + """Calculates a matrix of intersections over union from box dimensions, assuming that the boxes are located at + the same coordinates. + + Args: + dims1: Width and height of `N` boxes. Tensor of size ``[N, 2]``. + dims2: Width and height of `M` boxes. Tensor of size ``[M, 2]``. + + Returns: + Tensor of size ``[N, M]`` containing the pairwise IoU values for every element in + ``dims1`` and ``dims2`` + """ + area1 = dims1[:, 0] * dims1[:, 1] # [N] + area2 = dims2[:, 0] * dims2[:, 1] # [M] + + inter_wh = torch.min(dims1[:, None, :], dims2) # [N, M, 2] + inter = inter_wh[:, :, 0] * inter_wh[:, :, 1] # [N, M] + union = area1[:, None] + area2 - inter # [N, M] + + return inter / union + + +class SELoss(nn.MSELoss): + def __init__(self): + super().__init__(reduction="none") + + def forward(self, inputs: Tensor, target: Tensor) -> Tensor: + return super().forward(inputs, target).sum(1) + + +class IoULoss(nn.Module): + def forward(self, inputs: Tensor, target: Tensor) -> Tensor: + return 1.0 - box_iou(inputs, target).diagonal() + + +class GIoULoss(nn.Module): + def __init__(self) -> None: + super().__init__() + if not _GIOU_AVAILABLE: + raise ModuleNotFoundError( # pragma: no-cover + "A more recent version of `torchvision` is needed for generalized IoU loss." + ) + + def forward(self, inputs: Tensor, target: Tensor) -> Tensor: + return 1.0 - generalized_box_iou(inputs, target).diagonal() + + +class DetectionLayer(nn.Module): + """A YOLO detection layer. + + A YOLO model has usually 1 - 3 detection layers at different + resolutions. The loss should be summed from all of them. + """ + + def __init__( + self, + num_classes: int, + anchor_dims: List[Tuple[int, int]], + anchor_ids: List[int], + xy_scale: float = 1.0, + input_is_normalized: bool = False, + ignore_threshold: float = 0.5, + overlap_loss_func: Optional[Callable] = None, + class_loss_func: Optional[Callable] = None, + confidence_loss_func: Optional[Callable] = None, + image_space_loss: bool = False, + overlap_loss_multiplier: float = 1.0, + class_loss_multiplier: float = 1.0, + confidence_loss_multiplier: float = 1.0, + ) -> None: + """ + Args: + num_classes: Number of different classes that this layer predicts. + anchor_dims: A list of all the predefined anchor box dimensions. The list should + contain (width, height) tuples in the network input resolution (relative to the + width and height defined in the configuration file). + anchor_ids: List of indices to ``anchor_dims`` that is used to select the (usually 3) + anchors that this layer uses. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. + Using a value > 1.0 helps to produce coordinate values close to one. + input_is_normalized: The input is normalized by logistic activation in the previous + layer. In this case the detection layer will not take the sigmoid of the coordinate + and probability predictions, and the width and height are scaled up so that the + maximum value is four times the anchor dimension + ignore_threshold: If a predictor is not responsible for predicting any target, but the + corresponding anchor has IoU with some target greater than this threshold, the + predictor will not be taken into account when calculating the confidence loss. + overlap_loss_func: Loss function for bounding box coordinates. Default is the sum of + squared errors. + class_loss_func: Loss function for class probability distribution. Default is the sum + of squared errors. + confidence_loss_func: Loss function for confidence score. Default is the sum of squared + errors. + image_space_loss: If set to ``True``, the overlap loss function will receive the + bounding box `(x1, y1, x2, y2)` coordinates, scaled to the input image size. This is + needed for the IoU losses introduced in YOLOv4. Otherwise the loss will be computed + from the x, y, width, and height values, as predicted by the network (i.e. relative + to the anchor box, and width and height are logarithmic). + overlap_loss_multiplier: Multiply the overlap loss by this factor. + class_loss_multiplier: Multiply the classification loss by this factor. + confidence_loss_multiplier: Multiply the confidence loss by this factor. + """ + super().__init__() + + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError("YOLO model uses `torchvision`, which is not installed yet.") + + self.num_classes = num_classes + self.all_anchor_dims = anchor_dims + self.anchor_dims = [anchor_dims[i] for i in anchor_ids] + self.anchor_map = [anchor_ids.index(i) if i in anchor_ids else -1 for i in range(len(anchor_dims))] + self.xy_scale = xy_scale + self.input_is_normalized = input_is_normalized + self.ignore_threshold = ignore_threshold + + self.overlap_loss_func = overlap_loss_func or SELoss() + self.class_loss_func = class_loss_func or SELoss() + self.confidence_loss_func = confidence_loss_func or nn.MSELoss(reduction="none") + self.image_space_loss = image_space_loss + self.overlap_loss_multiplier = overlap_loss_multiplier + self.class_loss_multiplier = class_loss_multiplier + self.confidence_loss_multiplier = confidence_loss_multiplier + + def forward( + self, x: Tensor, image_size: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None + ) -> Tuple[Tensor, Dict[str, Tensor]]: + """Runs a forward pass through this YOLO detection layer. + + Maps cell-local coordinates to global coordinates in the image space, scales the bounding + boxes with the anchors, converts the center coordinates to corner coordinates, and maps + probabilities to the `]0, 1[` range using sigmoid. + + If targets are given, computes also losses from the predictions and the targets. This layer + is responsible only for the targets that best match one of the anchors assigned to this + layer. + + Args: + x: The output from the previous layer. Tensor of size + ``[batch_size, boxes_per_cell * (num_classes + 5), height, width]``. + image_size: Image width and height in a vector (defines the scale of the predicted and + target coordinates). + targets: If set, computes losses from detection layers against these targets. A list of + dictionaries, one for each image. + + Returns: + output (Tensor), losses (Dict[str, Tensor]), hits (int): Layer output tensor, sized + ``[batch_size, num_anchors * height * width, num_classes + 5]``. If training targets + were provided, also returns a dictionary of losses and the number of targets that this + layer was responsible for. + """ + batch_size, num_features, height, width = x.shape + num_attrs = self.num_classes + 5 + boxes_per_cell = num_features // num_attrs + if boxes_per_cell != len(self.anchor_dims): + raise MisconfigurationException( + "The model predicts {} bounding boxes per cell, but {} anchor boxes are defined " + "for this layer.".format(boxes_per_cell, len(self.anchor_dims)) + ) + + # Reshape the output to have the bounding box attributes of each grid cell on its own row. + x = x.permute(0, 2, 3, 1) # [batch_size, height, width, boxes_per_cell * num_attrs] + x = x.view(batch_size, height, width, boxes_per_cell, num_attrs) + + # Take the sigmoid of the bounding box coordinates, confidence score, and class + # probabilities, unless the input is normalized by the previous layer activation. + if self.input_is_normalized: + xy = x[..., :2] + confidence = x[..., 4] + classprob = x[..., 5:] + else: + xy = torch.sigmoid(x[..., :2]) + confidence = torch.sigmoid(x[..., 4]) + classprob = torch.sigmoid(x[..., 5:]) + wh = x[..., 2:4] + + # Eliminate grid sensitivity. The previous layer should output extremely high values for + # the sigmoid to produce x/y coordinates close to one. YOLOv4 solves this by scaling the + # x/y coordinates. + xy = xy * self.xy_scale - 0.5 * (self.xy_scale - 1) + + image_xy = self._global_xy(xy, image_size) + if self.input_is_normalized: + image_wh = 4 * torch.square(wh) * torch.tensor(self.anchor_dims, dtype=wh.dtype, device=wh.device) + else: + image_wh = torch.exp(wh) * torch.tensor(self.anchor_dims, dtype=wh.dtype, device=wh.device) + boxes = _corner_coordinates(image_xy, image_wh) + output = torch.cat((boxes, confidence.unsqueeze(-1), classprob), -1) + output = output.reshape(batch_size, height * width * boxes_per_cell, num_attrs) + + if targets is None: + return output + + lc_mask = self._low_confidence_mask(boxes, targets) + if not self.image_space_loss: + boxes = torch.cat((xy, wh), -1) + losses, hits = self._calculate_losses(boxes, confidence, classprob, targets, image_size, lc_mask) + return output, losses, hits + + def _global_xy(self, xy: Tensor, image_size: Tensor) -> Tensor: + """Adds offsets to the predicted box center coordinates to obtain global coordinates to the image. + + The predicted coordinates are interpreted as coordinates inside a grid cell whose width and + height is 1. Adding offset to the cell, dividing by the grid size, and multiplying by the + image size, we get global coordinates in the image scale. + + Args: + xy: The predicted center coordinates before scaling. Values from zero to one in a + tensor sized ``[batch_size, height, width, boxes_per_cell, 2]``. + image_size: Width and height in a vector that will be used to scale the coordinates. + + Returns: + Global coordinates scaled to the size of the network input image, in a tensor with the + same shape as the input tensor. + """ + height = xy.shape[1] + width = xy.shape[2] + grid_size = torch.tensor([width, height], device=xy.device) + + x_range = torch.arange(width, device=xy.device) + y_range = torch.arange(height, device=xy.device) + grid_y, grid_x = torch.meshgrid(y_range, x_range) + offset = torch.stack((grid_x, grid_y), -1) # [height, width, 2] + offset = offset.unsqueeze(2) # [height, width, 1, 2] + + scale = torch.true_divide(image_size, grid_size) + return (xy + offset) * scale + + def _low_confidence_mask(self, boxes: Tensor, targets: List[Dict[str, Tensor]]) -> Tensor: + """Initializes the mask that will be used to select predictors that are not predicting any ground-truth + target. The value will be ``True``, unless the predicted box overlaps any target significantly (IoU greater + than ``self.ignore_threshold``). + + Args: + boxes: The predicted corner coordinates in the image space. Tensor of size + ``[batch_size, height, width, boxes_per_cell, 4]``. + targets: List of dictionaries of ground-truth targets, one dictionary per image. + + Returns: + A boolean tensor shaped ``[batch_size, height, width, boxes_per_cell]`` with ``False`` + where the predicted box overlaps a target significantly and ``True`` elsewhere. + """ + batch_size, height, width, boxes_per_cell, num_coords = boxes.shape + num_preds = height * width * boxes_per_cell + boxes = boxes.view(batch_size, num_preds, num_coords) + + results = torch.ones((batch_size, num_preds), dtype=torch.bool, device=boxes.device) + for image_idx, (image_boxes, image_targets) in enumerate(zip(boxes, targets)): + target_boxes = image_targets["boxes"] + if target_boxes.shape[0] > 0: + ious = box_iou(image_boxes, target_boxes) # [num_preds, num_targets] + best_iou = ious.max(-1).values # [num_preds] + results[image_idx] = best_iou <= self.ignore_threshold + + return results.view((batch_size, height, width, boxes_per_cell)) + + def _calculate_losses( + self, + boxes: Tensor, + confidence: Tensor, + classprob: Tensor, + targets: List[Dict[str, Tensor]], + image_size: Tensor, + lc_mask: Tensor, + ) -> Dict[str, Tensor]: + """From the targets that are in the image space calculates the actual targets for the network predictions, + and returns a dictionary of training losses. + + Args: + boxes: The predicted bounding boxes. A tensor sized + ``[batch_size, height, width, boxes_per_cell, 4]``. + confidence: The confidence predictions, normalized to `[0, 1]`. A tensor sized + ``[batch_size, height, width, boxes_per_cell]``. + classprob: The class probability predictions, normalized to `[0, 1]`. A tensor sized + ``[batch_size, height, width, boxes_per_cell, num_classes]``. + targets: List of dictionaries of target values, one dictionary for each image. + image_size: Width and height in a vector that defines the scale of the target + coordinates. + lc_mask: A boolean mask containing ``True`` where the predicted box does not overlap + any target significantly. + + Returns: + losses (Dict[str, Tensor]), hits (int): A dictionary of training losses and the number + of targets that this layer was responsible for. + """ + batch_size, height, width, boxes_per_cell, _ = boxes.shape + device = boxes.device + assert batch_size == len(targets) + + # A multiplier for scaling image coordinates to feature map coordinates + grid_size = torch.tensor([width, height], device=device) + image_to_grid = torch.true_divide(grid_size, image_size) + + anchor_wh = torch.tensor(self.all_anchor_dims, dtype=boxes.dtype, device=device) + anchor_map = torch.tensor(self.anchor_map, dtype=torch.int64, device=device) + + # List of predicted and target values for the predictors that are responsible for + # predicting a target. + target_xy = [] + target_wh = [] + target_label = [] + size_compensation = [] + pred_boxes = [] + pred_classprob = [] + pred_confidence = [] + hits = 0 + + for image_idx, image_targets in enumerate(targets): + target_boxes = image_targets["boxes"] + if target_boxes.shape[0] < 1: + continue + + # Bounding box corner coordinates are converted to center coordinates, width, and + # height. + wh = target_boxes[:, 2:4] - target_boxes[:, 0:2] + xy = target_boxes[:, 0:2] + (wh / 2) + + # The center coordinates are converted to the feature map dimensions so that the whole + # number tells the cell index and the fractional part tells the location inside the cell. + grid_xy = xy * image_to_grid + cell_i = grid_xy[:, 0].to(torch.int64).clamp(0, width - 1) + cell_j = grid_xy[:, 1].to(torch.int64).clamp(0, height - 1) + + # We want to know which anchor box overlaps a ground truth box more than any other + # anchor box. We know that the anchor box is located in the same grid cell as the + # ground truth box. For each prior shape (width, height), we calculate the IoU with + # all ground truth boxes, assuming the boxes are at the same location. Then for each + # target, we select the prior shape that gives the highest IoU. + ious = _aligned_iou(wh, anchor_wh) + best_anchors = ious.max(1).indices + + # ``anchor_map`` maps the anchor indices to the predictors in this layer, or to -1 if + # it's not an anchor of this layer. We ignore the predictions if the best anchor is in + # another layer. + predictors = anchor_map[best_anchors] + selected = predictors >= 0 + cell_i = cell_i[selected] + cell_j = cell_j[selected] + predictors = predictors[selected] + wh = wh[selected] + # sum() is equivalent to count_nonzero() and available before PyTorch 1.7. + hits += selected.sum() + + # The "low-confidence" mask is used to select predictors that are not responsible for + # predicting any object, for calculating the part of the confidence loss with zero as + # the target confidence. + lc_mask[image_idx, cell_j, cell_i, predictors] = False + + # IoU losses are calculated from the image space coordinates. The squared-error loss is + # calculated from the raw predicted values. + if self.image_space_loss: + xy = xy[selected] + target_xy.append(xy) + target_wh.append(wh) + else: + grid_xy = grid_xy[selected] + best_anchors = best_anchors[selected] + relative_xy = grid_xy - grid_xy.floor() + if self.input_is_normalized: + relative_wh = torch.sqrt(wh / (4 * anchor_wh[best_anchors] + 1e-16)) + else: + relative_wh = torch.log(wh / anchor_wh[best_anchors] + 1e-16) + target_xy.append(relative_xy) + target_wh.append(relative_wh) + + # Size compensation factor for bounding box overlap loss is calculated from unit width + # and height. + unit_wh = wh / image_size + size_compensation.append(2 - (unit_wh[:, 0] * unit_wh[:, 1])) + + # The data may contain a different number of classes than this detection layer. In case + # a label is greater than the number of classes that this layer predicts, it will be + # mapped to the last class. + labels = image_targets["labels"] + labels = labels[selected] + labels = torch.min(labels, torch.tensor(self.num_classes - 1, device=device)) + target_label.append(labels) + + pred_boxes.append(boxes[image_idx, cell_j, cell_i, predictors]) + pred_classprob.append(classprob[image_idx, cell_j, cell_i, predictors]) + pred_confidence.append(confidence[image_idx, cell_j, cell_i, predictors]) + + losses = dict() + + if pred_boxes and target_xy and target_wh: + size_compensation = torch.cat(size_compensation) + pred_boxes = torch.cat(pred_boxes) + if self.image_space_loss: + target_boxes = _corner_coordinates(torch.cat(target_xy), torch.cat(target_wh)) + else: + target_boxes = torch.cat((torch.cat(target_xy), torch.cat(target_wh)), -1) + overlap_loss = self.overlap_loss_func(pred_boxes, target_boxes) + overlap_loss = overlap_loss * size_compensation + overlap_loss = overlap_loss.sum() / batch_size + losses["overlap"] = overlap_loss * self.overlap_loss_multiplier + else: + losses["overlap"] = torch.tensor(0.0, device=device) + + if pred_classprob and target_label: + pred_classprob = torch.cat(pred_classprob) + target_label = torch.cat(target_label) + target_classprob = torch.nn.functional.one_hot(target_label, self.num_classes) + target_classprob = target_classprob.to(dtype=pred_classprob.dtype) + class_loss = self.class_loss_func(pred_classprob, target_classprob) + class_loss = class_loss.sum() / batch_size + losses["class"] = class_loss * self.class_loss_multiplier + else: + losses["class"] = torch.tensor(0.0, device=device) + + pred_low_confidence = confidence[lc_mask] + target_low_confidence = torch.zeros_like(pred_low_confidence) + if pred_confidence: + pred_high_confidence = torch.cat(pred_confidence) + target_high_confidence = torch.ones_like(pred_high_confidence) + pred_confidence = torch.cat((pred_low_confidence, pred_high_confidence)) + target_confidence = torch.cat((target_low_confidence, target_high_confidence)) + else: + pred_confidence = pred_low_confidence + target_confidence = target_low_confidence + confidence_loss = self.confidence_loss_func(pred_confidence, target_confidence) + confidence_loss = confidence_loss.sum() / batch_size + losses["confidence"] = confidence_loss * self.confidence_loss_multiplier + + return losses, hits + + +class Mish(nn.Module): + """Mish activation.""" + + def forward(self, x): + return x * torch.tanh(nn.functional.softplus(x)) + + +class RouteLayer(nn.Module): + """Route layer concatenates the output (or part of it) from given layers.""" + + def __init__(self, source_layers: List[int], num_chunks: int, chunk_idx: int) -> None: + """ + Args: + source_layers: Indices of the layers whose output will be concatenated. + num_chunks: Layer outputs will be split into this number of chunks. + chunk_idx: Only the chunks with this index will be concatenated. + """ + super().__init__() + self.source_layers = source_layers + self.num_chunks = num_chunks + self.chunk_idx = chunk_idx + + def forward(self, x, outputs): + chunks = [torch.chunk(outputs[layer], self.num_chunks, dim=1)[self.chunk_idx] for layer in self.source_layers] + return torch.cat(chunks, dim=1) + + +class ShortcutLayer(nn.Module): + """Shortcut layer adds a residual connection from the source layer.""" + + def __init__(self, source_layer: int) -> None: + """ + Args: + source_layer: Index of the layer whose output will be added to the output of the + previous layer. + """ + super().__init__() + self.source_layer = source_layer + + def forward(self, x, outputs): + return outputs[-1] + outputs[self.source_layer] diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py new file mode 100644 index 0000000000..ebb494f5ef --- /dev/null +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -0,0 +1,609 @@ +import logging +from typing import Any, Dict, List, Optional, Tuple, Type + +import numpy as np +import torch +import torch.nn as nn +from pytorch_lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_info +from torch import Tensor, optim + +from pl_bolts.models.detection.yolo.yolo_layers import DetectionLayer, RouteLayer, ShortcutLayer +from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR +from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _TORCHVISION_AVAILABLE: + from torchvision.ops import nms + from torchvision.transforms import functional as F +else: + warn_missing_pkg("torchvision") + +log = logging.getLogger(__name__) + + +class YOLO(LightningModule): + """PyTorch Lightning implementation of YOLOv3 and YOLOv4. + + *YOLOv3 paper*: `Joseph Redmon and Ali Farhadi `_ + + *YOLOv4 paper*: `Alexey Bochkovskiy, Chien-Yao Wang, and Hong-Yuan Mark Liao `_ + + *Implementation*: `Seppo Enarvi `_ + + The network architecture can be read from a Darknet configuration file using the + :class:`~pl_bolts.models.detection.yolo.yolo_config.YOLOConfiguration` class, or created by + some other means, and provided as a list of PyTorch modules. + + The input from the data loader is expected to be a list of images. Each image is a tensor with + shape ``[channels, height, width]``. The images from a single batch will be stacked into a + single tensor, so the sizes have to match. Different batches can have different image sizes, as + long as the size is divisible by the ratio in which the network downsamples the input. + + During training, the model expects both the input tensors and a list of targets. *Each target is + a dictionary containing*: + + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in `(x1, y1, x2, y2)` format + - labels (``Int64Tensor[N]``): the class label for each ground-truth box + + :func:`~pl_bolts.models.detection.yolo.yolo_module.YOLO.forward` method returns all + predictions from all detection layers in all images in one tensor with shape + ``[images, predictors, classes + 5]``. The coordinates are scaled to the input image size. + During training it also returns a dictionary containing the classification, box overlap, and + confidence losses. + + During inference, the model requires only the input tensors. + :func:`~pl_bolts.models.detection.yolo.yolo_module.YOLO.infer` method filters and processes the + predictions. *The processed output includes the following tensors*: + + - boxes (``FloatTensor[N, 4]``): predicted bounding box `(x1, y1, x2, y2)` coordinates in image space + - scores (``FloatTensor[N]``): detection confidences + - labels (``Int64Tensor[N]``): the predicted labels for each image + + Weights can be loaded from a Darknet model file using ``load_darknet_weights()``. + + CLI command:: + + # PascalVOC + wget https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4-tiny-3l.cfg + python yolo_module.py --config yolov4-tiny-3l.cfg --data_dir . --gpus 8 --batch_size 8 + """ + + def __init__( + self, + network: nn.ModuleList, + optimizer: Type[optim.Optimizer] = optim.SGD, + optimizer_params: Dict[str, Any] = {"lr": 0.001, "momentum": 0.9, "weight_decay": 0.0005}, + lr_scheduler: Type[optim.lr_scheduler._LRScheduler] = LinearWarmupCosineAnnealingLR, + lr_scheduler_params: Dict[str, Any] = {"warmup_epochs": 1, "max_epochs": 300, "warmup_start_lr": 0.0}, + confidence_threshold: float = 0.2, + nms_threshold: float = 0.45, + max_predictions_per_image: int = -1, + ) -> None: + """ + Args: + network: A list of network modules. This can be obtained from a Darknet configuration + using the :func:`~pl_bolts.models.detection.yolo.yolo_config.YOLOConfiguration.get_network` + method. + optimizer: Which optimizer class to use for training. + optimizer_params: Parameters to pass to the optimizer constructor. + lr_scheduler: Which learning rate scheduler class to use for training. + lr_scheduler_params: Parameters to pass to the learning rate scheduler constructor. + confidence_threshold: Postprocessing will remove bounding boxes whose + confidence score is not higher than this threshold. + nms_threshold: Non-maximum suppression will remove bounding boxes whose IoU with a higher + confidence box is higher than this threshold, if the predicted categories are equal. + max_predictions_per_image: If non-negative, keep at most this number of + highest-confidence predictions per image. + """ + super().__init__() + + if not _TORCHVISION_AVAILABLE: + raise ModuleNotFoundError("YOLO model uses `torchvision`, which is not installed yet.") # pragma: no-cover + + self.network = network + self.optimizer_class = optimizer + self.optimizer_params = optimizer_params + self.lr_scheduler_class = lr_scheduler + self.lr_scheduler_params = lr_scheduler_params + self.confidence_threshold = confidence_threshold + self.nms_threshold = nms_threshold + self.max_predictions_per_image = max_predictions_per_image + + def forward( + self, images: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None + ) -> Tuple[Tensor, Dict[str, Tensor]]: + """Runs a forward pass through the network (all layers listed in ``self.network``), and if training targets + are provided, computes the losses from the detection layers. + + Detections are concatenated from the detection layers. Each image will produce + `N * num_anchors * grid_height * grid_width` detections, where `N` depends on the number of + detection layers. For one detection layer `N = 1`, and each detection layer increases it by + a number that depends on the size of the feature map on that layer. For example, if the + feature map is twice as wide and high as the grid, the layer will add four times more + features. + + Args: + images: Images to be processed. Tensor of size + ``[batch_size, num_channels, height, width]``. + targets: If set, computes losses from detection layers against these targets. A list of + dictionaries, one for each image. + + Returns: + detections (:class:`~torch.Tensor`), losses (Dict[str, :class:`~torch.Tensor`]): + Detections, and if targets were provided, a dictionary of losses. Detections are shaped + ``[batch_size, num_predictors, num_classes + 5]``, where ``num_predictors`` is the + total number of cells in all detection layers times the number of boxes predicted by + one cell. The predicted box coordinates are in `(x1, y1, x2, y2)` format and scaled to + the input image size. + """ + outputs = [] # Outputs from all layers + detections = [] # Outputs from detection layers + losses = [] # Losses from detection layers + hits = [] # Number of targets each detection layer was responsible for + + image_height = images.shape[2] + image_width = images.shape[3] + image_size = torch.tensor([image_width, image_height], device=images.device) + + x = images + for module in self.network: + if isinstance(module, (RouteLayer, ShortcutLayer)): + x = module(x, outputs) + elif isinstance(module, DetectionLayer): + if targets is None: + x = module(x, image_size) + detections.append(x) + else: + x, layer_losses, layer_hits = module(x, image_size, targets) + detections.append(x) + losses.append(layer_losses) + hits.append(layer_hits) + else: + x = module(x) + + outputs.append(x) + + detections = torch.cat(detections, 1) + if targets is None: + return detections + + total_hits = sum(hits) + num_targets = sum(len(image_targets["boxes"]) for image_targets in targets) + if total_hits != num_targets: + log.warning( + f"{num_targets} training targets were matched a total of {total_hits} times by detection layers. " + "Anchors may have been configured incorrectly." + ) + for layer_idx, layer_hits in enumerate(hits): + hit_rate = torch.true_divide(layer_hits, total_hits) if total_hits > 0 else 1.0 + self.log(f"layer_{layer_idx}_hit_rate", hit_rate, sync_dist=False) + + def total_loss(loss_name): + """Returns the sum of the loss over detection layers.""" + loss_tuple = tuple(layer_losses[loss_name] for layer_losses in losses) + return torch.stack(loss_tuple).sum() + + losses = {loss_name: total_loss(loss_name) for loss_name in losses[0].keys()} + return detections, losses + + def configure_optimizers(self) -> Tuple[List, List]: + """Constructs the optimizer and learning rate scheduler.""" + optimizer = self.optimizer_class(self.parameters(), **self.optimizer_params) + lr_scheduler = self.lr_scheduler_class(optimizer, **self.lr_scheduler_params) + return [optimizer], [lr_scheduler] + + def training_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int) -> Dict[str, Tensor]: + """Computes the training loss. + + Args: + batch: A tuple of images and targets. Images is a list of 3-dimensional tensors. + Targets is a list of dictionaries that contain ground-truth boxes, labels, etc. + batch_idx: The index of this batch. + + Returns: + A dictionary that includes the training loss in 'loss'. + """ + images, targets = self._validate_batch(batch) + _, losses = self(images, targets) + total_loss = torch.stack(tuple(losses.values())).sum() + + # sync_dist=True is broken in some versions of Lightning and may cause the sum of the loss + # across GPUs to be returned. + for name, value in losses.items(): + self.log(f"train/{name}_loss", value, prog_bar=True, sync_dist=False) + self.log("train/total_loss", total_loss, sync_dist=False) + + return {"loss": total_loss} + + def validation_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int): + """Evaluates a batch of data from the validation set. + + Args: + batch: A tuple of images and targets. Images is a list of 3-dimensional tensors. + Targets is a list of dictionaries that contain ground-truth boxes, labels, etc. + batch_idx: The index of this batch + """ + images, targets = self._validate_batch(batch) + detections, losses = self(images, targets) + detections = self._split_detections(detections) + detections = self._filter_detections(detections) + total_loss = torch.stack(tuple(losses.values())).sum() + + for name, value in losses.items(): + self.log(f"val/{name}_loss", value, sync_dist=True) + self.log("val/total_loss", total_loss, sync_dist=True) + + def test_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int): + """Evaluates a batch of data from the test set. + + Args: + batch: A tuple of images and targets. Images is a list of 3-dimensional tensors. + Targets is a list of dictionaries that contain ground-truth boxes, labels, etc. + batch_idx: The index of this batch. + """ + images, targets = self._validate_batch(batch) + detections, losses = self(images, targets) + detections = self._split_detections(detections) + detections = self._filter_detections(detections) + total_loss = torch.stack(tuple(losses.values())).sum() + + for name, value in losses.items(): + self.log(f"test/{name}_loss", value, sync_dist=True) + self.log("test/total_loss", total_loss, sync_dist=True) + + def infer(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """Feeds an image to the network and returns the detected bounding boxes, confidence scores, and class + labels. + + Args: + image: An input image, a tensor of uint8 values sized ``[channels, height, width]``. + + Returns: + boxes (:class:`~torch.Tensor`), confidences (:class:`~torch.Tensor`), labels (:class:`~torch.Tensor`): + A matrix of detected bounding box `(x1, y1, x2, y2)` coordinates, a vector of + confidences for the bounding box detections, and a vector of predicted class labels. + """ + if not isinstance(image, torch.Tensor): + image = F.to_tensor(image) + + self.eval() + detections = self(image.unsqueeze(0)) + detections = self._split_detections(detections) + detections = self._filter_detections(detections) + boxes = detections["boxes"][0] + scores = detections["scores"][0] + labels = detections["labels"][0] + return boxes, scores, labels + + def load_darknet_weights(self, weight_file): + """Loads weights to layer modules from a pretrained Darknet model. + + One may want to continue training from the pretrained weights, on a dataset with a + different number of object categories. The number of kernels in the convolutional layers + just before each detection layer depends on the number of output classes. The Darknet + solution is to truncate the weight file and stop reading weights at the first incompatible + layer. For this reason the function silently leaves the rest of the layers unchanged, when + the weight file ends. + + Args: + weight_file: A file object containing model weights in the Darknet binary format. + """ + version = np.fromfile(weight_file, count=3, dtype=np.int32) + images_seen = np.fromfile(weight_file, count=1, dtype=np.int64) + rank_zero_info( + f"Loading weights from Darknet model version {version[0]}.{version[1]}.{version[2]} " + f"that has been trained on {images_seen[0]} images." + ) + + def read(tensor): + """Reads the contents of ``tensor`` from the current position of ``weight_file``. + + If there's no more data in ``weight_file``, returns without error. + """ + x = np.fromfile(weight_file, count=tensor.numel(), dtype=np.float32) + if x.shape[0] == 0: + return + x = torch.from_numpy(x).view_as(tensor) + with torch.no_grad(): + tensor.copy_(x) + + for module in self.network: + # Weights are loaded only to convolutional layers + if not isinstance(module, nn.Sequential): + continue + + conv = module[0] + assert isinstance(conv, nn.Conv2d) + + # Convolution may be followed by batch normalization, in which case we read the batch + # normalization parameters and not the convolution bias. + if len(module) > 1 and isinstance(module[1], nn.BatchNorm2d): + bn = module[1] + read(bn.bias) + read(bn.weight) + read(bn.running_mean) + read(bn.running_var) + else: + read(conv.bias) + + read(conv.weight) + + def _validate_batch( + self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]] + ) -> Tuple[Tensor, List[Dict[str, Tensor]]]: + """Reads a batch of data, validates the format, and stacks the images into a single tensor. + + Args: + batch: The batch of data read by the :class:`~torch.utils.data.DataLoader`. + + Returns: + The input batch with images stacked into a single tensor. + """ + images, targets = batch + + if len(images) != len(targets): + raise ValueError(f"Got {len(images)} images, but targets for {len(targets)} images.") + + for image in images: + if not isinstance(image, Tensor): + raise ValueError(f"Expected image to be of type Tensor, got {type(image)}.") + + for target in targets: + boxes = target["boxes"] + if not isinstance(boxes, Tensor): + raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.") + if (len(boxes.shape) != 2) or (boxes.shape[-1] != 4): + raise ValueError(f"Expected target boxes to be tensors of shape [N, 4], got {list(boxes.shape)}.") + labels = target["labels"] + if not isinstance(labels, Tensor): + raise ValueError(f"Expected target labels to be of type Tensor, got {type(labels)}.") + if len(labels.shape) != 1: + raise ValueError(f"Expected target labels to be tensors of shape [N], got {list(labels.shape)}.") + + images = torch.stack(images) + return images, targets + + def _split_detections(self, detections: Tensor) -> Dict[str, Tensor]: + """Splits the detection tensor returned by a forward pass into a dictionary. + + The fields of the dictionary are as follows: + - boxes (``Tensor[batch_size, N, 4]``): detected bounding box `(x1, y1, x2, y2)` coordinates + - scores (``Tensor[batch_size, N]``): detection confidences + - classprobs (``Tensor[batch_size, N]``): probabilities of the best classes + - labels (``Int64Tensor[batch_size, N]``): the predicted labels for each image + + Args: + detections: A tensor of detected bounding boxes and their attributes. + + Returns: + A dictionary of detection results. + """ + boxes = detections[..., :4] + scores = detections[..., 4] + classprobs = detections[..., 5:] + classprobs, labels = torch.max(classprobs, -1) + return {"boxes": boxes, "scores": scores, "classprobs": classprobs, "labels": labels} + + def _filter_detections(self, detections: Dict[str, Tensor]) -> Dict[str, List[Tensor]]: + """Filters detections based on confidence threshold. Then for every class performs non-maximum suppression + (NMS). NMS iterates the bounding boxes that predict this class in descending order of confidence score, and + removes lower scoring boxes that have an IoU greater than the NMS threshold with a higher scoring box. + Finally the detections are sorted by descending confidence and possible truncated to the maximum number of + predictions. + + Args: + detections: All detections. A dictionary of tensors, each containing the predictions + from all images. + + Returns: + Filtered detections. A dictionary of lists, each containing a tensor per image. + """ + boxes = detections["boxes"] + scores = detections["scores"] + classprobs = detections["classprobs"] + labels = detections["labels"] + + out_boxes = [] + out_scores = [] + out_classprobs = [] + out_labels = [] + + for img_boxes, img_scores, img_classprobs, img_labels in zip(boxes, scores, classprobs, labels): + # Select detections with high confidence score. + selected = img_scores > self.confidence_threshold + img_boxes = img_boxes[selected] + img_scores = img_scores[selected] + img_classprobs = img_classprobs[selected] + img_labels = img_labels[selected] + + img_out_boxes = boxes.new_zeros((0, 4)) + img_out_scores = scores.new_zeros(0) + img_out_classprobs = classprobs.new_zeros(0) + img_out_labels = labels.new_zeros(0) + + # Iterate through the unique object classes detected in the image and perform non-maximum + # suppression for the objects of the class in question. + for cls_label in labels.unique(): + selected = img_labels == cls_label + cls_boxes = img_boxes[selected] + cls_scores = img_scores[selected] + cls_classprobs = img_classprobs[selected] + cls_labels = img_labels[selected] + + # NMS will crash if there are too many boxes. + cls_boxes = cls_boxes[:100000] + cls_scores = cls_scores[:100000] + selected = nms(cls_boxes, cls_scores, self.nms_threshold) + + img_out_boxes = torch.cat((img_out_boxes, cls_boxes[selected])) + img_out_scores = torch.cat((img_out_scores, cls_scores[selected])) + img_out_classprobs = torch.cat((img_out_classprobs, cls_classprobs[selected])) + img_out_labels = torch.cat((img_out_labels, cls_labels[selected])) + + # Sort by descending confidence and limit the maximum number of predictions. + indices = torch.argsort(img_out_scores, descending=True) + if self.max_predictions_per_image >= 0: + indices = indices[: self.max_predictions_per_image] + out_boxes.append(img_out_boxes[indices]) + out_scores.append(img_out_scores[indices]) + out_classprobs.append(img_out_classprobs[indices]) + out_labels.append(img_out_labels[indices]) + + return {"boxes": out_boxes, "scores": out_scores, "classprobs": out_classprobs, "labels": out_labels} + + +class Resize: + """Rescales the image and target to given dimensions. + + Args: + output_size (tuple or int): Desired output size. If tuple (height, width), the output is + matched to ``output_size``. If int, the smaller of the image edges is matched to + ``output_size``, keeping the aspect ratio the same. + """ + + def __init__(self, output_size: tuple) -> None: + self.output_size = output_size + + def __call__(self, image: Tensor, target: Dict[str, Any]): + """ + Args: + tensor: Tensor image to be resized. + target: Dictionary of detection targets. + + Returns: + Resized Tensor image. + """ + height, width = image.shape[-2:] + original_size = torch.tensor([height, width]) + scale_y, scale_x = torch.tensor(self.output_size) / original_size + scale = torch.tensor([scale_x, scale_y, scale_x, scale_y], device=target["boxes"].device) + image = F.resize(image, self.output_size) + target["boxes"] = target["boxes"] * scale + return image, target + + +def run_cli(): + from argparse import ArgumentParser + + from pytorch_lightning import Trainer, seed_everything + + from pl_bolts.datamodules import VOCDetectionDataModule + from pl_bolts.datamodules.vocdetection_datamodule import Compose + from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration + + seed_everything(42) + + parser = ArgumentParser() + parser.add_argument( + "--config", + type=str, + metavar="PATH", + required=True, + help="read model configuration from PATH", + ) + parser.add_argument( + "--darknet-weights", + type=str, + metavar="PATH", + help="read the initial model weights from PATH in Darknet format", + ) + parser.add_argument( + "--lr", + type=float, + metavar="LR", + default=0.0013, + help="learning rate after the warmup period", + ) + parser.add_argument( + "--momentum", + type=float, + metavar="GAMMA", + default=0.9, + help="if nonzero, the optimizer uses momentum with factor GAMMA", + ) + parser.add_argument( + "--weight-decay", + type=float, + metavar="LAMBDA", + default=0.0005, + help="if nonzero, the optimizer uses weight decay (L2 penalty) with factor LAMBDA", + ) + parser.add_argument( + "--warmup-epochs", + type=int, + metavar="N", + default=1, + help="learning rate warmup period is N epochs", + ) + parser.add_argument( + "--max-epochs", + type=int, + metavar="N", + default=300, + help="train at most N epochs", + ) + parser.add_argument( + "--initial-lr", + type=float, + metavar="LR", + default=0.0, + help="learning rate before the warmup period", + ) + parser.add_argument( + "--confidence-threshold", + type=float, + metavar="THRESHOLD", + default=0.001, + help="keep predictions only if the confidence is above THRESHOLD", + ) + parser.add_argument( + "--nms-threshold", + type=float, + metavar="THRESHOLD", + default=0.45, + help="non-maximum suppression removes predicted boxes that have IoU greater than " + "THRESHOLD with a higher scoring box", + ) + parser.add_argument( + "--max-predictions-per-image", + type=int, + metavar="N", + default=100, + help="keep at most N best predictions", + ) + + parser = VOCDetectionDataModule.add_argparse_args(parser) + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args() + + config = YOLOConfiguration(args.config) + + transforms = [lambda image, target: (F.to_tensor(image), target), Resize((config.height, config.width))] + transforms = Compose(transforms) + datamodule = VOCDetectionDataModule.from_argparse_args(args, train_transforms=transforms, val_transforms=transforms) + + optimizer_params = {"lr": args.lr, "momentum": args.momentum, "weight_decay": args.weight_decay} + lr_scheduler_params = { + "warmup_epochs": args.warmup_epochs, + "max_epochs": args.max_epochs, + "warmup_start_lr": args.initial_lr, + } + model = YOLO( + network=config.get_network(), + optimizer_params=optimizer_params, + lr_scheduler_params=lr_scheduler_params, + confidence_threshold=args.confidence_threshold, + nms_threshold=args.nms_threshold, + max_predictions_per_image=args.max_predictions_per_image, + ) + if args.darknet_weights is not None: + with open(args.darknet_weights) as weight_file: + model.load_darknet_weights(weight_file) + + trainer = Trainer.from_argparse_args(args) + trainer.fit(model, datamodule=datamodule) + + +if __name__ == "__main__": + run_cli() diff --git a/tests/data/yolo.cfg b/tests/data/yolo.cfg new file mode 100644 index 0000000000..d7596ef4b5 --- /dev/null +++ b/tests/data/yolo.cfg @@ -0,0 +1,79 @@ +[net] +width=256 +height=256 +channels=3 + +[convolutional] +batch_normalize=1 +filters=8 +size=3 +stride=1 +pad=1 +activation=leaky + +[route] +layers=-1 +groups=2 +group_id=1 + +[maxpool] +size=2 +stride=2 + +[convolutional] +batch_normalize=1 +filters=2 +size=1 +stride=1 +pad=1 +activation=mish + +[convolutional] +batch_normalize=1 +filters=4 +size=3 +stride=1 +pad=1 +activation=mish + +[shortcut] +from=-3 +activation=linear + +[convolutional] +size=1 +stride=1 +pad=1 +filters=14 +activation=linear + +[yolo] +mask=2,3 +anchors=1,2, 3,4, 5,6, 9,10 +classes=2 +scale_x_y=1.05 +cls_normalizer=1.0 +iou_normalizer=0.07 +ignore_thresh=0.7 + +[route] +layers = -4 + +[upsample] +stride=2 + +[convolutional] +size=1 +stride=1 +pad=1 +filters=14 +activation=linear + +[yolo] +mask=0,1 +anchors=1,2, 3,4, 5,6, 9,10 +classes=2 +scale_x_y=1.05 +cls_normalizer=1.0 +iou_normalizer=0.07 +ignore_thresh=0.7 diff --git a/tests/models/test_detection.py b/tests/models/test_detection.py index 900a2daa13..82ca40c4d8 100644 --- a/tests/models/test_detection.py +++ b/tests/models/test_detection.py @@ -1,9 +1,14 @@ +from pathlib import Path + +import pytest import torch from pytorch_lightning import Trainer from torch.utils.data import DataLoader from pl_bolts.datasets import DummyDetectionDataset -from pl_bolts.models.detection import FasterRCNN +from pl_bolts.models.detection import YOLO, FasterRCNN, YOLOConfiguration +from pl_bolts.models.detection.yolo.yolo_layers import _aligned_iou +from tests import TEST_ROOT def _collate_fn(batch): @@ -35,3 +40,38 @@ def test_fasterrcnn_bbone_train(tmpdir): trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, train_dl, valid_dl) + + +def test_yolo(tmpdir): + config_path = Path(TEST_ROOT) / "data" / "yolo.cfg" + config = YOLOConfiguration(config_path) + model = YOLO(config.get_network()) + + image = torch.rand(1, 3, 256, 256) + model(image) + + +def test_yolo_train(tmpdir): + config_path = Path(TEST_ROOT) / "data" / "yolo.cfg" + config = YOLOConfiguration(config_path) + model = YOLO(config.get_network()) + + train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) + valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) + + trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) + trainer.fit(model, train_dataloader=train_dl, val_dataloaders=valid_dl) + + +@pytest.mark.parametrize( + "dims1, dims2, expected_ious", + [ + ( + torch.tensor([[1.0, 1.0], [10.0, 1.0], [100.0, 10.0]]), + torch.tensor([[1.0, 10.0], [2.0, 20.0]]), + torch.tensor([[1.0 / 10.0, 1.0 / 40.0], [1.0 / 19.0, 2.0 / 48.0], [10.0 / 1000.0, 20.0 / 1020.0]]), + ) + ], +) +def test_aligned_iou(dims1, dims2, expected_ious): + torch.testing.assert_allclose(_aligned_iou(dims1, dims2), expected_ious)