From 461de9624302bf078448050e76f48df8ed89e963 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Tue, 2 Feb 2021 21:03:11 +0200 Subject: [PATCH 01/61] Add YOLO object detection model --- docs/source/index.rst | 1 + docs/source/object_detection.rst | 20 + .../datamodules/vocdetection_datamodule.py | 22 +- pl_bolts/models/detection/__init__.py | 3 + pl_bolts/models/detection/yolo/__init__.py | 4 + pl_bolts/models/detection/yolo/yolo_config.py | 122 ++++ pl_bolts/models/detection/yolo/yolo_layers.py | 450 +++++++++++++ pl_bolts/models/detection/yolo/yolo_module.py | 589 ++++++++++++++++++ tests/models/test_detection.py | 128 +++- 9 files changed, 1332 insertions(+), 7 deletions(-) create mode 100644 docs/source/object_detection.rst create mode 100644 pl_bolts/models/detection/yolo/__init__.py create mode 100644 pl_bolts/models/detection/yolo/yolo_config.py create mode 100644 pl_bolts/models/detection/yolo/yolo_layers.py create mode 100644 pl_bolts/models/detection/yolo/yolo_module.py diff --git a/docs/source/index.rst b/docs/source/index.rst index f87dca6161..17852f1693 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -70,6 +70,7 @@ PyTorch-Lightning-Bolts documentation autoencoders convolutional + object_detection gans reinforce_learn self_supervised_models diff --git a/docs/source/object_detection.rst b/docs/source/object_detection.rst new file mode 100644 index 0000000000..37ace12a88 --- /dev/null +++ b/docs/source/object_detection.rst @@ -0,0 +1,20 @@ +Object Detection +================ +This package lists contributed object detection models. + +-------------- + + +Faster R-CNN +------------ + +.. autoclass:: pl_bolts.models.detection.faster_rcnn.faster_rcnn_module.FasterRCNN + :noindex: + +------------- + +YOLO +---- + +.. autoclass:: pl_bolts.models.detection.yolo.yolo_module.YOLO + :noindex: diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 97b63cc86e..8557741f11 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -152,16 +152,20 @@ def prepare_data(self) -> None: VOCDetection(self.data_dir, year=self.year, image_set="val", download=True) def train_dataloader( - self, batch_size: int = 1, image_transforms: Union[List[Callable], Callable] = None + self, + batch_size: int = 1, + transforms: Optional[List[Callable]] = [], + image_transforms: Optional[Callable] = None ) -> DataLoader: """ VOCDetection train set uses the `train` subset Args: batch_size: size of batch - transforms: custom transforms + transforms: custom transforms for image and target + image_transforms: custom image-only transforms """ - transforms = [_prepare_voc_instance] + transforms = [_prepare_voc_instance] + transforms image_transforms = image_transforms or self.train_transforms or self._default_transforms() transforms = Compose(transforms, image_transforms) dataset = VOCDetection(self.data_dir, year=self.year, image_set="train", transforms=transforms) @@ -176,15 +180,21 @@ def train_dataloader( ) return loader - def val_dataloader(self, batch_size: int = 1, image_transforms: Optional[List[Callable]] = None) -> DataLoader: + def val_dataloader( + self, + batch_size: int = 1, + transforms: Optional[List[Callable]] = [], + image_transforms: Optional[Callable] = None + ) -> DataLoader: """ VOCDetection val set uses the `val` subset Args: batch_size: size of batch - transforms: custom transforms + transforms: custom transforms for image and target + image_transforms: custom image-only transforms """ - transforms = [_prepare_voc_instance] + transforms = [_prepare_voc_instance] + transforms image_transforms = image_transforms or self.train_transforms or self._default_transforms() transforms = Compose(transforms, image_transforms) dataset = VOCDetection(self.data_dir, year=self.year, image_set="val", transforms=transforms) diff --git a/pl_bolts/models/detection/__init__.py b/pl_bolts/models/detection/__init__.py index 4c09eac3d4..e7a864f736 100644 --- a/pl_bolts/models/detection/__init__.py +++ b/pl_bolts/models/detection/__init__.py @@ -1,7 +1,10 @@ from pl_bolts.models.detection import components # noqa: F401 from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401 +from pl_bolts.models.detection.yolo import Yolo, YoloConfiguration # noqa: F401 __all__ = [ "components", "FasterRCNN", + "Yolo", + "YoloConfiguration" ] diff --git a/pl_bolts/models/detection/yolo/__init__.py b/pl_bolts/models/detection/yolo/__init__.py new file mode 100644 index 0000000000..a2785f5882 --- /dev/null +++ b/pl_bolts/models/detection/yolo/__init__.py @@ -0,0 +1,4 @@ +from pl_bolts.models.detection.yolo.yolo_config import YoloConfiguration +from pl_bolts.models.detection.yolo.yolo_module import Yolo + +__all__ = ["YoloConfiguration", "Yolo"] diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py new file mode 100644 index 0000000000..cd8e99b602 --- /dev/null +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -0,0 +1,122 @@ +import re +from warnings import warn + +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class YoloConfiguration: + def __init__(self, path: str): + """ + Parser for YOLOv4 network configuration files. + + Saves the variables from the first configuration section to attributes of this object, and + the rest of the sections to the `modules` list. + + Args: + path (str): configuration file to read + """ + with open(path, 'r') as config_file: + sections = self._read_file(config_file) + + if len(sections) < 2: + raise MisconfigurationException( + "The model configuration file should include at least two sections.") + + self.__dict__.update(sections[0]) + self.modules = sections[1:] + + def _read_file(self, config_file): + """ + Reads a YOLOv4 network configuration file and returns a list of configuration sections. + + Args: + config_file (iterable over lines): The configuration file to read. + + Returns: + sections (list): A list of configuration sections. + """ + section_re = re.compile(r'\[([^]]+)\]') + list_variables = ('layers', 'anchors', 'mask', 'scales') + variable_types = { + 'activation': str, + 'anchors': int, + 'angle': float, + 'batch': int, + 'batch_normalize': bool, + 'beta_nms': float, + 'burn_in': int, + 'channels': int, + 'classes': int, + 'cls_normalizer': float, + 'decay': float, + 'exposure': float, + 'filters': int, + 'from': int, + 'groups': int, + 'group_id': int, + 'height': int, + 'hue': float, + 'ignore_thresh': float, + 'iou_loss': str, + 'iou_normalizer': float, + 'iou_thresh': float, + 'jitter': float, + 'layers': int, + 'learning_rate': float, + 'mask': int, + 'max_batches': int, + 'max_delta': float, + 'momentum': float, + 'mosaic': bool, + 'nms_kind': str, + 'num': int, + 'obj_normalizer': float, + 'pad': bool, + 'policy': str, + 'random': bool, + 'resize': float, + 'saturation': float, + 'scales': float, + 'scale_x_y': float, + 'size': int, + 'steps': str, + 'stride': int, + 'subdivisions': int, + 'truth_thresh': float, + 'width': int + } + + section = None + sections = [] + + def convert(key, value): + """Converts a value to the correct type based on key.""" + if not key in variable_types: + warn('Unknown YOLO configuration variable: ' + key) + return key, value + if key in list_variables: + value = [variable_types[key](v) for v in value.split(',')] + else: + value = variable_types[key](value) + return key, value + + for line in config_file: + line = line.strip() + if (not line) or (line[0] == '#'): + continue + + section_match = section_re.match(line) + if section_match: + if section is not None: + sections.append(section) + section = {'type': section_match.group(1)} + else: + key, value = line.split('=') + key = key.rstrip() + value = value.lstrip() + key, value = convert(key, value) + section[key] = value + if section is not None: + sections.append(section) + + return sections diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py new file mode 100644 index 0000000000..32165b53d2 --- /dev/null +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -0,0 +1,450 @@ +from typing import List, Tuple + +import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch import nn + +from pl_bolts.utils.warnings import warn_missing_pkg + +try: + from torchvision.ops import box_iou +except ModuleNotFoundError: + warn_missing_pkg('torchvision') # pragma: no-cover + _TORCHVISION_AVAILABLE = False +else: + _TORCHVISION_AVAILABLE = True + + +def _aligned_iou(dims1, dims2): + """ + Calculates a matrix of intersections over union from box dimensions, assuming that the boxes + are located at the same coordinates. + + Arguments: + dims1 (Tensor[N, 2]): width and height of N boxes + dims2 (Tensor[M, 2]): width and height of M boxes + + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in + `dims1` and `dims2` + """ + area1 = dims1[:, 0] * dims1[:, 1] # [N] + area2 = dims2[:, 0] * dims2[:, 1] # [M] + + inter_wh = torch.min(dims1[:, None, :], dims2) # [N, M, 2] + inter = inter_wh[:, :, 0] * inter_wh[:, :, 1] # [N, M] + union = area1[:, None] + area2 - inter # [N, M] + + return inter / union + + +class DetectionLayer(nn.Module): + """ + A YOLO detection layer. A YOLO model has usually 1 - 3 detection layers at different + resolutions. The loss should be summed from all of them. + """ + + def __init__(self, + num_classes: int, + image_width: int, + image_height: int, + anchor_dims: List[Tuple[int, int]], + anchor_ids: List[int], + xy_scale: float = 1.0, + ignore_threshold: float = 0.5, + coord_loss_multiplier: float = 1.0, + class_loss_multiplier: float = 1.0, + confidence_loss_multiplier: float = 1.0): + """ + Constructs a YOLO detection layer. + + Args: + num_classes (int): Number of different classes that this layer predicts. + image_width (int): Image width (defines the scale of the anchor box and target bounding + box dimensions). + image_height (int): Image height (defines the scale of the anchor box and target + bounding box dimensions). + anchor_dims (List[Tuple[int, int]]): A list of all the predefined anchor box + dimensions. The list should contain (width, height) tuples in the network input + resolution (relative to the width and height defined in the configuration file). + anchor_ids (List[int]): List of indices to `anchor_dims` that is used to select the + (usually 3) anchors that this layer uses. + xy_scale (float): Eliminate "grid sensitivity" by scaling the box coordinates by this + factor. Using a value > 1.0 helps to produce coordinate values close to one. + ignore_threshold (float): If a predictor is not responsible for predicting any target, + but the corresponding anchor has IoU with some target greater than this threshold, + the predictor will not be taken into account when calculating the confidence loss. + coord_loss_multiplier (float): Multiply the coordinate/size loss by this factor. + class_loss_multiplier (float): Multiply the classification loss by this factor. + confidence_loss_multiplier (float): Multiply the confidence loss by this factor. + """ + super().__init__() + + if not _TORCHVISION_AVAILABLE: + raise ModuleNotFoundError( # pragma: no-cover + 'YOLO model uses `torchvision`, which is not installed yet.' + ) + + self.num_classes = num_classes + self.image_width = image_width + self.image_height = image_height + self.anchor_dims = anchor_dims + self.anchor_ids = anchor_ids + self.anchor_map = [anchor_ids.index(i) if i in anchor_ids else -1 for i in range(9)] + self.xy_scale = xy_scale + self.ignore_threshold = ignore_threshold + self.coord_loss_multiplier = coord_loss_multiplier + self.class_loss_multiplier = class_loss_multiplier + self.confidence_loss_multiplier = confidence_loss_multiplier + self.se_loss = nn.MSELoss(reduction='none') + + def forward(self, x, targets=None): + """ + Runs a forward pass through this YOLO detection layer. + + Maps cell-local coordinates to global coordinates in the [0, 1] range, scales the bounding + boxes with the anchors, converts the center coordinates to corner coordinates, and maps + probabilities to ]0, 1[ range using sigmoid. + + Args: + x (Tensor): The output from the previous layer. Tensor of size + `[batch_size, boxes_per_cell * (num_classes + 5), height, width]`. + targets (List[Dict[str, Tensor]]): If set, computes losses from detection layers + against these targets. A list of dictionaries, one for each image. + + Returns: + result (Tuple[Tensor, Dict[str, Tensor]]): Layer output, and if training targets were + provided, a dictionary of losses. Layer output is sized + `[batch_size, num_anchors * height * width, num_classes + 5]`. + """ + batch_size, num_features, height, width = x.shape + num_attrs = self.num_classes + 5 + boxes_per_cell = num_features // num_attrs + if boxes_per_cell != len(self.anchor_ids): + raise MisconfigurationException( + "The model predicts {} bounding boxes per cell, but {} anchor boxes are defined " + "for this layer.".format(boxes_per_cell, len(self.anchor_ids))) + + # Reshape the output to have the bounding box attributes of each grid cell on its own row. + x = x.permute(0, 2, 3, 1) # [batch_size, height, width, boxes_per_cell * num_attrs] + x = x.view(batch_size, height, width, boxes_per_cell, num_attrs) + + # Take the sigmoid of the bounding box coordinates, confidence score, and class + # probabilities. + xy = torch.sigmoid(x[..., :2]) + wh = x[..., 2:4] + confidence = torch.sigmoid(x[..., 4]) + classprob = torch.sigmoid(x[..., 5:]) + + # Eliminate grid sensitivity. The previous layer should output extremely high values for + # the sigmoid to produce x/y coordinates close to one. YOLOv4 solves this by scaling the + # x/y coordinates. + xy = xy * self.xy_scale - 0.5 * (self.xy_scale - 1) + + if not torch.isfinite(x).all(): + raise ValueError('Detection layer output contains nan or inf values.') + + image_xy = self._global_xy(xy) + image_wh = self._scale_wh(wh) + corners = self._corner_coordinates(image_xy, image_wh) + output = torch.cat((corners, confidence.unsqueeze(-1), classprob), -1) + output = output.reshape(batch_size, height * width * boxes_per_cell, num_attrs) + + if targets is None: + return output + else: + np_mask = self._no_prediction_mask(corners, targets) + losses = self._calculate_losses(xy, wh, confidence, classprob, targets, np_mask) + return output, losses + + def _global_xy(self, xy): + """ + Adds offsets to the predicted box center coordinates to obtain global coordinates to the + image. + + The predicted coordinates are interpreted as coordinates inside a grid cell whose width and + height is 1. Adding offset to the cell and dividing by the grid size, we get global + coordinates in the [0, 1] range. + + Args: + xy (Tensor): The predicted center coordinates before scaling. Values from zero to one + in a tensor sized `[batch_size, height, width, boxes_per_cell, 2]`. + + Returns: + result (Tensor): Global coordinates from zero to one, in a tensor with the same shape + as the input tensor. + """ + height = xy.shape[1] + width = xy.shape[2] + grid_size = torch.tensor([width, height], device=xy.device) + + x_range = torch.arange(width, dtype=xy.dtype, device=xy.device) + y_range = torch.arange(height, dtype=xy.dtype, device=xy.device) + grid_y, grid_x = torch.meshgrid(y_range, x_range) + offset = torch.stack((grid_x, grid_y), -1) # [height, width, 2] + offset = offset.unsqueeze(2) # [height, width, 1, 2] + + return (xy + offset) / grid_size + + def _scale_wh(self, wh): + """ + Scales the box size predictions by the prior dimensions from the anchors. + + Args: + wh (Tensor): The unnormalized width and height predictions. Tensor of size + `[..., boxes_per_cell, 2]`. + + Returns: + result (Tensor): A tensor with the same shape as the input tensor, but scaled sizes + normalized to the [0, 1] range. + """ + image_size = torch.tensor([self.image_width, self.image_height], device=wh.device) + anchor_wh = [self.anchor_dims[i] for i in self.anchor_ids] + anchor_wh = torch.tensor(anchor_wh, dtype=wh.dtype, device=wh.device) + return torch.exp(wh) * anchor_wh / image_size + + def _corner_coordinates(self, xy, wh): + """ + Converts box center points and sizes to corner coordinates. + + Args: + xy (Tensor): Center coordinates. Tensor of size `[..., 2]`. + wh (Tensor): Width and height. Tensor of size `[..., 2]`. + + Returns: + corners (Tensor): A matrix of (x1, y1, x2, y2) coordinates. + """ + half_wh = wh / 2 + top_left = xy - half_wh + bottom_right = xy + half_wh + return torch.cat((top_left, bottom_right), -1) + + def _no_prediction_mask(self, preds, targets): + """ + Initializes the mask that will be used to select predictors that are not responsible for + predicting any target. The value will be `True`, unless the predicted box overlaps any + target significantly (IoU greater than `self.ignore_threshold`). + + Args: + preds (Tensor): The predicted corner coordinates, normalized to the [0, 1] range. + Tensor of size `[batch_size, height, width, boxes_per_cell, 4]`. + targets (List[Dict[str, Tensor]]): List of dictionaries of target values, one + dictionary for each image. + + Returns: + results (Tensor): A boolean tensor shaped `[batch_size, height, width, boxes_per_cell]` + with `False` where the predicted box overlaps a target significantly and `True` + elsewhere. + """ + shape = preds.shape + preds = preds.view(shape[0], -1, shape[-1]) + + scale = torch.tensor([self.image_width, + self.image_height, + self.image_width, + self.image_height], + device=preds.device) + preds = preds * scale + + results = torch.ones(preds.shape[:-1], dtype=torch.bool, device=preds.device) + for image_idx, (image_preds, image_targets) in enumerate(zip(preds, targets)): + target_boxes = image_targets['boxes'] + if target_boxes.shape[0] > 0: + ious = box_iou(image_preds, target_boxes) + best_ious = ious.max(-1).values + results[image_idx] = best_ious <= self.ignore_threshold + results = results.view(shape[:-1]) + return results + + def _calculate_losses(self, xy, wh, confidence, classprob, targets, np_mask): + """ + From the targets that are in the image space calculates the actual targets for the network + predictions, and returns a dictionary of training losses. + + Args: + xy (Tensor): The predicted center coordinates before scaling. Values from zero to one + in a tensor sized `[batch_size, height, width, boxes_per_cell, 2]`. + wh (Tensor): The unnormalized width and height predictions. Tensor of size + `[batch_size, height, width, boxes_per_cell, 2]`. + confidence (Tensor): The confidence predictions, normalized to [0, 1]. A tensor sized + `[batch_size, height, width, boxes_per_cell]`. + classprob (Tensor): The class probability predictions, normalized to [0, 1]. A tensor + sized `[batch_size, height, width, boxes_per_cell, num_classes]`. + targets (List[Dict[str, Tensor]]): List of dictionaries of target values, one + dictionary for each image. + np_mask: A boolean mask containing `True` where the predicted box does not overlap any + target significantly. + + Returns: + predicted (Dict[str, Tensor]): A dictionary of training losses. + """ + batch_size, height, width, boxes_per_cell, _ = xy.shape + device = xy.device + assert batch_size == len(targets) + + # Divisor for converting targets from image coordinates to feature map coordinates + image_to_feature_map = torch.tensor([self.image_width / width, + self.image_height / height], + device=device) + # Divisor for converting targets from image coordinates to [0, 1] range + image_to_unit = torch.tensor([self.image_width, self.image_height], + device=device) + + anchor_wh = torch.tensor(self.anchor_dims, dtype=wh.dtype, device=device) + anchor_map = torch.tensor(self.anchor_map, dtype=torch.int64, device=device) + + # List of predicted and target values for the predictors that are responsible for + # predicting a target. + target_xy = [] + target_wh = [] + target_label = [] + size_compensation = [] + pred_xy = [] + pred_wh = [] + pred_classprob = [] + pred_confidence = [] + + for image_idx, image_targets in enumerate(targets): + boxes = image_targets['boxes'] + if boxes.shape[0] < 1: + continue + + # Bounding box corner coordinates are converted to center coordinates, width, and + # height. + box_wh = boxes[:, 2:4] - boxes[:, 0:2] + box_xy = boxes[:, 0:2] + (box_wh / 2) + + # The center coordinates are converted to the feature map dimensions so that the whole + # number tells the cell index and the fractional part tells the location inside the cell. + box_xy = box_xy / image_to_feature_map + cell_i = box_xy[:, 0].to(torch.int64).clamp(0, width - 1) + cell_j = box_xy[:, 1].to(torch.int64).clamp(0, height - 1) + + # We want to know which anchor box overlaps a ground truth box more than any other + # anchor box. We know that the anchor box is located in the same grid cell as the + # ground truth box. For each prior shape (width, height), we calculate the IoU with + # all ground truth boxes, assuming the boxes are at the same location. Then for each + # target, we select the prior shape that gives the highest IoU. + ious = _aligned_iou(box_wh, anchor_wh) + best_anchors = ious.max(1).indices + + # `anchor_map` maps the anchor indices to the predictors in this layer, or to -1 if + # it's not an anchor of this layer. We ignore the predictions if the best anchor is in + # another layer. + predictors = anchor_map[best_anchors] + selected = predictors >= 0 + box_xy = box_xy[selected] + box_wh = box_wh[selected] + cell_i = cell_i[selected] + cell_j = cell_j[selected] + predictors = predictors[selected] + best_anchors = best_anchors[selected] + + # The "no-prediction" mask is used to select predictors that are not responsible for + # predicting any object for calculating the confidence loss. + np_mask[image_idx, cell_j, cell_i, predictors] = False + + # Bounding box targets + relative_xy = box_xy - box_xy.floor() + relative_wh = torch.log(box_wh / anchor_wh[best_anchors] + 1e-16) + target_xy.append(relative_xy) + target_wh.append(relative_wh) + + # Size compensation factor for bounding box loss + unit_wh = box_wh / image_to_unit + size_compensation.append(2 - (unit_wh[:, 0] * unit_wh[:, 1])) + + # The data may contain a different number of classes than this detection layer. In case + # a label is greater than the number of classes that this layer predicts, it will be + # mapped to the last class. + labels = image_targets['labels'] + labels = labels[selected] + labels = torch.minimum(labels, torch.tensor(self.num_classes - 1, device=device)) + target_label.append(labels) + + pred_xy.append(xy[image_idx, cell_j, cell_i, predictors]) + pred_wh.append(wh[image_idx, cell_j, cell_i, predictors]) + pred_classprob.append(classprob[image_idx, cell_j, cell_i, predictors]) + pred_confidence.append(confidence[image_idx, cell_j, cell_i, predictors]) + + losses = dict() + + if pred_xy and pred_wh and target_xy and target_wh: + size_compensation = torch.cat(size_compensation).unsqueeze(1) + pred_xy = torch.cat(pred_xy) + target_xy = torch.cat(target_xy) + location_loss = self.se_loss(pred_xy, target_xy) + location_loss = location_loss * size_compensation + location_loss = location_loss.sum() / batch_size + losses['location'] = location_loss * self.coord_loss_multiplier + + pred_wh = torch.cat(pred_wh) + target_wh = torch.cat(target_wh) + size_loss = self.se_loss(pred_wh, target_wh) + size_loss = size_loss * size_compensation + size_loss = size_loss.sum() / batch_size + losses['size'] = size_loss * self.coord_loss_multiplier + + class_loss = None + if pred_classprob and target_label: + pred_classprob = torch.cat(pred_classprob) + target_label = torch.cat(target_label) + target_classprob = torch.nn.functional.one_hot(target_label, self.num_classes) + target_classprob = target_classprob.to(dtype=pred_classprob.dtype) + class_loss = self.se_loss(pred_classprob, target_classprob) + class_loss = class_loss.sum() / batch_size + losses['class'] = class_loss * self.class_loss_multiplier + + np_confidence = confidence[np_mask] + np_target_confidence = torch.zeros_like(np_confidence) + np_confidence_loss = self.se_loss(np_confidence, np_target_confidence) + np_confidence_loss = np_confidence_loss.sum() / batch_size + losses['np_confidence'] = np_confidence_loss * self.confidence_loss_multiplier + + if pred_confidence: + p_confidence = torch.cat(pred_confidence) + p_target_confidence = torch.ones_like(p_confidence) + p_confidence_loss = self.se_loss(p_confidence, p_target_confidence) + p_confidence_loss = p_confidence_loss.sum() / batch_size + losses['p_confidence'] = p_confidence_loss * self.confidence_loss_multiplier + + return losses + + +class Mish(nn.Module): + """Mish activation.""" + def __init__(self): + super().__init__() + + def forward(self, x): + return x * torch.tanh(nn.functional.softplus(x)) + + +class RouteLayer(nn.Module): + """Route layer concatenates the output (or part of it) from given layers.""" + + def __init__(self, layers, groups, group_id): + super().__init__() + self.layers = layers + if groups > 0: + self.groups = groups + self.group_id = group_id + else: + self.groups = 1 + self.group_id = 0 + + def forward(self, x, outputs): + chunks = [torch.chunk(outputs[l], self.groups, dim=1)[self.group_id] + for l in self.layers] + return torch.cat(chunks, dim=1) + + +class ShortcutLayer(nn.Module): + """Shortcut layer adds a residual connection from the source layer.""" + + def __init__(self, source_layer): + super().__init__() + self.source_layer = source_layer + + def forward(self, x, outputs): + return outputs[-1] + outputs[self.source_layer] diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py new file mode 100644 index 0000000000..00c90a80fa --- /dev/null +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -0,0 +1,589 @@ +import inspect +from argparse import ArgumentParser, Namespace +from pathlib import Path +from typing import List, Type, Union + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn as nn +from pytorch_lightning.utilities import argparse_utils +from torch import optim + +from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR +from pl_bolts.utils.warnings import warn_missing_pkg + +try: + import torchvision.transforms as T + from torchvision.ops import nms + from torchvision.transforms import functional as F +except ModuleNotFoundError: + warn_missing_pkg('torchvision') # pragma: no-cover + _TORCHVISION_AVAILABLE = False +else: + _TORCHVISION_AVAILABLE = True + +from pl_bolts.models.detection.yolo.yolo_config import YoloConfiguration +from pl_bolts.models.detection.yolo.yolo_layers import DetectionLayer, Mish, RouteLayer, ShortcutLayer + + +class Yolo(pl.LightningModule): + def __init__( + self, + configuration: YoloConfiguration, + optimizer: str = 'sgd', + momentum: float = 0.9, + weight_decay: float = 0.0005, + learning_rate: float = 0.0013, + warmup_epochs: int = 1, + warmup_start_lr: float = 0.0001, + annealing_epochs: int = 271, + confidence_threshold: float = 0.2, + nms_threshold: float = 0.45): + """ + Constructs a YOLO model. + + Args: + config (YoloConfiguration): The model configuration file. + momentum (float): Momentum factor for SGD with momentum. + weight_decay (float): Weight decay (L2 penalty). + learning_rate (float): Learning rate after the warmup period. + warmup_epochs (int): Length of the learning rate warmup period in the beginning of + training. During this number of epochs, the learning rate will be raised from + `warmup_start_lr` to `learning_rate`. + warmup_start_lr (int): Learning rate in the beginning of the warmup period. + annealing_epochs (int): Length of the learning rate annealing period, during which the + learning rate will go to zero. + confidence_threshold (float): Postprocessing will remove bounding boxes whose + confidence score is not higher than this threshold. + nms_threshold (float): Non-maximum suppression will remove bounding boxes whose IoU + with the next best bounding box in that class is higher than this threshold. + """ + super().__init__() + + if not _TORCHVISION_AVAILABLE: + raise ModuleNotFoundError( # pragma: no-cover + 'YOLO model uses `torchvision`, which is not installed yet.' + ) + + self.config = configuration + self.optimizer = optimizer + self.momentum = momentum + self.weight_decay = weight_decay + self.learning_rate = learning_rate + self.warmup_epochs = warmup_epochs + self.warmup_start_lr = warmup_start_lr + self.annealing_epochs = annealing_epochs + self.confidence_threshold = confidence_threshold + self.nms_threshold = nms_threshold + + self._create_modules() + + def forward(self, images, targets=None): + # type: (List[Tensor]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + """ + Runs a forward pass through the network (all layers listed in `self._module_list`), and if + training targets are provided, computes the losses from the detection layers. + + Detections are concatenated from the detection layers. Each image will produce + `N * num_anchors * grid_height * grid_width` detections, where `N` depends on the number of + detection layers. For one detection layer `N = 1`, and each detection layer increases it by + a number that depends on the size of the feature map on that layer. For example, if the + feature map is twice as wide and high as the grid, the layer will add four times more + features. + + Args: + images (Tensor): Images to be processed. Tensor of size + `[batch_size, num_channels, height, width]`. + targets (List[Dict[str, Tensor]]): If set, computes losses from detection layers + against these targets. A list of dictionaries, one for each image. + + Returns: + boxes (Tensor), confidences (Tensor), classprobs (Tensor), losses (Dict[str, Tensor]): + Detections, and if targets were provided, a dictionary of losses. The first + dimension of the detections is the index of the image in the batch and the second + dimension is the detection within the image. `boxes` contains the predicted + (x1, y1, x2, y2) coordinates, normalized to [0, 1]. + """ + outputs = [] # Outputs from all layers + detections = [] # Outputs from detection layers + losses = [] # Losses from detection layers + + x = images + for module in self._module_list: + if isinstance(module, RouteLayer) or isinstance(module, ShortcutLayer): + x = module(x, outputs) + elif isinstance(module, DetectionLayer): + if targets is None: + x = module(x) + detections.append(x) + else: + x, layer_losses = module(x, targets) + detections.append(x) + losses.append(layer_losses) + else: + x = module(x) + + outputs.append(x) + + def mean_loss(loss_name): + loss_tuple = tuple(layer_losses[loss_name] for layer_losses in losses) + return torch.stack(loss_tuple).sum() / images.shape[0] + + detections = torch.cat(detections, 1) + boxes = detections[..., :4] + confidences = detections[..., 4] + classprobs = detections[..., 5:] + + if targets is not None: + losses = {loss_name: mean_loss(loss_name) for loss_name in losses[0].keys()} + return boxes, confidences, classprobs, losses + else: + return boxes, confidences, classprobs + + def configure_optimizers(self): + """Constructs the optimizer and learning rate scheduler.""" + if self.optimizer == 'sgd': + optimizer = optim.SGD( + self.parameters(), + lr=self.learning_rate, + momentum=self.momentum, + weight_decay=self.weight_decay) + elif self.optimizer == 'adam': + optimizer = optim.Adam( + self.parameters(), + lr=self.learning_rate + ) + lr_scheduler = LinearWarmupCosineAnnealingLR( + optimizer, + warmup_epochs=self.warmup_epochs, + max_epochs=self.annealing_epochs, + warmup_start_lr=self.warmup_start_lr) + return [optimizer], [lr_scheduler] + + def training_step(self, batch, batch_idx): + # type: (Tuple[List[Tensor], List[Dict[str, Tensor]]]) -> Dict[str, Tensor] + """ + Computes the training loss. + + Args: + batch (Tuple[List[Tensor], List[Dict[str, Tensor]]]): + A tuple of images and targets. Images is a list of 3-dimensional tensors. Targets + is a list of dictionaries that contain ground-truth boxes, labels, etc. + batch_idx (int): The index of this batch. + + Returns: + A dictionary that includes the training loss in 'loss'. + """ + images, targets = self._validate_batch(batch) + _, _, _, losses = self(images, targets) + total_loss = torch.stack(tuple(losses.values())).sum() + + for name, value in losses.items(): + self.log('train/{}_loss'.format(name), value) + self.log('train/total_loss', total_loss) + + return {'loss': total_loss} + + def validation_step(self, batch, batch_idx): + # type: (Tuple[List[Tensor], List[Dict[str, Tensor]]], int) -> Dict[str, Tensor] + """ + Evaluates a batch of data from the validation set. + + Args: + batch (Tuple[List[Tensor], List[Dict[str, Tensor]]]): + The batch of data read by the :class:`~torch.utils.data.DataLoader` + batch_idx (int): The index of this batch + """ + images, targets = self._validate_batch(batch) + boxes, confidences, classprobs, losses = self(images, targets) + classprobs, labels = torch.max(classprobs, -1) + boxes, confidences, classprobs, labels = self._filter_detections( + boxes, confidences, classprobs, labels) + total_loss = torch.stack(tuple(losses.values())).sum() + + for name, value in losses.items(): + self.log('val/{}_loss'.format(name), value) + self.log('val/total_loss', total_loss) + + def test_step(self, batch, batch_idx): + # type: (Tuple[List[Tensor], List[Dict[str, Tensor]]], int) -> Dict[str, Tensor] + """ + Evaluates a batch of data from the test set. + + Args: + batch (Tuple[List[Tensor], List[Dict[str, Tensor]]]): + The batch of data read by the :class:`~torch.utils.data.DataLoader` + batch_idx (int): The index of this batch. + """ + images, targets = self._validate_batch(batch) + boxes, confidences, classprobs, losses = self(images, targets) + classprobs, labels = torch.max(classprobs, -1) + boxes, confidences, classprobs, labels = self._filter_detections( + boxes, confidences, classprobs, labels) + total_loss = torch.stack(tuple(losses.values())).sum() + + for name, value in losses.items(): + self.log('test/{}_loss'.format(name), value) + self.log('test/total_loss', total_loss) + + def infer(self, image): + # type: (ndarray) -> Dict[str, Tensor] + """ + Resizes given image to the network input size and feeds it to the network. Returns the + detected bounding boxes, confidences, and class labels. + + Args: + image (Tensor): + An input image, a tensor of uint8 values sized `[channels, height, width]`. + + Returns: + boxes (Tensor): A matrix of detected bounding box (x1, y1, x2, y2) coordinates. + confidences (Tensor): A vector of confidences for the bounding box detections. + labels (Tensor): A vector of predicted class labels. + """ + network_input = image.float().div(255.0) + network_input = network_input.unsqueeze(0) + self.eval() + boxes, confidences, classprobs = self(network_input) + classprobs, labels = torch.max(classprobs, -1) + boxes, confidences, classprobs, labels = self._filter_detections( + boxes, confidences, classprobs, labels) + assert len(boxes) == 1 + boxes = boxes[0] + confidences = confidences[0] + labels = labels[0] + + height = image.shape[1] + width = image.shape[2] + scale = torch.tensor([width, height, width, height], device=boxes.device) + boxes = boxes * scale + boxes = torch.round(boxes).int() + return boxes, confidences, labels + + def load_darknet_weights(self, weight_file): + """ + Loads weights to layer modules from a pretrained Darknet model. + """ + version = np.fromfile(weight_file, count=3, dtype=np.int32) + images_seen = np.fromfile(weight_file, count=1, dtype=np.int64) + print('Loading weights from Darknet model version {}.{}.{} that has been trained on {} ' + 'images.'.format(version[0], version[1], version[2], images_seen[0])) + + def read(tensor): + x = np.fromfile(weight_file, count=tensor.numel(), dtype=np.float32) + x = torch.from_numpy(x).view_as(tensor) + with torch.no_grad(): + tensor.copy_(x) + + for module in self._module_list: + # Weights are loaded only to convolutional layers + if not isinstance(module, nn.Sequential): + continue + + conv = module[0] + assert isinstance(conv, nn.Conv2d) + + if len(module) > 1: + bn = module[1] + assert isinstance(bn, nn.BatchNorm2d) + + read(bn.bias) + read(bn.weight) + read(bn.running_mean) + read(bn.running_var) + else: + read(conv.bias) + + read(conv.weight) + + @classmethod + def get_deprecated_arg_names(cls) -> List: + """Returns a list with deprecated constructor arguments.""" + depr_arg_names = [] + for name, val in cls.__dict__.items(): + if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)): + depr_arg_names.extend(val) + return depr_arg_names + + def _create_modules(self): + """ + Creates a list of network modules based on parsed configuration file. + """ + self._module_list = nn.ModuleList() + num_outputs = 3 # Number of channels in the previous layer output + layer_outputs = [] # Number of channels in the output of every layer + + # Iterate through the modules from the configuration and generate required components. + for index, config in enumerate(self.config.modules): + if config['type'] == 'convolutional': + module = nn.Sequential() + + batch_normalize = config.get('batch_normalize', False) + padding = (config['size'] - 1) // 2 if config['pad'] else 0 + + conv = nn.Conv2d( + num_outputs, + config['filters'], + config['size'], + config['stride'], + padding, + bias=not batch_normalize) + module.add_module("conv_{0}".format(index), conv) + num_outputs = config['filters'] + + if batch_normalize: + bn = nn.BatchNorm2d(config['filters']) + module.add_module("batch_norm_{0}".format(index), bn) + + if config['activation'] == 'leaky': + leakyrelu = nn.LeakyReLU(0.1, inplace=True) + module.add_module('leakyrelu_{0}'.format(index), leakyrelu) + elif config['activation'] == 'mish': + mish = Mish() + module.add_module("mish_{0}".format(index), mish) + + elif config['type'] == 'upsample': + module = nn.Upsample(scale_factor=config["stride"], mode='nearest') + + elif config['type'] == 'route': + groups = config.get('groups', 0) + group_id = config.get('group_id', 0) + layers = [layer if layer >= 0 else index + layer for layer in config['layers']] + module = RouteLayer(layers, groups, group_id) + + num_outputs = 0 + for layer in layers: + if groups > 0: + num_outputs += layer_outputs[layer] // groups + else: + num_outputs += layer_outputs[layer] + + elif config['type'] == 'shortcut': + module = ShortcutLayer(config['from']) + + elif config['type'] == 'yolo': + # The "anchors" list alternates width and height. + anchor_dims = config['anchors'] + anchor_dims = [(anchor_dims[i], anchor_dims[i + 1]) + for i in range(0, len(anchor_dims), 2)] + + xy_scale = config.get('scale_x_y', 1.0) + ignore_threshold = config.get('ignore_thresh', 1.0) + coord_loss_multiplier = config.get('iou_normalizer', 1.0) + class_loss_multiplier = config.get('cls_normalizer', 1.0) + confidence_loss_multiplier = config.get('obj_normalizer', 1.0) + + module = DetectionLayer( + num_classes=config['classes'], + image_width=self.config.width, + image_height=self.config.height, + anchor_dims=anchor_dims, + anchor_ids=config['mask'], + xy_scale=xy_scale, + ignore_threshold=ignore_threshold, + coord_loss_multiplier=coord_loss_multiplier, + class_loss_multiplier=class_loss_multiplier, + confidence_loss_multiplier=confidence_loss_multiplier) + + elif config['type'] == 'maxpool': + padding = (config['size'] - 1) // 2 + module = nn.MaxPool2d(config['size'], config['stride'], padding) + + self._module_list.append(module) + layer_outputs.append(num_outputs) + + def _validate_batch(self, batch): + """ + Reads a batch of data and validates the format. + + Args: + batch (Tuple[List[Tensor], List[Dict[str, Tensor]]]): + The batch of data read by the :class:`~torch.utils.data.DataLoader` + """ + images, targets = batch + + if len(images) != len(targets): + raise ValueError("Got {} images, but targets for {} images." + .format(len(images), len(targets))) + + for image in images: + if not isinstance(image, torch.Tensor): + raise ValueError("Expected image to be of type Tensor, got {}." + .format(type(image))) + expected_shape = torch.Size((self.config.channels, + self.config.height, + self.config.width)) + if image.shape != expected_shape: + raise ValueError("Expected images to be tensors of shape {}, got {}." + .format(list(expected_shape), list(image.shape))) + + for target in targets: + boxes = target['boxes'] + if not isinstance(boxes, torch.Tensor): + raise ValueError("Expected target boxes to be of type Tensor, got {}." + .format(type(boxes))) + if (len(boxes.shape) != 2) or (boxes.shape[-1] != 4): + raise ValueError("Expected target boxes to be tensors of shape [N, 4], got {}." + .format(list(boxes.shape))) + labels = target['labels'] + if not isinstance(labels, torch.Tensor): + raise ValueError("Expected target labels to be of type Tensor, got {}." + .format(type(labels))) + if len(labels.shape) != 1: + raise ValueError("Expected target labels to be tensors of shape [N], got {}." + .format(list(labels.shape))) + + images = torch.stack(images) + return images, targets + + def _filter_detections(self, boxes, confidences, classprobs, labels): + """ + Filters detections based on confidence threshold. Then for every class performs non-maximum + suppression (NMS). NMS iterates the bounding boxes that predict this class in descending + order of confidence score, and removes the bounding box, if its IoU with the next one is + higher than the NMS threshold. + + Args: + boxes (Tensor): + Detected bounding box (x1, y1, x2, y2) coordinates in a tensor sized + `[batch_size, N, 4]`. + confidences (Tensor): + Detection confidences in a tensor sized `[batch_size, N]`. + classprobs (Tensor): + Probabilities of the best classes in a tensor sized `[batch_size, N]`. + labels (Tensor): + Indices of the best classes in a tensor sized `[batch_size, N]`. + + Returns: + boxes (List[Tensor]): + List of bounding box (x1, y1, x2, y2) coordinates, one tensor for each image. + confidences (List[Tensor]): + List of detection confidences, one tensor for each image. + classprobs (List[Tensor]): + List of tensors, one for each image, that contain the probabilities of the best + class of each prediction. + labels (List[Tensor]): + List of predicted class labels, one for each image. + """ + out_boxes = [] + out_confidences = [] + out_classprobs = [] + out_labels = [] + + for img_boxes, img_confidences, img_classprobs, img_labels in zip(boxes, confidences, classprobs, labels): + # Select detections with high confidence score. + selected = img_confidences > self.confidence_threshold + img_boxes = img_boxes[selected] + img_confidences = img_confidences[selected] + img_classprobs = img_classprobs[selected] + img_labels = img_labels[selected] + + img_out_boxes = boxes.new_zeros((0, 4)) + img_out_confidences = confidences.new_zeros(0) + img_out_classprobs = classprobs.new_zeros(0) + img_out_labels = labels.new_zeros(0) + + # Iterate through the unique object classes detected in the image and perform non-maximum + # suppression for the objects of the class in question. + for cls_label in labels.unique(): + selected = img_labels == cls_label + cls_boxes = img_boxes[selected] + cls_confidences = img_confidences[selected] + cls_classprobs = img_classprobs[selected] + cls_labels = img_labels[selected] + + selected = nms(cls_boxes, cls_confidences, self.nms_threshold) + img_out_boxes = torch.cat((img_out_boxes, cls_boxes[selected])) + img_out_confidences = torch.cat((img_out_confidences, cls_confidences[selected])) + img_out_classprobs = torch.cat((img_out_classprobs, cls_classprobs[selected])) + img_out_labels = torch.cat((img_out_labels, cls_labels[selected])) + + out_boxes.append(img_out_boxes) + out_confidences.append(img_out_confidences) + out_classprobs.append(img_out_classprobs) + out_labels.append(img_out_labels) + + return out_boxes, out_confidences, out_classprobs, out_labels + + def _scale_boxes(self, images, boxes): + """Scales the box coordinates to image dimensions.""" + result = [] + + for image, img_boxes in zip(images, boxes): + height = image.shape[1] + width = image.shape[2] + scale = torch.tensor([width, height, width, height], device=img_boxes.device) + img_boxes = img_boxes * scale + result.append(img_boxes) + + return result + + +class Resize: + """Rescales the image and target to given dimensions. + + Args: + output_size (tuple or int): Desired output size. If tuple (height, width), the output is + matched to `output_size`. If int, the smaller of the image edges is matched to + `output_size`, keeping the aspect ratio the same. + """ + + def __init__(self, output_size: tuple): + self.output_size = output_size + + def __call__(self, image, target): + width, height = image.size + original_size = torch.tensor([height, width]) + resize_ratio = torch.tensor(self.output_size) / original_size + image = F.resize(image, self.output_size) + scale = torch.tensor([resize_ratio[1], # y + resize_ratio[0], # x + resize_ratio[1], # y + resize_ratio[0]], # x + device=target['boxes'].device) + target['boxes'] = target['boxes'] * scale + return image, target + + +def run_cli(): + from pytorch_lightning.utilities import argparse_utils + + from pl_bolts.datamodules import VOCDetectionDataModule + from pl_bolts.datamodules.vocdetection_datamodule import Compose + + pl.seed_everything(42) + + parser = ArgumentParser() + parser.add_argument('--config', type=str, help='model configuration file', required=True) + parser.add_argument('--darknet-weights', type=str, help='initialize the model weights from this Darknet model file') + parser.add_argument('--batch-size', type=int, help='number of images in one batch', default=16) + parser = VOCDetectionDataModule.add_argparse_args(parser) + parser = argparse_utils.add_argparse_args(Yolo, parser) + parser = pl.Trainer.add_argparse_args(parser) + args = parser.parse_args() + + config = YoloConfiguration(args.config) + + transforms = [Resize((config.height, config.width))] + image_transforms = T.ToTensor() + datamodule = VOCDetectionDataModule.from_argparse_args(args) + datamodule.prepare_data() + + params = vars(args) + valid_kwargs = inspect.signature(Yolo.__init__).parameters + kwargs = dict((name, params[name]) for name in valid_kwargs if name in params) + model = Yolo(configuration=config, **kwargs) + if args.darknet_weights is not None: + with open(args.darknet_weights, 'r') as weight_file: + model.load_darknet_weights(weight_file) + + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit( + model, + datamodule.train_dataloader(args.batch_size, transforms, image_transforms), + datamodule.val_dataloader(args.batch_size, transforms, image_transforms)) + + +if __name__ == "__main__": + run_cli() diff --git a/tests/models/test_detection.py b/tests/models/test_detection.py index 73736cad6e..137c7fec49 100644 --- a/tests/models/test_detection.py +++ b/tests/models/test_detection.py @@ -1,9 +1,13 @@ +from pathlib import Path + +import pytest import pytorch_lightning as pl import torch from torch.utils.data import DataLoader from pl_bolts.datasets import DummyDetectionDataset -from pl_bolts.models.detection import FasterRCNN +from pl_bolts.models.detection import FasterRCNN, Yolo, YoloConfiguration +from pl_bolts.models.detection.yolo.yolo_layers import _aligned_iou def _collate_fn(batch): @@ -34,3 +38,125 @@ def test_fasterrcnn_bbone_train(tmpdir): trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, train_dl, valid_dl) + + +def _create_yolo_config_file(config_path): + config_file = open(config_path, 'w') + config_file.write('''[net] +width=256 +height=256 +channels=3 + +[convolutional] +batch_normalize=1 +filters=8 +size=3 +stride=1 +pad=1 +activation=leaky + +[route] +layers=-1 +groups=2 +group_id=1 + +[maxpool] +size=2 +stride=2 + +[convolutional] +batch_normalize=1 +filters=2 +size=1 +stride=1 +pad=1 +activation=mish + +[convolutional] +batch_normalize=1 +filters=4 +size=3 +stride=1 +pad=1 +activation=mish + +[shortcut] +from=-3 +activation=linear + +[convolutional] +size=1 +stride=1 +pad=1 +filters=14 +activation=linear + +[yolo] +mask=2,3 +anchors=1,2, 3,4, 5,6, 9,10 +classes=2 +scale_x_y=1.05 +cls_normalizer=1.0 +iou_normalizer=0.07 +ignore_thresh=0.7 + +[route] +layers = -4 + +[upsample] +stride=2 + +[convolutional] +size=1 +stride=1 +pad=1 +filters=14 +activation=linear + +[yolo] +mask=0,1 +anchors=1,2, 3,4, 5,6, 9,10 +classes=2 +scale_x_y=1.05 +cls_normalizer=1.0 +iou_normalizer=0.07 +ignore_thresh=0.7''') + config_file.close() + + +def test_yolo(tmpdir): + config_path = Path(tmpdir) / 'yolo.cfg' + _create_yolo_config_file(config_path) + config = YoloConfiguration(config_path) + model = Yolo(config) + + image = torch.rand(1, 3, 256, 256) + model(image) + + +def test_yolo_train(tmpdir): + config_path = Path(tmpdir) / 'yolo.cfg' + _create_yolo_config_file(config_path) + config = YoloConfiguration(config_path) + model = Yolo(config) + + train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) + valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) + + trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) + trainer.fit(model, train_dataloader=train_dl, val_dataloaders=valid_dl) + + +@pytest.mark.parametrize( + "dims1, dims2, expected_ious", + [(torch.tensor([[1.0, 1.0], + [10.0, 1.0], + [100.0, 10.0]]), + torch.tensor([[1.0, 10.0], + [2.0, 20.0]]), + torch.tensor([[1.0 / 10.0, 1.0 / 40.0], + [1.0 / 19.0, 2.0 / 48.0], + [10.0 / 1000.0, 20.0 / 1020.0]]))] +) +def test_aligned_iou(dims1, dims2, expected_ious): + torch.testing.assert_allclose(_aligned_iou(dims1, dims2), expected_ious) From 2b9b073b831f3dee77dafb6666df3b2c3bf4fb33 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 3 Feb 2021 10:05:15 +0200 Subject: [PATCH 02/61] Readability improvements --- .../datamodules/vocdetection_datamodule.py | 4 +- pl_bolts/models/detection/yolo/yolo_config.py | 2 +- pl_bolts/models/detection/yolo/yolo_layers.py | 30 +++++++++----- pl_bolts/models/detection/yolo/yolo_module.py | 41 ++++++++----------- 4 files changed, 41 insertions(+), 36 deletions(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 8557741f11..863156e086 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -154,7 +154,7 @@ def prepare_data(self) -> None: def train_dataloader( self, batch_size: int = 1, - transforms: Optional[List[Callable]] = [], + transforms: List[Callable] = [], image_transforms: Optional[Callable] = None ) -> DataLoader: """ @@ -183,7 +183,7 @@ def train_dataloader( def val_dataloader( self, batch_size: int = 1, - transforms: Optional[List[Callable]] = [], + transforms: List[Callable] = [], image_transforms: Optional[Callable] = None ) -> DataLoader: """ diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index cd8e99b602..bb945e75e5 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -91,7 +91,7 @@ def _read_file(self, config_file): def convert(key, value): """Converts a value to the correct type based on key.""" - if not key in variable_types: + if key not in variable_types: warn('Unknown YOLO configuration variable: ' + key) return key, value if key in list_variables: diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 32165b53d2..da1a66c7d7 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -423,19 +423,23 @@ def forward(self, x): class RouteLayer(nn.Module): """Route layer concatenates the output (or part of it) from given layers.""" - def __init__(self, layers, groups, group_id): + def __init__(self, source_layers: List[int], num_chunks: int, chunk_idx: int): + """ + Creates a YOLO route layer. + + Args: + source_layers (List[int]): Indices of the layers whose output will be concatenated. + num_chunks (int): Layer outputs will be split into this number of chunks. + chunk_idx (int): Only the chunks with this index will be concatenated. + """ super().__init__() - self.layers = layers - if groups > 0: - self.groups = groups - self.group_id = group_id - else: - self.groups = 1 - self.group_id = 0 + self.source_layers = source_layers + self.num_chunks = num_chunks + self.chunk_idx = chunk_idx def forward(self, x, outputs): - chunks = [torch.chunk(outputs[l], self.groups, dim=1)[self.group_id] - for l in self.layers] + chunks = [torch.chunk(outputs[layer], self.num_chunks, dim=1)[self.chunk_idx] + for layer in self.source_layers] return torch.cat(chunks, dim=1) @@ -443,6 +447,12 @@ class ShortcutLayer(nn.Module): """Shortcut layer adds a residual connection from the source layer.""" def __init__(self, source_layer): + """ + Constructs a YOLO shortcut layer. + + Args: + num_classes (int): Number of different classes that this layer predicts. + """ super().__init__() self.source_layer = source_layer diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 00c90a80fa..a03dd6c08a 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -28,18 +28,17 @@ class Yolo(pl.LightningModule): - def __init__( - self, - configuration: YoloConfiguration, - optimizer: str = 'sgd', - momentum: float = 0.9, - weight_decay: float = 0.0005, - learning_rate: float = 0.0013, - warmup_epochs: int = 1, - warmup_start_lr: float = 0.0001, - annealing_epochs: int = 271, - confidence_threshold: float = 0.2, - nms_threshold: float = 0.45): + def __init__(self, + configuration: YoloConfiguration, + optimizer: str = 'sgd', + momentum: float = 0.9, + weight_decay: float = 0.0005, + learning_rate: float = 0.0013, + warmup_epochs: int = 1, + warmup_start_lr: float = 0.0001, + annealing_epochs: int = 271, + confidence_threshold: float = 0.2, + nms_threshold: float = 0.45): """ Constructs a YOLO model. @@ -347,17 +346,13 @@ def _create_modules(self): module = nn.Upsample(scale_factor=config["stride"], mode='nearest') elif config['type'] == 'route': - groups = config.get('groups', 0) - group_id = config.get('group_id', 0) - layers = [layer if layer >= 0 else index + layer for layer in config['layers']] - module = RouteLayer(layers, groups, group_id) - - num_outputs = 0 - for layer in layers: - if groups > 0: - num_outputs += layer_outputs[layer] // groups - else: - num_outputs += layer_outputs[layer] + num_chunks = config.get('groups', 1) + chunk_idx = config.get('group_id', 0) + source_layers = [layer if layer >= 0 else index + layer + for layer in config['layers']] + module = RouteLayer(source_layers, num_chunks, chunk_idx) + num_outputs = sum(layer_outputs[layer] // num_chunks + for layer in source_layers) elif config['type'] == 'shortcut': module = ShortcutLayer(config['from']) From cc425406825bb97df9b72a3fa1833d9c84318a84 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 3 Feb 2021 14:11:20 +0200 Subject: [PATCH 03/61] Documentation improvements --- docs/source/object_detection.rst | 2 +- pl_bolts/models/detection/__init__.py | 6 +- pl_bolts/models/detection/yolo/yolo_config.py | 6 +- pl_bolts/models/detection/yolo/yolo_layers.py | 71 +++---- pl_bolts/models/detection/yolo/yolo_module.py | 191 ++++++++++-------- 5 files changed, 147 insertions(+), 129 deletions(-) diff --git a/docs/source/object_detection.rst b/docs/source/object_detection.rst index 37ace12a88..cdff88f9e0 100644 --- a/docs/source/object_detection.rst +++ b/docs/source/object_detection.rst @@ -16,5 +16,5 @@ Faster R-CNN YOLO ---- -.. autoclass:: pl_bolts.models.detection.yolo.yolo_module.YOLO +.. autoclass:: pl_bolts.models.detection.yolo.yolo_module.Yolo :noindex: diff --git a/pl_bolts/models/detection/__init__.py b/pl_bolts/models/detection/__init__.py index e7a864f736..367ae444c1 100644 --- a/pl_bolts/models/detection/__init__.py +++ b/pl_bolts/models/detection/__init__.py @@ -1,10 +1,10 @@ from pl_bolts.models.detection import components # noqa: F401 from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401 -from pl_bolts.models.detection.yolo import Yolo, YoloConfiguration # noqa: F401 +from pl_bolts.models.detection.yolo import YoloConfiguration, Yolo # noqa: F401 __all__ = [ "components", "FasterRCNN", - "Yolo", - "YoloConfiguration" + "YoloConfiguration", + "Yolo" ] diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index bb945e75e5..804f486305 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -5,15 +5,15 @@ class YoloConfiguration: + """Parser for YOLOv4 network configuration files.""" + def __init__(self, path: str): """ - Parser for YOLOv4 network configuration files. - Saves the variables from the first configuration section to attributes of this object, and the rest of the sections to the `modules` list. Args: - path (str): configuration file to read + path: configuration file to read """ with open(path, 'r') as config_file: sections = self._read_file(config_file) diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index da1a66c7d7..930f31cea3 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -1,8 +1,8 @@ -from typing import List, Tuple +from typing import Dict, List, Optional, Tuple import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch import nn +from torch import nn, Tensor from pl_bolts.utils.warnings import warn_missing_pkg @@ -56,27 +56,25 @@ def __init__(self, class_loss_multiplier: float = 1.0, confidence_loss_multiplier: float = 1.0): """ - Constructs a YOLO detection layer. - Args: - num_classes (int): Number of different classes that this layer predicts. - image_width (int): Image width (defines the scale of the anchor box and target bounding + num_classes: Number of different classes that this layer predicts. + image_width: Image width (defines the scale of the anchor box and target bounding box dimensions). - image_height (int): Image height (defines the scale of the anchor box and target + image_height: Image height (defines the scale of the anchor box and target bounding box dimensions). - anchor_dims (List[Tuple[int, int]]): A list of all the predefined anchor box - dimensions. The list should contain (width, height) tuples in the network input - resolution (relative to the width and height defined in the configuration file). - anchor_ids (List[int]): List of indices to `anchor_dims` that is used to select the - (usually 3) anchors that this layer uses. - xy_scale (float): Eliminate "grid sensitivity" by scaling the box coordinates by this - factor. Using a value > 1.0 helps to produce coordinate values close to one. - ignore_threshold (float): If a predictor is not responsible for predicting any target, - but the corresponding anchor has IoU with some target greater than this threshold, - the predictor will not be taken into account when calculating the confidence loss. - coord_loss_multiplier (float): Multiply the coordinate/size loss by this factor. - class_loss_multiplier (float): Multiply the classification loss by this factor. - confidence_loss_multiplier (float): Multiply the confidence loss by this factor. + anchor_dims: A list of all the predefined anchor box dimensions. The list should + contain (width, height) tuples in the network input resolution (relative to the + width and height defined in the configuration file). + anchor_ids: List of indices to `anchor_dims` that is used to select the (usually 3) + anchors that this layer uses. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. + Using a value > 1.0 helps to produce coordinate values close to one. + ignore_threshold: If a predictor is not responsible for predicting any target, but the + corresponding anchor has IoU with some target greater than this threshold, the + predictor will not be taken into account when calculating the confidence loss. + coord_loss_multiplier: Multiply the coordinate/size loss by this factor. + class_loss_multiplier: Multiply the classification loss by this factor. + confidence_loss_multiplier: Multiply the confidence loss by this factor. """ super().__init__() @@ -98,7 +96,10 @@ def __init__(self, self.confidence_loss_multiplier = confidence_loss_multiplier self.se_loss = nn.MSELoss(reduction='none') - def forward(self, x, targets=None): + def forward(self, + x: Tensor, + targets: Optional[List[Dict[str, Tensor]]] = None + ) -> Tuple[Tensor, Dict[str, Tensor]]: """ Runs a forward pass through this YOLO detection layer. @@ -107,15 +108,14 @@ def forward(self, x, targets=None): probabilities to ]0, 1[ range using sigmoid. Args: - x (Tensor): The output from the previous layer. Tensor of size + x : The output from the previous layer. Tensor of size `[batch_size, boxes_per_cell * (num_classes + 5), height, width]`. - targets (List[Dict[str, Tensor]]): If set, computes losses from detection layers - against these targets. A list of dictionaries, one for each image. + targets: If set, computes losses from detection layers against these targets. A list of + dictionaries, one for each image. Returns: - result (Tuple[Tensor, Dict[str, Tensor]]): Layer output, and if training targets were - provided, a dictionary of losses. Layer output is sized - `[batch_size, num_anchors * height * width, num_classes + 5]`. + result: Layer output, and if training targets were provided, a dictionary of losses. + Layer output is sized `[batch_size, num_anchors * height * width, num_classes + 5]`. """ batch_size, num_features, height, width = x.shape num_attrs = self.num_classes + 5 @@ -413,8 +413,6 @@ def _calculate_losses(self, xy, wh, confidence, classprob, targets, np_mask): class Mish(nn.Module): """Mish activation.""" - def __init__(self): - super().__init__() def forward(self, x): return x * torch.tanh(nn.functional.softplus(x)) @@ -425,12 +423,10 @@ class RouteLayer(nn.Module): def __init__(self, source_layers: List[int], num_chunks: int, chunk_idx: int): """ - Creates a YOLO route layer. - Args: - source_layers (List[int]): Indices of the layers whose output will be concatenated. - num_chunks (int): Layer outputs will be split into this number of chunks. - chunk_idx (int): Only the chunks with this index will be concatenated. + source_layers: Indices of the layers whose output will be concatenated. + num_chunks: Layer outputs will be split into this number of chunks. + chunk_idx: Only the chunks with this index will be concatenated. """ super().__init__() self.source_layers = source_layers @@ -446,12 +442,11 @@ def forward(self, x, outputs): class ShortcutLayer(nn.Module): """Shortcut layer adds a residual connection from the source layer.""" - def __init__(self, source_layer): + def __init__(self, source_layer: int): """ - Constructs a YOLO shortcut layer. - Args: - num_classes (int): Number of different classes that this layer predicts. + source_layer: Index of the layer whose output will be added to the output of the + previous layer. """ super().__init__() self.source_layer = source_layer diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index a03dd6c08a..533be66fff 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -1,17 +1,19 @@ import inspect from argparse import ArgumentParser, Namespace from pathlib import Path -from typing import List, Type, Union +from typing import Dict, List, Tuple, Type, Union import numpy as np import pytorch_lightning as pl import torch import torch.nn as nn from pytorch_lightning.utilities import argparse_utils -from torch import optim +from torch import optim, Tensor from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR from pl_bolts.utils.warnings import warn_missing_pkg +from pl_bolts.models.detection.yolo.yolo_config import YoloConfiguration +from pl_bolts.models.detection.yolo.yolo_layers import DetectionLayer, Mish, RouteLayer, ShortcutLayer try: import torchvision.transforms as T @@ -23,11 +25,35 @@ else: _TORCHVISION_AVAILABLE = True -from pl_bolts.models.detection.yolo.yolo_config import YoloConfiguration -from pl_bolts.models.detection.yolo.yolo_layers import DetectionLayer, Mish, RouteLayer, ShortcutLayer - class Yolo(pl.LightningModule): + """ + PyTorch Lightning implementation of `YOLOv3 `_ with some + improvements from `YOLOv4 `_. + + YOLOv3 paper authors: Joseph Redmon and Ali Farhadi + + YOLOv4 paper authors: Alexey Bochkovskiy, Chien-Yao Wang, Hong-Yuan Mark Liao + + Model implemented by: + - `Seppo Enarvi `_ + + The network architecture is read from a configuration file in the same format as in the Darknet + implementation. Supports loading weights from a Darknet model file too, if you don't want to + start training from a randomly initialized model. During training, the model expects both the + images (list of tensors), as well as targets (list of dictionaries). + + The target dictionaries should contain: + - boxes (`FloatTensor[N, 4]`): the ground truth boxes in `[x1, y1, x2, y2]` format. + - labels (`LongTensor[N]`): the class label for each ground truh box + + CLI command:: + + # PascalVOC + wget https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4-tiny.cfg + python yolo_module.py --config yolov4-tiny.cfg --data_dir . --gpus 8 --batch-size 8 + """ + def __init__(self, configuration: YoloConfiguration, optimizer: str = 'sgd', @@ -40,22 +66,21 @@ def __init__(self, confidence_threshold: float = 0.2, nms_threshold: float = 0.45): """ - Constructs a YOLO model. - Args: - config (YoloConfiguration): The model configuration file. - momentum (float): Momentum factor for SGD with momentum. - weight_decay (float): Weight decay (L2 penalty). - learning_rate (float): Learning rate after the warmup period. - warmup_epochs (int): Length of the learning rate warmup period in the beginning of + configuration: The model configuration. + optimizer: Which optimizer to use for training; either 'sgd' or 'adam'. + momentum: Momentum factor for SGD with momentum. + weight_decay: Weight decay (L2 penalty). + learning_rate: Learning rate after the warmup period. + warmup_epochs: Length of the learning rate warmup period in the beginning of training. During this number of epochs, the learning rate will be raised from `warmup_start_lr` to `learning_rate`. - warmup_start_lr (int): Learning rate in the beginning of the warmup period. - annealing_epochs (int): Length of the learning rate annealing period, during which the + warmup_start_lr: Learning rate in the beginning of the warmup period. + annealing_epochs: Length of the learning rate annealing period, during which the learning rate will go to zero. - confidence_threshold (float): Postprocessing will remove bounding boxes whose + confidence_threshold: Postprocessing will remove bounding boxes whose confidence score is not higher than this threshold. - nms_threshold (float): Non-maximum suppression will remove bounding boxes whose IoU + nms_threshold: Non-maximum suppression will remove bounding boxes whose IoU with the next best bounding box in that class is higher than this threshold. """ super().__init__() @@ -78,8 +103,10 @@ def __init__(self, self._create_modules() - def forward(self, images, targets=None): - # type: (List[Tensor]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + def forward(self, + images: Tensor, + targets: List[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Tensor, Tensor, Dict[str, Tensor]]: """ Runs a forward pass through the network (all layers listed in `self._module_list`), and if training targets are provided, computes the losses from the detection layers. @@ -92,13 +119,13 @@ def forward(self, images, targets=None): features. Args: - images (Tensor): Images to be processed. Tensor of size + images: Images to be processed. Tensor of size `[batch_size, num_channels, height, width]`. - targets (List[Dict[str, Tensor]]): If set, computes losses from detection layers - against these targets. A list of dictionaries, one for each image. + targets: If set, computes losses from detection layers against these targets. A list of + dictionaries, one for each image. Returns: - boxes (Tensor), confidences (Tensor), classprobs (Tensor), losses (Dict[str, Tensor]): + boxes (:class:`~torch.Tensor`), confidences (:class:`~torch.Tensor`), classprobs (:class:`~torch.Tensor`), losses (Dict[str, :class:`~torch.Tensor`]): Detections, and if targets were provided, a dictionary of losses. The first dimension of the detections is the index of the image in the batch and the second dimension is the detection within the image. `boxes` contains the predicted @@ -140,7 +167,7 @@ def mean_loss(loss_name): else: return boxes, confidences, classprobs - def configure_optimizers(self): + def configure_optimizers(self) -> Tuple[List, List]: """Constructs the optimizer and learning rate scheduler.""" if self.optimizer == 'sgd': optimizer = optim.SGD( @@ -160,16 +187,17 @@ def configure_optimizers(self): warmup_start_lr=self.warmup_start_lr) return [optimizer], [lr_scheduler] - def training_step(self, batch, batch_idx): - # type: (Tuple[List[Tensor], List[Dict[str, Tensor]]]) -> Dict[str, Tensor] + def training_step(self, + batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], + batch_idx: int + ) -> Dict[str, Tensor]: """ Computes the training loss. Args: - batch (Tuple[List[Tensor], List[Dict[str, Tensor]]]): - A tuple of images and targets. Images is a list of 3-dimensional tensors. Targets - is a list of dictionaries that contain ground-truth boxes, labels, etc. - batch_idx (int): The index of this batch. + batch: A tuple of images and targets. Images is a list of 3-dimensional tensors. + Targets is a list of dictionaries that contain ground-truth boxes, labels, etc. + batch_idx: The index of this batch. Returns: A dictionary that includes the training loss in 'loss'. @@ -184,15 +212,17 @@ def training_step(self, batch, batch_idx): return {'loss': total_loss} - def validation_step(self, batch, batch_idx): - # type: (Tuple[List[Tensor], List[Dict[str, Tensor]]], int) -> Dict[str, Tensor] + def validation_step(self, + batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], + batch_idx: int + ) -> Dict[str, Tensor]: """ Evaluates a batch of data from the validation set. Args: - batch (Tuple[List[Tensor], List[Dict[str, Tensor]]]): - The batch of data read by the :class:`~torch.utils.data.DataLoader` - batch_idx (int): The index of this batch + batch: A tuple of images and targets. Images is a list of 3-dimensional tensors. + Targets is a list of dictionaries that contain ground-truth boxes, labels, etc. + batch_idx: The index of this batch """ images, targets = self._validate_batch(batch) boxes, confidences, classprobs, losses = self(images, targets) @@ -205,15 +235,17 @@ def validation_step(self, batch, batch_idx): self.log('val/{}_loss'.format(name), value) self.log('val/total_loss', total_loss) - def test_step(self, batch, batch_idx): - # type: (Tuple[List[Tensor], List[Dict[str, Tensor]]], int) -> Dict[str, Tensor] + def test_step(self, + batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], + batch_idx: int + ) -> Dict[str, Tensor]: """ Evaluates a batch of data from the test set. Args: - batch (Tuple[List[Tensor], List[Dict[str, Tensor]]]): - The batch of data read by the :class:`~torch.utils.data.DataLoader` - batch_idx (int): The index of this batch. + batch: A tuple of images and targets. Images is a list of 3-dimensional tensors. + Targets is a list of dictionaries that contain ground-truth boxes, labels, etc. + batch_idx: The index of this batch. """ images, targets = self._validate_batch(batch) boxes, confidences, classprobs, losses = self(images, targets) @@ -226,20 +258,19 @@ def test_step(self, batch, batch_idx): self.log('test/{}_loss'.format(name), value) self.log('test/total_loss', total_loss) - def infer(self, image): - # type: (ndarray) -> Dict[str, Tensor] + def infer(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """ Resizes given image to the network input size and feeds it to the network. Returns the detected bounding boxes, confidences, and class labels. Args: - image (Tensor): - An input image, a tensor of uint8 values sized `[channels, height, width]`. + image: An input image, a tensor of uint8 values sized `[channels, height, width]`. Returns: - boxes (Tensor): A matrix of detected bounding box (x1, y1, x2, y2) coordinates. - confidences (Tensor): A vector of confidences for the bounding box detections. - labels (Tensor): A vector of predicted class labels. + boxes (:class:`~torch.Tensor`), confidences (:class:`~torch.Tensor`), labels (:class:`~torch.Tensor`): + A matrix of detected bounding box (x1, y1, x2, y2) coordinates, a vector of + confidences for the bounding box detections, and a vector of predicted class + labels. """ network_input = image.float().div(255.0) network_input = network_input.unsqueeze(0) @@ -263,6 +294,9 @@ def infer(self, image): def load_darknet_weights(self, weight_file): """ Loads weights to layer modules from a pretrained Darknet model. + + Args: + weight_file: A file object containing model weights in the Darknet binary format. """ version = np.fromfile(weight_file, count=3, dtype=np.int32) images_seen = np.fromfile(weight_file, count=1, dtype=np.int64) @@ -388,13 +422,19 @@ def _create_modules(self): self._module_list.append(module) layer_outputs.append(num_outputs) - def _validate_batch(self, batch): + def _validate_batch(self, + batch: Tuple[List[Tensor], List[Dict[str, Tensor]]] + ) -> Tuple[Tensor, List[Dict[str, Tensor]]]: """ - Reads a batch of data and validates the format. + Reads a batch of data, validates the format, and stacks the images into a single tensor. Args: - batch (Tuple[List[Tensor], List[Dict[str, Tensor]]]): - The batch of data read by the :class:`~torch.utils.data.DataLoader` + batch (Tuple[List[:class:`~torch.Tensor`], List[Dict[str, :class:`~torch.Tensor`]]]): + The batch of data read by the :class:`~torch.utils.data.DataLoader`. + + Returns: + batch (Tuple[:class:`~torch.Tensor`, List[Dict[str, :class:`~torch.Tensor`]]]): + The input batch with images stacked into a single tensor. """ images, targets = batch @@ -403,7 +443,7 @@ def _validate_batch(self, batch): .format(len(images), len(targets))) for image in images: - if not isinstance(image, torch.Tensor): + if not isinstance(image, Tensor): raise ValueError("Expected image to be of type Tensor, got {}." .format(type(image))) expected_shape = torch.Size((self.config.channels, @@ -415,14 +455,14 @@ def _validate_batch(self, batch): for target in targets: boxes = target['boxes'] - if not isinstance(boxes, torch.Tensor): + if not isinstance(boxes, Tensor): raise ValueError("Expected target boxes to be of type Tensor, got {}." .format(type(boxes))) if (len(boxes.shape) != 2) or (boxes.shape[-1] != 4): raise ValueError("Expected target boxes to be tensors of shape [N, 4], got {}." .format(list(boxes.shape))) labels = target['labels'] - if not isinstance(labels, torch.Tensor): + if not isinstance(labels, Tensor): raise ValueError("Expected target labels to be of type Tensor, got {}." .format(type(labels))) if len(labels.shape) != 1: @@ -432,7 +472,12 @@ def _validate_batch(self, batch): images = torch.stack(images) return images, targets - def _filter_detections(self, boxes, confidences, classprobs, labels): + def _filter_detections(self, + boxes: Tensor, + confidences: Tensor, + classprobs: Tensor, + labels: Tensor + ) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]: """ Filters detections based on confidence threshold. Then for every class performs non-maximum suppression (NMS). NMS iterates the bounding boxes that predict this class in descending @@ -440,26 +485,17 @@ def _filter_detections(self, boxes, confidences, classprobs, labels): higher than the NMS threshold. Args: - boxes (Tensor): - Detected bounding box (x1, y1, x2, y2) coordinates in a tensor sized + boxes: Detected bounding box (x1, y1, x2, y2) coordinates in a tensor sized `[batch_size, N, 4]`. - confidences (Tensor): - Detection confidences in a tensor sized `[batch_size, N]`. - classprobs (Tensor): - Probabilities of the best classes in a tensor sized `[batch_size, N]`. - labels (Tensor): - Indices of the best classes in a tensor sized `[batch_size, N]`. + confidences: Detection confidences in a tensor sized `[batch_size, N]`. + classprobs: Probabilities of the best classes in a tensor sized `[batch_size, N]`. + labels: Indices of the best classes in a tensor sized `[batch_size, N]`. Returns: - boxes (List[Tensor]): - List of bounding box (x1, y1, x2, y2) coordinates, one tensor for each image. - confidences (List[Tensor]): - List of detection confidences, one tensor for each image. - classprobs (List[Tensor]): - List of tensors, one for each image, that contain the probabilities of the best - class of each prediction. - labels (List[Tensor]): - List of predicted class labels, one for each image. + boxes (List[:class:`~torch.Tensor`]), confidences (List[:class:`~torch.Tensor`]), classprobs (List[:class:`~torch.Tensor`]), labels (List[:class:`~torch.Tensor`]): + Four lists, each containing one tensor per image - bounding box (x1, y1, x2, y2) + coordinates, detection confidences, probabilities of the best class of each + prediction, and the predicted class labels. """ out_boxes = [] out_confidences = [] @@ -501,19 +537,6 @@ class of each prediction. return out_boxes, out_confidences, out_classprobs, out_labels - def _scale_boxes(self, images, boxes): - """Scales the box coordinates to image dimensions.""" - result = [] - - for image, img_boxes in zip(images, boxes): - height = image.shape[1] - width = image.shape[2] - scale = torch.tensor([width, height, width, height], device=img_boxes.device) - img_boxes = img_boxes * scale - result.append(img_boxes) - - return result - class Resize: """Rescales the image and target to given dimensions. From 876da0d4d39c448d597a7a1facaef8467d6e9ede Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 3 Feb 2021 16:24:55 +0200 Subject: [PATCH 04/61] Fixed style issues. --- pl_bolts/models/detection/yolo/yolo_layers.py | 7 +-- pl_bolts/models/detection/yolo/yolo_module.py | 48 +++++++++++-------- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 930f31cea3..3f79a93b87 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -96,9 +96,10 @@ def __init__(self, self.confidence_loss_multiplier = confidence_loss_multiplier self.se_loss = nn.MSELoss(reduction='none') - def forward(self, - x: Tensor, - targets: Optional[List[Dict[str, Tensor]]] = None + def forward( + self, + x: Tensor, + targets: Optional[List[Dict[str, Tensor]]] = None ) -> Tuple[Tensor, Dict[str, Tensor]]: """ Runs a forward pass through this YOLO detection layer. diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 533be66fff..acded9fb17 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -103,9 +103,10 @@ def __init__(self, self._create_modules() - def forward(self, - images: Tensor, - targets: List[Dict[str, Tensor]] = None + def forward( + self, + images: Tensor, + targets: List[Dict[str, Tensor]] = None ) -> Tuple[Tensor, Tensor, Tensor, Dict[str, Tensor]]: """ Runs a forward pass through the network (all layers listed in `self._module_list`), and if @@ -125,7 +126,7 @@ def forward(self, dictionaries, one for each image. Returns: - boxes (:class:`~torch.Tensor`), confidences (:class:`~torch.Tensor`), classprobs (:class:`~torch.Tensor`), losses (Dict[str, :class:`~torch.Tensor`]): + boxes (Tensor), confidences (Tensor), classprobs (Tensor), losses (Dict[str, Tensor]): Detections, and if targets were provided, a dictionary of losses. The first dimension of the detections is the index of the image in the batch and the second dimension is the detection within the image. `boxes` contains the predicted @@ -187,9 +188,10 @@ def configure_optimizers(self) -> Tuple[List, List]: warmup_start_lr=self.warmup_start_lr) return [optimizer], [lr_scheduler] - def training_step(self, - batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], - batch_idx: int + def training_step( + self, + batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], + batch_idx: int ) -> Dict[str, Tensor]: """ Computes the training loss. @@ -212,9 +214,10 @@ def training_step(self, return {'loss': total_loss} - def validation_step(self, - batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], - batch_idx: int + def validation_step( + self, + batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], + batch_idx: int ) -> Dict[str, Tensor]: """ Evaluates a batch of data from the validation set. @@ -235,9 +238,10 @@ def validation_step(self, self.log('val/{}_loss'.format(name), value) self.log('val/total_loss', total_loss) - def test_step(self, - batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], - batch_idx: int + def test_step( + self, + batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], + batch_idx: int ) -> Dict[str, Tensor]: """ Evaluates a batch of data from the test set. @@ -422,8 +426,9 @@ def _create_modules(self): self._module_list.append(module) layer_outputs.append(num_outputs) - def _validate_batch(self, - batch: Tuple[List[Tensor], List[Dict[str, Tensor]]] + def _validate_batch( + self, + batch: Tuple[List[Tensor], List[Dict[str, Tensor]]] ) -> Tuple[Tensor, List[Dict[str, Tensor]]]: """ Reads a batch of data, validates the format, and stacks the images into a single tensor. @@ -472,11 +477,12 @@ def _validate_batch(self, images = torch.stack(images) return images, targets - def _filter_detections(self, - boxes: Tensor, - confidences: Tensor, - classprobs: Tensor, - labels: Tensor + def _filter_detections( + self, + boxes: Tensor, + confidences: Tensor, + classprobs: Tensor, + labels: Tensor ) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]: """ Filters detections based on confidence threshold. Then for every class performs non-maximum @@ -492,7 +498,7 @@ def _filter_detections(self, labels: Indices of the best classes in a tensor sized `[batch_size, N]`. Returns: - boxes (List[:class:`~torch.Tensor`]), confidences (List[:class:`~torch.Tensor`]), classprobs (List[:class:`~torch.Tensor`]), labels (List[:class:`~torch.Tensor`]): + boxes (List[Tensor]), confidences (List[Tensor]), classprobs (List[Tensor]), labels (List[Tensor]): Four lists, each containing one tensor per image - bounding box (x1, y1, x2, y2) coordinates, detection confidences, probabilities of the best class of each prediction, and the predicted class labels. From f99930e99dc5acc933771095fc0d2408df4118c9 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Fri, 5 Feb 2021 12:46:04 +0200 Subject: [PATCH 05/61] Refactoring --- pl_bolts/models/detection/yolo/yolo_layers.py | 129 ++++++++++-------- pl_bolts/models/detection/yolo/yolo_module.py | 10 +- 2 files changed, 73 insertions(+), 66 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 3f79a93b87..8e829610a7 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -52,16 +52,19 @@ def __init__(self, anchor_ids: List[int], xy_scale: float = 1.0, ignore_threshold: float = 0.5, - coord_loss_multiplier: float = 1.0, + overlap_loss_func: Callable = None, + class_loss_func: Callable = None, + confidence_loss_func: Callable = None, + overlap_loss_multiplier: float = 1.0, class_loss_multiplier: float = 1.0, confidence_loss_multiplier: float = 1.0): """ Args: num_classes: Number of different classes that this layer predicts. - image_width: Image width (defines the scale of the anchor box and target bounding - box dimensions). - image_height: Image height (defines the scale of the anchor box and target - bounding box dimensions). + image_width: Image width (defines the scale of the anchor box and target bounding box + dimensions). + image_height: Image height (defines the scale of the anchor box and target bounding box + dimensions). anchor_dims: A list of all the predefined anchor box dimensions. The list should contain (width, height) tuples in the network input resolution (relative to the width and height defined in the configuration file). @@ -72,6 +75,12 @@ def __init__(self, ignore_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. + overlap_loss_func: Loss function for (x, y, w, h) coordinates. Default is the sum of + squared errors. + class_loss_func: Loss function for class probability distribution. Default is the sum + of squared errors. + confidence_loss_func: Loss function for confidence score. Default is the sum of squared + errors. coord_loss_multiplier: Multiply the coordinate/size loss by this factor. class_loss_multiplier: Multiply the classification loss by this factor. confidence_loss_multiplier: Multiply the confidence loss by this factor. @@ -91,10 +100,15 @@ def __init__(self, self.anchor_map = [anchor_ids.index(i) if i in anchor_ids else -1 for i in range(9)] self.xy_scale = xy_scale self.ignore_threshold = ignore_threshold - self.coord_loss_multiplier = coord_loss_multiplier + + self.overlap_loss_multiplier = overlap_loss_multiplier self.class_loss_multiplier = class_loss_multiplier self.confidence_loss_multiplier = confidence_loss_multiplier - self.se_loss = nn.MSELoss(reduction='none') + + se_loss = nn.MSELoss(reduction='none') + self.overlap_loss_func = overlap_loss_func or se_loss + self.class_loss_func = class_loss_func or se_loss + self.confidence_loss_func = confidence_loss_func or se_loss def forward( self, @@ -109,7 +123,7 @@ def forward( probabilities to ]0, 1[ range using sigmoid. Args: - x : The output from the previous layer. Tensor of size + x: The output from the previous layer. Tensor of size `[batch_size, boxes_per_cell * (num_classes + 5), height, width]`. targets: If set, computes losses from detection layers against these targets. A list of dictionaries, one for each image. @@ -154,8 +168,8 @@ def forward( if targets is None: return output else: - np_mask = self._no_prediction_mask(corners, targets) - losses = self._calculate_losses(xy, wh, confidence, classprob, targets, np_mask) + lc_mask = self._low_confidence_mask(corners, targets) + losses = self._calculate_losses(xy, wh, confidence, classprob, targets, lc_mask) return output, losses def _global_xy(self, xy): @@ -220,14 +234,14 @@ def _corner_coordinates(self, xy, wh): bottom_right = xy + half_wh return torch.cat((top_left, bottom_right), -1) - def _no_prediction_mask(self, preds, targets): + def _low_confidence_mask(self, boxes, targets): """ - Initializes the mask that will be used to select predictors that are not responsible for - predicting any target. The value will be `True`, unless the predicted box overlaps any - target significantly (IoU greater than `self.ignore_threshold`). + Initializes the mask that will be used to select predictors that are not predicting any + ground-truth target. The value will be `True`, unless the predicted box overlaps any target + significantly (IoU greater than `self.ignore_threshold`). Args: - preds (Tensor): The predicted corner coordinates, normalized to the [0, 1] range. + boxes (Tensor): The predicted corner coordinates, normalized to the [0, 1] range. Tensor of size `[batch_size, height, width, boxes_per_cell, 4]`. targets (List[Dict[str, Tensor]]): List of dictionaries of target values, one dictionary for each image. @@ -237,27 +251,28 @@ def _no_prediction_mask(self, preds, targets): with `False` where the predicted box overlaps a target significantly and `True` elsewhere. """ - shape = preds.shape - preds = preds.view(shape[0], -1, shape[-1]) + batch_size, height, width, boxes_per_cell, num_coords = boxes.shape + num_preds = height * width * boxes_per_cell + boxes = boxes.view(batch_size, num_preds, num_coords) scale = torch.tensor([self.image_width, self.image_height, self.image_width, self.image_height], - device=preds.device) - preds = preds * scale + device=boxes.device) + boxes = boxes * scale - results = torch.ones(preds.shape[:-1], dtype=torch.bool, device=preds.device) - for image_idx, (image_preds, image_targets) in enumerate(zip(preds, targets)): + results = torch.ones((batch_size, num_preds), dtype=torch.bool, device=boxes.device) + for image_idx, (image_boxes, image_targets) in enumerate(zip(boxes, targets)): target_boxes = image_targets['boxes'] if target_boxes.shape[0] > 0: - ious = box_iou(image_preds, target_boxes) - best_ious = ious.max(-1).values - results[image_idx] = best_ious <= self.ignore_threshold - results = results.view(shape[:-1]) - return results + ious = box_iou(image_boxes, target_boxes) # [num_preds, num_targets] + best_iou = ious.max(-1).values # [num_preds] + results[image_idx] = best_iou <= self.ignore_threshold + + return results.view((batch_size, height, width, boxes_per_cell)) - def _calculate_losses(self, xy, wh, confidence, classprob, targets, np_mask): + def _calculate_losses(self, xy, wh, confidence, classprob, targets, lc_mask): """ From the targets that are in the image space calculates the actual targets for the network predictions, and returns a dictionary of training losses. @@ -273,8 +288,8 @@ def _calculate_losses(self, xy, wh, confidence, classprob, targets, np_mask): sized `[batch_size, height, width, boxes_per_cell, num_classes]`. targets (List[Dict[str, Tensor]]): List of dictionaries of target values, one dictionary for each image. - np_mask: A boolean mask containing `True` where the predicted box does not overlap any - target significantly. + lc_mask (Tensor): A boolean mask containing `True` where the predicted box does not + overlap any target significantly. Returns: predicted (Dict[str, Tensor]): A dictionary of training losses. @@ -341,9 +356,10 @@ def _calculate_losses(self, xy, wh, confidence, classprob, targets, np_mask): predictors = predictors[selected] best_anchors = best_anchors[selected] - # The "no-prediction" mask is used to select predictors that are not responsible for - # predicting any object for calculating the confidence loss. - np_mask[image_idx, cell_j, cell_i, predictors] = False + # The "low-confidence" mask is used to select predictors that are not responsible for + # predicting any object, for calculating the part of the confidence loss with zero as + # the target confidence. + lc_mask[image_idx, cell_j, cell_i, predictors] = False # Bounding box targets relative_xy = box_xy - box_xy.floor() @@ -372,42 +388,35 @@ def _calculate_losses(self, xy, wh, confidence, classprob, targets, np_mask): if pred_xy and pred_wh and target_xy and target_wh: size_compensation = torch.cat(size_compensation).unsqueeze(1) - pred_xy = torch.cat(pred_xy) - target_xy = torch.cat(target_xy) - location_loss = self.se_loss(pred_xy, target_xy) - location_loss = location_loss * size_compensation - location_loss = location_loss.sum() / batch_size - losses['location'] = location_loss * self.coord_loss_multiplier - - pred_wh = torch.cat(pred_wh) - target_wh = torch.cat(target_wh) - size_loss = self.se_loss(pred_wh, target_wh) - size_loss = size_loss * size_compensation - size_loss = size_loss.sum() / batch_size - losses['size'] = size_loss * self.coord_loss_multiplier - - class_loss = None + pred_xywh = torch.cat((torch.cat(pred_xy), torch.cat(pred_wh)), -1) + target_xywh = torch.cat((torch.cat(target_xy), torch.cat(target_wh)), -1) + overlap_loss = self.overlap_loss_func(pred_xywh, target_xywh) + overlap_loss = overlap_loss * size_compensation + overlap_loss = overlap_loss.sum() / batch_size + losses['overlap'] = overlap_loss * self.overlap_loss_multiplier + if pred_classprob and target_label: pred_classprob = torch.cat(pred_classprob) target_label = torch.cat(target_label) target_classprob = torch.nn.functional.one_hot(target_label, self.num_classes) target_classprob = target_classprob.to(dtype=pred_classprob.dtype) - class_loss = self.se_loss(pred_classprob, target_classprob) + class_loss = self.class_loss_func(pred_classprob, target_classprob) class_loss = class_loss.sum() / batch_size losses['class'] = class_loss * self.class_loss_multiplier - np_confidence = confidence[np_mask] - np_target_confidence = torch.zeros_like(np_confidence) - np_confidence_loss = self.se_loss(np_confidence, np_target_confidence) - np_confidence_loss = np_confidence_loss.sum() / batch_size - losses['np_confidence'] = np_confidence_loss * self.confidence_loss_multiplier - + pred_low_confidence = confidence[lc_mask] + target_low_confidence = torch.zeros_like(pred_low_confidence) if pred_confidence: - p_confidence = torch.cat(pred_confidence) - p_target_confidence = torch.ones_like(p_confidence) - p_confidence_loss = self.se_loss(p_confidence, p_target_confidence) - p_confidence_loss = p_confidence_loss.sum() / batch_size - losses['p_confidence'] = p_confidence_loss * self.confidence_loss_multiplier + pred_high_confidence = torch.cat(pred_confidence) + target_high_confidence = torch.ones_like(pred_high_confidence) + pred_confidence = torch.cat((pred_low_confidence, pred_high_confidence)) + target_confidence = torch.cat((target_low_confidence, target_high_confidence)) + else: + pred_confidence = pred_low_confidence + target_confidence = target_low_confidence + confidence_loss = self.confidence_loss_func(pred_confidence, target_confidence) + confidence_loss = confidence_loss.sum() / batch_size + losses['confidence'] = confidence_loss * self.confidence_loss_multiplier return losses diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index acded9fb17..0924125260 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -403,7 +403,7 @@ def _create_modules(self): xy_scale = config.get('scale_x_y', 1.0) ignore_threshold = config.get('ignore_thresh', 1.0) - coord_loss_multiplier = config.get('iou_normalizer', 1.0) + overlap_loss_multiplier = config.get('iou_normalizer', 1.0) class_loss_multiplier = config.get('cls_normalizer', 1.0) confidence_loss_multiplier = config.get('obj_normalizer', 1.0) @@ -415,7 +415,7 @@ def _create_modules(self): anchor_ids=config['mask'], xy_scale=xy_scale, ignore_threshold=ignore_threshold, - coord_loss_multiplier=coord_loss_multiplier, + overlap_loss_multiplier=overlap_loss_multiplier, class_loss_multiplier=class_loss_multiplier, confidence_loss_multiplier=confidence_loss_multiplier) @@ -434,12 +434,10 @@ def _validate_batch( Reads a batch of data, validates the format, and stacks the images into a single tensor. Args: - batch (Tuple[List[:class:`~torch.Tensor`], List[Dict[str, :class:`~torch.Tensor`]]]): - The batch of data read by the :class:`~torch.utils.data.DataLoader`. + batch: The batch of data read by the :class:`~torch.utils.data.DataLoader`. Returns: - batch (Tuple[:class:`~torch.Tensor`, List[Dict[str, :class:`~torch.Tensor`]]]): - The input batch with images stacked into a single tensor. + batch: The input batch with images stacked into a single tensor. """ images, targets = batch From 4415d419e726592c24320574fdb762b90a50669e Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 8 Feb 2021 16:07:28 +0200 Subject: [PATCH 06/61] Refactoring --- pl_bolts/models/detection/yolo/yolo_config.py | 143 +++++++++++++++++- pl_bolts/models/detection/yolo/yolo_layers.py | 10 +- pl_bolts/models/detection/yolo/yolo_module.py | 118 ++------------- 3 files changed, 158 insertions(+), 113 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index 804f486305..af50c62245 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -2,18 +2,25 @@ from warnings import warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +import torch.nn as nn + +from pl_bolts.models.detection.yolo.yolo_layers import * class YoloConfiguration: - """Parser for YOLOv4 network configuration files.""" + """ + This class can be used to parse the configuration files of the Darknet YOLOv4 implementation. + The `get_network()` method returns a PyTorch module list that can be used to construct a YOLO + model. + """ def __init__(self, path: str): """ Saves the variables from the first configuration section to attributes of this object, and - the rest of the sections to the `modules` list. + the rest of the sections to the `layer_configs` list. Args: - path: configuration file to read + path: Path to a configuration file """ with open(path, 'r') as config_file: sections = self._read_file(config_file) @@ -23,7 +30,25 @@ def __init__(self, path: str): "The model configuration file should include at least two sections.") self.__dict__.update(sections[0]) - self.modules = sections[1:] + self.global_config = sections[0] + self.layer_configs = sections[1:] + + def get_network(self) -> nn.ModuleList: + """ + Iterates through the layers from the configuration and creates corresponding PyTorch + modules. Returns the network structure that can be used to create a YOLO model. + + Returns: + modules: A `nn.ModuleList` that defines the YOLO network. + """ + result = nn.ModuleList() + num_inputs = [3] # Number of channels in the input of every layer up to the current layer + for layer_config in self.layer_configs: + config = {**self.global_config, **layer_config} + module, num_outputs = _create_layer(config, num_inputs) + result.append(module) + num_inputs.append(num_outputs) + return result def _read_file(self, config_file): """ @@ -33,7 +58,7 @@ def _read_file(self, config_file): config_file (iterable over lines): The configuration file to read. Returns: - sections (list): A list of configuration sections. + sections (List[dict]): A list of configuration sections. """ section_re = re.compile(r'\[([^]]+)\]') list_variables = ('layers', 'anchors', 'mask', 'scales') @@ -120,3 +145,111 @@ def convert(key, value): sections.append(section) return sections + + +def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]: + """ + Calls one of the `_create_(config, num_inputs)` functions to create a PyTorch + module from the layer config. + + Args: + config: Dictionary of configuration options for this layer. + num_inputs: Number of channels in the input of every layer up to this layer. + + Returns: + module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the + number of channels in its output. + """ + create_func = { + 'convolutional': _create_convolutional, + 'maxpool': _create_maxpool, + 'route': _create_route, + 'shortcut': _create_shortcut, + 'upsample': _create_upsample, + 'yolo': _create_yolo + } + return create_func[config['type']](config, num_inputs) + +def _create_convolutional(config, num_inputs): + module = nn.Sequential() + + batch_normalize = config.get('batch_normalize', False) + padding = (config['size'] - 1) // 2 if config['pad'] else 0 + + conv = nn.Conv2d( + num_inputs[-1], + config['filters'], + config['size'], + config['stride'], + padding, + bias=not batch_normalize) + module.add_module('conv', conv) + + if batch_normalize: + bn = nn.BatchNorm2d(config['filters']) + module.add_module('bn', bn) + + if config['activation'] == 'leaky': + leakyrelu = nn.LeakyReLU(0.1, inplace=True) + module.add_module('leakyrelu', leakyrelu) + elif config['activation'] == 'mish': + mish = Mish() + module.add_module('mish', mish) + + return module, config['filters'] + +def _create_maxpool(config, num_inputs): + padding = (config['size'] - 1) // 2 + module = nn.MaxPool2d(config['size'], config['stride'], padding) + return module, num_inputs[-1] + +def _create_route(config, num_inputs): + num_chunks = config.get('groups', 1) + chunk_idx = config.get('group_id', 0) + + # 0 is the first layer, -1 is the previous layer + last = len(num_inputs) - 1 + source_layers = [layer if layer >= 0 else last + layer + for layer in config['layers']] + + module = RouteLayer(source_layers, num_chunks, chunk_idx) + + # The number of outputs of a source layer is the number of inputs of the next layer. + num_outputs = sum(num_inputs[layer + 1] // num_chunks + for layer in source_layers) + + return module, num_outputs + +def _create_shortcut(config, num_inputs): + module = ShortcutLayer(config['from']) + return module, num_inputs[-1] + +def _create_upsample(config, num_inputs): + module = nn.Upsample(scale_factor=config["stride"], mode='nearest') + return module, num_inputs[-1] + +def _create_yolo(config, num_inputs): + # The "anchors" list alternates width and height. + anchor_dims = config['anchors'] + anchor_dims = [(anchor_dims[i], anchor_dims[i + 1]) + for i in range(0, len(anchor_dims), 2)] + + xy_scale = config.get('scale_x_y', 1.0) + ignore_threshold = config.get('ignore_thresh', 1.0) + overlap_loss_multiplier = config.get('iou_normalizer', 1.0) + class_loss_multiplier = config.get('cls_normalizer', 1.0) + confidence_loss_multiplier = config.get('obj_normalizer', 1.0) + + module = DetectionLayer( + num_classes=config['classes'], + image_width=config['width'], + image_height=config['height'], + anchor_dims=anchor_dims, + anchor_ids=config['mask'], + xy_scale=xy_scale, + ignore_threshold=ignore_threshold, + overlap_loss_multiplier=overlap_loss_multiplier, + class_loss_multiplier=class_loss_multiplier, + confidence_loss_multiplier=confidence_loss_multiplier) + + return module, num_inputs[-1] diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 8e829610a7..7c823f4f31 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -20,7 +20,7 @@ def _aligned_iou(dims1, dims2): Calculates a matrix of intersections over union from box dimensions, assuming that the boxes are located at the same coordinates. - Arguments: + Args: dims1 (Tensor[N, 2]): width and height of N boxes dims2 (Tensor[M, 2]): width and height of M boxes @@ -167,10 +167,10 @@ def forward( if targets is None: return output - else: - lc_mask = self._low_confidence_mask(corners, targets) - losses = self._calculate_losses(xy, wh, confidence, classprob, targets, lc_mask) - return output, losses + + lc_mask = self._low_confidence_mask(corners, targets) + losses = self._calculate_losses(xy, wh, confidence, classprob, targets, lc_mask) + return output, losses def _global_xy(self, xy): """ diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 0924125260..f9308e46f8 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -13,7 +13,7 @@ from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR from pl_bolts.utils.warnings import warn_missing_pkg from pl_bolts.models.detection.yolo.yolo_config import YoloConfiguration -from pl_bolts.models.detection.yolo.yolo_layers import DetectionLayer, Mish, RouteLayer, ShortcutLayer +from pl_bolts.models.detection.yolo.yolo_layers import DetectionLayer, RouteLayer, ShortcutLayer try: import torchvision.transforms as T @@ -38,10 +38,12 @@ class Yolo(pl.LightningModule): Model implemented by: - `Seppo Enarvi `_ - The network architecture is read from a configuration file in the same format as in the Darknet - implementation. Supports loading weights from a Darknet model file too, if you don't want to - start training from a randomly initialized model. During training, the model expects both the - images (list of tensors), as well as targets (list of dictionaries). + The network architecture can be read from a Darknet configuration file using the + :class:`~pl_bolts.models.detection.yolo.yolo_config.YoloConfiguration` class, or created by + some other means, and provided as a list of PyTorch modules. Supports loading weights from a + Darknet model file too, if you don't want to start training from a randomly initialized model. + During training, the model expects both the images (list of tensors), as well as targets (list + of dictionaries). The target dictionaries should contain: - boxes (`FloatTensor[N, 4]`): the ground truth boxes in `[x1, y1, x2, y2]` format. @@ -55,7 +57,7 @@ class Yolo(pl.LightningModule): """ def __init__(self, - configuration: YoloConfiguration, + network: nn.ModuleList, optimizer: str = 'sgd', momentum: float = 0.9, weight_decay: float = 0.0005, @@ -67,7 +69,8 @@ def __init__(self, nms_threshold: float = 0.45): """ Args: - configuration: The model configuration. + network: A list of network modules. This can be obtained from a Darknet configuration + using the `YoloConfiguration.get_network()` method. optimizer: Which optimizer to use for training; either 'sgd' or 'adam'. momentum: Momentum factor for SGD with momentum. weight_decay: Weight decay (L2 penalty). @@ -90,7 +93,7 @@ def __init__(self, 'YOLO model uses `torchvision`, which is not installed yet.' ) - self.config = configuration + self.network = network self.optimizer = optimizer self.momentum = momentum self.weight_decay = weight_decay @@ -101,15 +104,13 @@ def __init__(self, self.confidence_threshold = confidence_threshold self.nms_threshold = nms_threshold - self._create_modules() - def forward( self, images: Tensor, targets: List[Dict[str, Tensor]] = None ) -> Tuple[Tensor, Tensor, Tensor, Dict[str, Tensor]]: """ - Runs a forward pass through the network (all layers listed in `self._module_list`), and if + Runs a forward pass through the network (all layers listed in `self.network`), and if training targets are provided, computes the losses from the detection layers. Detections are concatenated from the detection layers. Each image will produce @@ -137,7 +138,7 @@ def forward( losses = [] # Losses from detection layers x = images - for module in self._module_list: + for module in self.network: if isinstance(module, RouteLayer) or isinstance(module, ShortcutLayer): x = module(x, outputs) elif isinstance(module, DetectionLayer): @@ -313,7 +314,7 @@ def read(tensor): with torch.no_grad(): tensor.copy_(x) - for module in self._module_list: + for module in self.network: # Weights are loaded only to convolutional layers if not isinstance(module, nn.Sequential): continue @@ -343,89 +344,6 @@ def get_deprecated_arg_names(cls) -> List: depr_arg_names.extend(val) return depr_arg_names - def _create_modules(self): - """ - Creates a list of network modules based on parsed configuration file. - """ - self._module_list = nn.ModuleList() - num_outputs = 3 # Number of channels in the previous layer output - layer_outputs = [] # Number of channels in the output of every layer - - # Iterate through the modules from the configuration and generate required components. - for index, config in enumerate(self.config.modules): - if config['type'] == 'convolutional': - module = nn.Sequential() - - batch_normalize = config.get('batch_normalize', False) - padding = (config['size'] - 1) // 2 if config['pad'] else 0 - - conv = nn.Conv2d( - num_outputs, - config['filters'], - config['size'], - config['stride'], - padding, - bias=not batch_normalize) - module.add_module("conv_{0}".format(index), conv) - num_outputs = config['filters'] - - if batch_normalize: - bn = nn.BatchNorm2d(config['filters']) - module.add_module("batch_norm_{0}".format(index), bn) - - if config['activation'] == 'leaky': - leakyrelu = nn.LeakyReLU(0.1, inplace=True) - module.add_module('leakyrelu_{0}'.format(index), leakyrelu) - elif config['activation'] == 'mish': - mish = Mish() - module.add_module("mish_{0}".format(index), mish) - - elif config['type'] == 'upsample': - module = nn.Upsample(scale_factor=config["stride"], mode='nearest') - - elif config['type'] == 'route': - num_chunks = config.get('groups', 1) - chunk_idx = config.get('group_id', 0) - source_layers = [layer if layer >= 0 else index + layer - for layer in config['layers']] - module = RouteLayer(source_layers, num_chunks, chunk_idx) - num_outputs = sum(layer_outputs[layer] // num_chunks - for layer in source_layers) - - elif config['type'] == 'shortcut': - module = ShortcutLayer(config['from']) - - elif config['type'] == 'yolo': - # The "anchors" list alternates width and height. - anchor_dims = config['anchors'] - anchor_dims = [(anchor_dims[i], anchor_dims[i + 1]) - for i in range(0, len(anchor_dims), 2)] - - xy_scale = config.get('scale_x_y', 1.0) - ignore_threshold = config.get('ignore_thresh', 1.0) - overlap_loss_multiplier = config.get('iou_normalizer', 1.0) - class_loss_multiplier = config.get('cls_normalizer', 1.0) - confidence_loss_multiplier = config.get('obj_normalizer', 1.0) - - module = DetectionLayer( - num_classes=config['classes'], - image_width=self.config.width, - image_height=self.config.height, - anchor_dims=anchor_dims, - anchor_ids=config['mask'], - xy_scale=xy_scale, - ignore_threshold=ignore_threshold, - overlap_loss_multiplier=overlap_loss_multiplier, - class_loss_multiplier=class_loss_multiplier, - confidence_loss_multiplier=confidence_loss_multiplier) - - elif config['type'] == 'maxpool': - padding = (config['size'] - 1) // 2 - module = nn.MaxPool2d(config['size'], config['stride'], padding) - - self._module_list.append(module) - layer_outputs.append(num_outputs) - def _validate_batch( self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]] @@ -449,12 +367,6 @@ def _validate_batch( if not isinstance(image, Tensor): raise ValueError("Expected image to be of type Tensor, got {}." .format(type(image))) - expected_shape = torch.Size((self.config.channels, - self.config.height, - self.config.width)) - if image.shape != expected_shape: - raise ValueError("Expected images to be tensors of shape {}, got {}." - .format(list(expected_shape), list(image.shape))) for target in targets: boxes = target['boxes'] @@ -595,7 +507,7 @@ def run_cli(): params = vars(args) valid_kwargs = inspect.signature(Yolo.__init__).parameters kwargs = dict((name, params[name]) for name in valid_kwargs if name in params) - model = Yolo(configuration=config, **kwargs) + model = Yolo(network=config.get_network(), **kwargs) if args.darknet_weights is not None: with open(args.darknet_weights, 'r') as weight_file: model.load_darknet_weights(weight_file) From 7356fe06984f81ca931d262d8ffc3f5cc9e68f35 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 8 Feb 2021 16:22:00 +0200 Subject: [PATCH 07/61] Refactoring --- pl_bolts/models/detection/__init__.py | 2 +- pl_bolts/models/detection/yolo/yolo_config.py | 2 +- pl_bolts/models/detection/yolo/yolo_module.py | 15 +++++++-------- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/pl_bolts/models/detection/__init__.py b/pl_bolts/models/detection/__init__.py index 367ae444c1..2ce43d0711 100644 --- a/pl_bolts/models/detection/__init__.py +++ b/pl_bolts/models/detection/__init__.py @@ -1,6 +1,6 @@ from pl_bolts.models.detection import components # noqa: F401 from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401 -from pl_bolts.models.detection.yolo import YoloConfiguration, Yolo # noqa: F401 +from pl_bolts.models.detection.yolo import Yolo, YoloConfiguration # noqa: F401 __all__ = [ "components", diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index af50c62245..27269448eb 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -1,8 +1,8 @@ import re from warnings import warn -from pytorch_lightning.utilities.exceptions import MisconfigurationException import torch.nn as nn +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pl_bolts.models.detection.yolo.yolo_layers import * diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index f9308e46f8..1e15bb7233 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -7,13 +7,12 @@ import pytorch_lightning as pl import torch import torch.nn as nn -from pytorch_lightning.utilities import argparse_utils from torch import optim, Tensor -from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR -from pl_bolts.utils.warnings import warn_missing_pkg from pl_bolts.models.detection.yolo.yolo_config import YoloConfiguration from pl_bolts.models.detection.yolo.yolo_layers import DetectionLayer, RouteLayer, ShortcutLayer +from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR +from pl_bolts.utils.warnings import warn_missing_pkg try: import torchvision.transforms as T @@ -139,7 +138,7 @@ def forward( x = images for module in self.network: - if isinstance(module, RouteLayer) or isinstance(module, ShortcutLayer): + if isinstance(module, (RouteLayer, ShortcutLayer)): x = module(x, outputs) elif isinstance(module, DetectionLayer): if targets is None: @@ -163,12 +162,12 @@ def mean_loss(loss_name): confidences = detections[..., 4] classprobs = detections[..., 5:] - if targets is not None: - losses = {loss_name: mean_loss(loss_name) for loss_name in losses[0].keys()} - return boxes, confidences, classprobs, losses - else: + if targets is None: return boxes, confidences, classprobs + losses = {loss_name: mean_loss(loss_name) for loss_name in losses[0].keys()} + return boxes, confidences, classprobs, losses + def configure_optimizers(self) -> Tuple[List, List]: """Constructs the optimizer and learning rate scheduler.""" if self.optimizer == 'sgd': From 39eb80d4009cf1e35d368e156a3ef0aa2371b620 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 8 Feb 2021 18:21:03 +0200 Subject: [PATCH 08/61] Fixed YOLO test. --- tests/models/test_detection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_detection.py b/tests/models/test_detection.py index 137c7fec49..f85f31402a 100644 --- a/tests/models/test_detection.py +++ b/tests/models/test_detection.py @@ -128,7 +128,7 @@ def test_yolo(tmpdir): config_path = Path(tmpdir) / 'yolo.cfg' _create_yolo_config_file(config_path) config = YoloConfiguration(config_path) - model = Yolo(config) + model = Yolo(config.get_network()) image = torch.rand(1, 3, 256, 256) model(image) @@ -138,7 +138,7 @@ def test_yolo_train(tmpdir): config_path = Path(tmpdir) / 'yolo.cfg' _create_yolo_config_file(config_path) config = YoloConfiguration(config_path) - model = Yolo(config) + model = Yolo(config.get_network()) train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) From 291f4be2e733e3ace66f383ea5172f1bd726766c Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Tue, 9 Feb 2021 12:16:24 +0200 Subject: [PATCH 09/61] Fixedd style issues --- .../datamodules/vocdetection_datamodule.py | 2 +- pl_bolts/models/detection/yolo/yolo_config.py | 21 ++++++++++++------- pl_bolts/models/detection/yolo/yolo_module.py | 7 ++----- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 863156e086..b34ba48da3 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from pytorch_lightning import LightningDataModule diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index 27269448eb..4128653b33 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -1,10 +1,11 @@ import re +from typing import List, Tuple from warnings import warn import torch.nn as nn from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pl_bolts.models.detection.yolo.yolo_layers import * +import pl_bolts.models.detection.yolo.yolo_layers as yolo class YoloConfiguration: @@ -170,6 +171,7 @@ def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]: } return create_func[config['type']](config, num_inputs) + def _create_convolutional(config, num_inputs): module = nn.Sequential() @@ -193,16 +195,18 @@ def _create_convolutional(config, num_inputs): leakyrelu = nn.LeakyReLU(0.1, inplace=True) module.add_module('leakyrelu', leakyrelu) elif config['activation'] == 'mish': - mish = Mish() + mish = yolo.Mish() module.add_module('mish', mish) return module, config['filters'] + def _create_maxpool(config, num_inputs): padding = (config['size'] - 1) // 2 module = nn.MaxPool2d(config['size'], config['stride'], padding) return module, num_inputs[-1] + def _create_route(config, num_inputs): num_chunks = config.get('groups', 1) chunk_idx = config.get('group_id', 0) @@ -212,27 +216,30 @@ def _create_route(config, num_inputs): source_layers = [layer if layer >= 0 else last + layer for layer in config['layers']] - module = RouteLayer(source_layers, num_chunks, chunk_idx) + module = yolo.RouteLayer(source_layers, num_chunks, chunk_idx) # The number of outputs of a source layer is the number of inputs of the next layer. num_outputs = sum(num_inputs[layer + 1] // num_chunks - for layer in source_layers) + for layer in source_layers) return module, num_outputs + def _create_shortcut(config, num_inputs): - module = ShortcutLayer(config['from']) + module = yolo.ShortcutLayer(config['from']) return module, num_inputs[-1] + def _create_upsample(config, num_inputs): module = nn.Upsample(scale_factor=config["stride"], mode='nearest') return module, num_inputs[-1] + def _create_yolo(config, num_inputs): # The "anchors" list alternates width and height. anchor_dims = config['anchors'] anchor_dims = [(anchor_dims[i], anchor_dims[i + 1]) - for i in range(0, len(anchor_dims), 2)] + for i in range(0, len(anchor_dims), 2)] xy_scale = config.get('scale_x_y', 1.0) ignore_threshold = config.get('ignore_thresh', 1.0) @@ -240,7 +247,7 @@ def _create_yolo(config, num_inputs): class_loss_multiplier = config.get('cls_normalizer', 1.0) confidence_loss_multiplier = config.get('obj_normalizer', 1.0) - module = DetectionLayer( + module = yolo.DetectionLayer( num_classes=config['classes'], image_width=config['width'], image_height=config['height'], diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 1e15bb7233..ce1b9e387f 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -1,7 +1,5 @@ import inspect -from argparse import ArgumentParser, Namespace -from pathlib import Path -from typing import Dict, List, Tuple, Type, Union +from typing import Dict, List, Tuple import numpy as np import pytorch_lightning as pl @@ -480,10 +478,9 @@ def __call__(self, image, target): def run_cli(): + from argparse import ArgumentParser from pytorch_lightning.utilities import argparse_utils - from pl_bolts.datamodules import VOCDetectionDataModule - from pl_bolts.datamodules.vocdetection_datamodule import Compose pl.seed_everything(42) From 8db794740c927db6d56b915d8429934473b32d59 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Tue, 9 Feb 2021 12:17:48 +0200 Subject: [PATCH 10/61] Comply to isort rules. --- pl_bolts/models/detection/yolo/yolo_module.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index ce1b9e387f..348fb718db 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -479,7 +479,9 @@ def __call__(self, image, target): def run_cli(): from argparse import ArgumentParser + from pytorch_lightning.utilities import argparse_utils + from pl_bolts.datamodules import VOCDetectionDataModule pl.seed_everything(42) From 2831755eae1e9e7d06ec1d260a7b859182f2a9a6 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Tue, 9 Feb 2021 13:53:00 +0200 Subject: [PATCH 11/61] Reading Darknet weights works also with truncated files. --- pl_bolts/models/detection/yolo/yolo_module.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 348fb718db..019e130934 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -297,6 +297,13 @@ def load_darknet_weights(self, weight_file): """ Loads weights to layer modules from a pretrained Darknet model. + One may want to continue training from the pretrained weights, on a dataset with a + different number of object categories. The number of kernels in the convolutional layers + just before each detection layer depends on the number of output classes. The Darknet + solution is to truncate the weight file and stop reading weights at the first incompatible + layer. For this reason the function silently leaves the rest of the layers unchanged, when + the weight file ends. + Args: weight_file: A file object containing model weights in the Darknet binary format. """ @@ -306,7 +313,13 @@ def load_darknet_weights(self, weight_file): 'images.'.format(version[0], version[1], version[2], images_seen[0])) def read(tensor): + """ + Reads the contents of `tensor` from the current position of `weight_file`. + If there's no more data in `weight_file`, returns without error. + """ x = np.fromfile(weight_file, count=tensor.numel(), dtype=np.float32) + if x.shape[0] == 0: + return x = torch.from_numpy(x).view_as(tensor) with torch.no_grad(): tensor.copy_(x) From eb26eba71c7408c6cb7e9092beb589042432ef17 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Tue, 9 Feb 2021 15:01:37 +0200 Subject: [PATCH 12/61] Fixed code formatting. --- pl_bolts/models/detection/__init__.py | 10 +- pl_bolts/models/detection/yolo/yolo_config.py | 23 +-- pl_bolts/models/detection/yolo/yolo_layers.py | 54 +++---- pl_bolts/models/detection/yolo/yolo_module.py | 142 ++++++++---------- tests/models/test_detection.py | 19 +-- 5 files changed, 103 insertions(+), 145 deletions(-) diff --git a/pl_bolts/models/detection/__init__.py b/pl_bolts/models/detection/__init__.py index 2ce43d0711..17862f170e 100644 --- a/pl_bolts/models/detection/__init__.py +++ b/pl_bolts/models/detection/__init__.py @@ -1,10 +1,6 @@ from pl_bolts.models.detection import components # noqa: F401 from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401 -from pl_bolts.models.detection.yolo import Yolo, YoloConfiguration # noqa: F401 +from pl_bolts.models.detection.yolo import Yolo # noqa: F401 +from pl_bolts.models.detection.yolo import YoloConfiguration # noqa: F401 -__all__ = [ - "components", - "FasterRCNN", - "YoloConfiguration", - "Yolo" -] +__all__ = ["components", "FasterRCNN", "YoloConfiguration", "Yolo"] diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index 4128653b33..10226815dd 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -27,8 +27,7 @@ def __init__(self, path: str): sections = self._read_file(config_file) if len(sections) < 2: - raise MisconfigurationException( - "The model configuration file should include at least two sections.") + raise MisconfigurationException("The model configuration file should include at least two sections.") self.__dict__.update(sections[0]) self.global_config = sections[0] @@ -179,12 +178,8 @@ def _create_convolutional(config, num_inputs): padding = (config['size'] - 1) // 2 if config['pad'] else 0 conv = nn.Conv2d( - num_inputs[-1], - config['filters'], - config['size'], - config['stride'], - padding, - bias=not batch_normalize) + num_inputs[-1], config['filters'], config['size'], config['stride'], padding, bias=not batch_normalize + ) module.add_module('conv', conv) if batch_normalize: @@ -213,14 +208,12 @@ def _create_route(config, num_inputs): # 0 is the first layer, -1 is the previous layer last = len(num_inputs) - 1 - source_layers = [layer if layer >= 0 else last + layer - for layer in config['layers']] + source_layers = [layer if layer >= 0 else last + layer for layer in config['layers']] module = yolo.RouteLayer(source_layers, num_chunks, chunk_idx) # The number of outputs of a source layer is the number of inputs of the next layer. - num_outputs = sum(num_inputs[layer + 1] // num_chunks - for layer in source_layers) + num_outputs = sum(num_inputs[layer + 1] // num_chunks for layer in source_layers) return module, num_outputs @@ -238,8 +231,7 @@ def _create_upsample(config, num_inputs): def _create_yolo(config, num_inputs): # The "anchors" list alternates width and height. anchor_dims = config['anchors'] - anchor_dims = [(anchor_dims[i], anchor_dims[i + 1]) - for i in range(0, len(anchor_dims), 2)] + anchor_dims = [(anchor_dims[i], anchor_dims[i + 1]) for i in range(0, len(anchor_dims), 2)] xy_scale = config.get('scale_x_y', 1.0) ignore_threshold = config.get('ignore_thresh', 1.0) @@ -257,6 +249,7 @@ def _create_yolo(config, num_inputs): ignore_threshold=ignore_threshold, overlap_loss_multiplier=overlap_loss_multiplier, class_loss_multiplier=class_loss_multiplier, - confidence_loss_multiplier=confidence_loss_multiplier) + confidence_loss_multiplier=confidence_loss_multiplier + ) return module, num_inputs[-1] diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 7c823f4f31..e629cfe578 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -44,20 +44,22 @@ class DetectionLayer(nn.Module): resolutions. The loss should be summed from all of them. """ - def __init__(self, - num_classes: int, - image_width: int, - image_height: int, - anchor_dims: List[Tuple[int, int]], - anchor_ids: List[int], - xy_scale: float = 1.0, - ignore_threshold: float = 0.5, - overlap_loss_func: Callable = None, - class_loss_func: Callable = None, - confidence_loss_func: Callable = None, - overlap_loss_multiplier: float = 1.0, - class_loss_multiplier: float = 1.0, - confidence_loss_multiplier: float = 1.0): + def __init__( + self, + num_classes: int, + image_width: int, + image_height: int, + anchor_dims: List[Tuple[int, int]], + anchor_ids: List[int], + xy_scale: float = 1.0, + ignore_threshold: float = 0.5, + overlap_loss_func: Callable = None, + class_loss_func: Callable = None, + confidence_loss_func: Callable = None, + overlap_loss_multiplier: float = 1.0, + class_loss_multiplier: float = 1.0, + confidence_loss_multiplier: float = 1.0 + ): """ Args: num_classes: Number of different classes that this layer predicts. @@ -110,11 +112,7 @@ def __init__(self, self.class_loss_func = class_loss_func or se_loss self.confidence_loss_func = confidence_loss_func or se_loss - def forward( - self, - x: Tensor, - targets: Optional[List[Dict[str, Tensor]]] = None - ) -> Tuple[Tensor, Dict[str, Tensor]]: + def forward(self, x: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None) -> Tuple[Tensor, Dict[str, Tensor]]: """ Runs a forward pass through this YOLO detection layer. @@ -138,7 +136,8 @@ def forward( if boxes_per_cell != len(self.anchor_ids): raise MisconfigurationException( "The model predicts {} bounding boxes per cell, but {} anchor boxes are defined " - "for this layer.".format(boxes_per_cell, len(self.anchor_ids))) + "for this layer.".format(boxes_per_cell, len(self.anchor_ids)) + ) # Reshape the output to have the bounding box attributes of each grid cell on its own row. x = x.permute(0, 2, 3, 1) # [batch_size, height, width, boxes_per_cell * num_attrs] @@ -255,10 +254,7 @@ def _low_confidence_mask(self, boxes, targets): num_preds = height * width * boxes_per_cell boxes = boxes.view(batch_size, num_preds, num_coords) - scale = torch.tensor([self.image_width, - self.image_height, - self.image_width, - self.image_height], + scale = torch.tensor([self.image_width, self.image_height, self.image_width, self.image_height], device=boxes.device) boxes = boxes * scale @@ -299,12 +295,9 @@ def _calculate_losses(self, xy, wh, confidence, classprob, targets, lc_mask): assert batch_size == len(targets) # Divisor for converting targets from image coordinates to feature map coordinates - image_to_feature_map = torch.tensor([self.image_width / width, - self.image_height / height], - device=device) + image_to_feature_map = torch.tensor([self.image_width / width, self.image_height / height], device=device) # Divisor for converting targets from image coordinates to [0, 1] range - image_to_unit = torch.tensor([self.image_width, self.image_height], - device=device) + image_to_unit = torch.tensor([self.image_width, self.image_height], device=device) anchor_wh = torch.tensor(self.anchor_dims, dtype=wh.dtype, device=device) anchor_map = torch.tensor(self.anchor_map, dtype=torch.int64, device=device) @@ -444,8 +437,7 @@ def __init__(self, source_layers: List[int], num_chunks: int, chunk_idx: int): self.chunk_idx = chunk_idx def forward(self, x, outputs): - chunks = [torch.chunk(outputs[layer], self.num_chunks, dim=1)[self.chunk_idx] - for layer in self.source_layers] + chunks = [torch.chunk(outputs[layer], self.num_chunks, dim=1)[self.chunk_idx] for layer in self.source_layers] return torch.cat(chunks, dim=1) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 019e130934..42988491d5 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -53,17 +53,19 @@ class Yolo(pl.LightningModule): python yolo_module.py --config yolov4-tiny.cfg --data_dir . --gpus 8 --batch-size 8 """ - def __init__(self, - network: nn.ModuleList, - optimizer: str = 'sgd', - momentum: float = 0.9, - weight_decay: float = 0.0005, - learning_rate: float = 0.0013, - warmup_epochs: int = 1, - warmup_start_lr: float = 0.0001, - annealing_epochs: int = 271, - confidence_threshold: float = 0.2, - nms_threshold: float = 0.45): + def __init__( + self, + network: nn.ModuleList, + optimizer: str = 'sgd', + momentum: float = 0.9, + weight_decay: float = 0.0005, + learning_rate: float = 0.0013, + warmup_epochs: int = 1, + warmup_start_lr: float = 0.0001, + annealing_epochs: int = 271, + confidence_threshold: float = 0.2, + nms_threshold: float = 0.45 + ): """ Args: network: A list of network modules. This can be obtained from a Darknet configuration @@ -101,11 +103,9 @@ def __init__(self, self.confidence_threshold = confidence_threshold self.nms_threshold = nms_threshold - def forward( - self, - images: Tensor, - targets: List[Dict[str, Tensor]] = None - ) -> Tuple[Tensor, Tensor, Tensor, Dict[str, Tensor]]: + def forward(self, + images: Tensor, + targets: List[Dict[str, Tensor]] = None) -> Tuple[Tensor, Tensor, Tensor, Dict[str, Tensor]]: """ Runs a forward pass through the network (all layers listed in `self.network`), and if training targets are provided, computes the losses from the detection layers. @@ -130,9 +130,9 @@ def forward( dimension is the detection within the image. `boxes` contains the predicted (x1, y1, x2, y2) coordinates, normalized to [0, 1]. """ - outputs = [] # Outputs from all layers + outputs = [] # Outputs from all layers detections = [] # Outputs from detection layers - losses = [] # Losses from detection layers + losses = [] # Losses from detection layers x = images for module in self.network: @@ -170,27 +170,21 @@ def configure_optimizers(self) -> Tuple[List, List]: """Constructs the optimizer and learning rate scheduler.""" if self.optimizer == 'sgd': optimizer = optim.SGD( - self.parameters(), - lr=self.learning_rate, - momentum=self.momentum, - weight_decay=self.weight_decay) - elif self.optimizer == 'adam': - optimizer = optim.Adam( - self.parameters(), - lr=self.learning_rate + self.parameters(), lr=self.learning_rate, momentum=self.momentum, weight_decay=self.weight_decay ) + elif self.optimizer == 'adam': + optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) + lr_scheduler = LinearWarmupCosineAnnealingLR( optimizer, warmup_epochs=self.warmup_epochs, max_epochs=self.annealing_epochs, - warmup_start_lr=self.warmup_start_lr) + warmup_start_lr=self.warmup_start_lr + ) + return [optimizer], [lr_scheduler] - def training_step( - self, - batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], - batch_idx: int - ) -> Dict[str, Tensor]: + def training_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int) -> Dict[str, Tensor]: """ Computes the training loss. @@ -212,11 +206,7 @@ def training_step( return {'loss': total_loss} - def validation_step( - self, - batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], - batch_idx: int - ) -> Dict[str, Tensor]: + def validation_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int) -> Dict[str, Tensor]: """ Evaluates a batch of data from the validation set. @@ -228,19 +218,14 @@ def validation_step( images, targets = self._validate_batch(batch) boxes, confidences, classprobs, losses = self(images, targets) classprobs, labels = torch.max(classprobs, -1) - boxes, confidences, classprobs, labels = self._filter_detections( - boxes, confidences, classprobs, labels) + boxes, confidences, classprobs, labels = self._filter_detections(boxes, confidences, classprobs, labels) total_loss = torch.stack(tuple(losses.values())).sum() for name, value in losses.items(): self.log('val/{}_loss'.format(name), value) self.log('val/total_loss', total_loss) - def test_step( - self, - batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], - batch_idx: int - ) -> Dict[str, Tensor]: + def test_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int) -> Dict[str, Tensor]: """ Evaluates a batch of data from the test set. @@ -252,8 +237,7 @@ def test_step( images, targets = self._validate_batch(batch) boxes, confidences, classprobs, losses = self(images, targets) classprobs, labels = torch.max(classprobs, -1) - boxes, confidences, classprobs, labels = self._filter_detections( - boxes, confidences, classprobs, labels) + boxes, confidences, classprobs, labels = self._filter_detections(boxes, confidences, classprobs, labels) total_loss = torch.stack(tuple(losses.values())).sum() for name, value in losses.items(): @@ -279,8 +263,7 @@ def infer(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: self.eval() boxes, confidences, classprobs = self(network_input) classprobs, labels = torch.max(classprobs, -1) - boxes, confidences, classprobs, labels = self._filter_detections( - boxes, confidences, classprobs, labels) + boxes, confidences, classprobs, labels = self._filter_detections(boxes, confidences, classprobs, labels) assert len(boxes) == 1 boxes = boxes[0] confidences = confidences[0] @@ -309,8 +292,10 @@ def load_darknet_weights(self, weight_file): """ version = np.fromfile(weight_file, count=3, dtype=np.int32) images_seen = np.fromfile(weight_file, count=1, dtype=np.int64) - print('Loading weights from Darknet model version {}.{}.{} that has been trained on {} ' - 'images.'.format(version[0], version[1], version[2], images_seen[0])) + print( + 'Loading weights from Darknet model version {}.{}.{} that has been trained on {} ' + 'images.'.format(version[0], version[1], version[2], images_seen[0]) + ) def read(tensor): """ @@ -354,10 +339,8 @@ def get_deprecated_arg_names(cls) -> List: depr_arg_names.extend(val) return depr_arg_names - def _validate_batch( - self, - batch: Tuple[List[Tensor], List[Dict[str, Tensor]]] - ) -> Tuple[Tensor, List[Dict[str, Tensor]]]: + def _validate_batch(self, batch: Tuple[List[Tensor], List[Dict[str, + Tensor]]]) -> Tuple[Tensor, List[Dict[str, Tensor]]]: """ Reads a batch of data, validates the format, and stacks the images into a single tensor. @@ -370,40 +353,33 @@ def _validate_batch( images, targets = batch if len(images) != len(targets): - raise ValueError("Got {} images, but targets for {} images." - .format(len(images), len(targets))) + raise ValueError("Got {} images, but targets for {} images.".format(len(images), len(targets))) for image in images: if not isinstance(image, Tensor): - raise ValueError("Expected image to be of type Tensor, got {}." - .format(type(image))) + raise ValueError("Expected image to be of type Tensor, got {}.".format(type(image))) for target in targets: boxes = target['boxes'] if not isinstance(boxes, Tensor): - raise ValueError("Expected target boxes to be of type Tensor, got {}." - .format(type(boxes))) + raise ValueError("Expected target boxes to be of type Tensor, got {}.".format(type(boxes))) if (len(boxes.shape) != 2) or (boxes.shape[-1] != 4): - raise ValueError("Expected target boxes to be tensors of shape [N, 4], got {}." - .format(list(boxes.shape))) + raise ValueError( + "Expected target boxes to be tensors of shape [N, 4], got {}.".format(list(boxes.shape)) + ) labels = target['labels'] if not isinstance(labels, Tensor): - raise ValueError("Expected target labels to be of type Tensor, got {}." - .format(type(labels))) + raise ValueError("Expected target labels to be of type Tensor, got {}.".format(type(labels))) if len(labels.shape) != 1: - raise ValueError("Expected target labels to be tensors of shape [N], got {}." - .format(list(labels.shape))) + raise ValueError( + "Expected target labels to be tensors of shape [N], got {}.".format(list(labels.shape)) + ) images = torch.stack(images) return images, targets - def _filter_detections( - self, - boxes: Tensor, - confidences: Tensor, - classprobs: Tensor, - labels: Tensor - ) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]: + def _filter_detections(self, boxes: Tensor, confidences: Tensor, classprobs: Tensor, + labels: Tensor) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]: """ Filters detections based on confidence threshold. Then for every class performs non-maximum suppression (NMS). NMS iterates the bounding boxes that predict this class in descending @@ -481,11 +457,15 @@ def __call__(self, image, target): original_size = torch.tensor([height, width]) resize_ratio = torch.tensor(self.output_size) / original_size image = F.resize(image, self.output_size) - scale = torch.tensor([resize_ratio[1], # y - resize_ratio[0], # x - resize_ratio[1], # y - resize_ratio[0]], # x - device=target['boxes'].device) + scale = torch.tensor( + [ + resize_ratio[1], # y + resize_ratio[0], # x + resize_ratio[1], # y + resize_ratio[0] # x + ], + device=target['boxes'].device + ) target['boxes'] = target['boxes'] * scale return image, target @@ -525,9 +505,9 @@ def run_cli(): trainer = pl.Trainer.from_argparse_args(args) trainer.fit( - model, - datamodule.train_dataloader(args.batch_size, transforms, image_transforms), - datamodule.val_dataloader(args.batch_size, transforms, image_transforms)) + model, datamodule.train_dataloader(args.batch_size, transforms, image_transforms), + datamodule.val_dataloader(args.batch_size, transforms, image_transforms) + ) if __name__ == "__main__": diff --git a/tests/models/test_detection.py b/tests/models/test_detection.py index f85f31402a..7ad94db224 100644 --- a/tests/models/test_detection.py +++ b/tests/models/test_detection.py @@ -42,7 +42,8 @@ def test_fasterrcnn_bbone_train(tmpdir): def _create_yolo_config_file(config_path): config_file = open(config_path, 'w') - config_file.write('''[net] + config_file.write( + '''[net] width=256 height=256 channels=3 @@ -120,7 +121,8 @@ def _create_yolo_config_file(config_path): scale_x_y=1.05 cls_normalizer=1.0 iou_normalizer=0.07 -ignore_thresh=0.7''') +ignore_thresh=0.7''' + ) config_file.close() @@ -148,15 +150,10 @@ def test_yolo_train(tmpdir): @pytest.mark.parametrize( - "dims1, dims2, expected_ious", - [(torch.tensor([[1.0, 1.0], - [10.0, 1.0], - [100.0, 10.0]]), - torch.tensor([[1.0, 10.0], - [2.0, 20.0]]), - torch.tensor([[1.0 / 10.0, 1.0 / 40.0], - [1.0 / 19.0, 2.0 / 48.0], - [10.0 / 1000.0, 20.0 / 1020.0]]))] + "dims1, dims2, expected_ious", [( + torch.tensor([[1.0, 1.0], [10.0, 1.0], [100.0, 10.0]]), torch.tensor([[1.0, 10.0], [2.0, 20.0]]), + torch.tensor([[1.0 / 10.0, 1.0 / 40.0], [1.0 / 19.0, 2.0 / 48.0], [10.0 / 1000.0, 20.0 / 1020.0]]) + )] ) def test_aligned_iou(dims1, dims2, expected_ious): torch.testing.assert_allclose(_aligned_iou(dims1, dims2), expected_ious) From 9c155a951897d0718cd3ecbf16516e48ffb89eea Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Tue, 9 Feb 2021 15:50:05 +0200 Subject: [PATCH 13/61] Trying to fix Python 3.6 import problem. --- pl_bolts/models/detection/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/models/detection/__init__.py b/pl_bolts/models/detection/__init__.py index 17862f170e..03e667ee06 100644 --- a/pl_bolts/models/detection/__init__.py +++ b/pl_bolts/models/detection/__init__.py @@ -1,6 +1,6 @@ from pl_bolts.models.detection import components # noqa: F401 from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401 -from pl_bolts.models.detection.yolo import Yolo # noqa: F401 -from pl_bolts.models.detection.yolo import YoloConfiguration # noqa: F401 +from pl_bolts.models.detection.yolo.yolo_config import YoloConfiguration # noqa: F401 +from pl_bolts.models.detection.yolo.yolo_module import Yolo # noqa: F401 __all__ = ["components", "FasterRCNN", "YoloConfiguration", "Yolo"] From efeb1c84db5bc8e4dd9fcbd0056a178ae0f81c7e Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Tue, 9 Feb 2021 17:55:38 +0200 Subject: [PATCH 14/61] Fixed Python 3.6 import error. --- pl_bolts/models/detection/__init__.py | 3 +-- pl_bolts/models/detection/yolo/yolo_config.py | 10 +++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/pl_bolts/models/detection/__init__.py b/pl_bolts/models/detection/__init__.py index 03e667ee06..bb37ad76a3 100644 --- a/pl_bolts/models/detection/__init__.py +++ b/pl_bolts/models/detection/__init__.py @@ -1,6 +1,5 @@ from pl_bolts.models.detection import components # noqa: F401 from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401 -from pl_bolts.models.detection.yolo.yolo_config import YoloConfiguration # noqa: F401 -from pl_bolts.models.detection.yolo.yolo_module import Yolo # noqa: F401 +from pl_bolts.models.detection.yolo import Yolo, YoloConfiguration # noqa: F401 __all__ = ["components", "FasterRCNN", "YoloConfiguration", "Yolo"] diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index 10226815dd..48ca42095e 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -5,7 +5,7 @@ import torch.nn as nn from pytorch_lightning.utilities.exceptions import MisconfigurationException -import pl_bolts.models.detection.yolo.yolo_layers as yolo +from pl_bolts.models.detection.yolo import yolo_layers class YoloConfiguration: @@ -190,7 +190,7 @@ def _create_convolutional(config, num_inputs): leakyrelu = nn.LeakyReLU(0.1, inplace=True) module.add_module('leakyrelu', leakyrelu) elif config['activation'] == 'mish': - mish = yolo.Mish() + mish = yolo_layers.Mish() module.add_module('mish', mish) return module, config['filters'] @@ -210,7 +210,7 @@ def _create_route(config, num_inputs): last = len(num_inputs) - 1 source_layers = [layer if layer >= 0 else last + layer for layer in config['layers']] - module = yolo.RouteLayer(source_layers, num_chunks, chunk_idx) + module = yolo_layers.RouteLayer(source_layers, num_chunks, chunk_idx) # The number of outputs of a source layer is the number of inputs of the next layer. num_outputs = sum(num_inputs[layer + 1] // num_chunks for layer in source_layers) @@ -219,7 +219,7 @@ def _create_route(config, num_inputs): def _create_shortcut(config, num_inputs): - module = yolo.ShortcutLayer(config['from']) + module = yolo_layers.ShortcutLayer(config['from']) return module, num_inputs[-1] @@ -239,7 +239,7 @@ def _create_yolo(config, num_inputs): class_loss_multiplier = config.get('cls_normalizer', 1.0) confidence_loss_multiplier = config.get('obj_normalizer', 1.0) - module = yolo.DetectionLayer( + module = yolo_layers.DetectionLayer( num_classes=config['classes'], image_width=config['width'], image_height=config['height'], From 1a1ecd3791fcbd7fb8cee07f57700fed4283ede1 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Tue, 9 Feb 2021 18:10:45 +0200 Subject: [PATCH 15/61] Added YOLO to CHANGELOG. --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 63f2717538..a18a67abcd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#323](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/323)) - Added data monitor callbacks `ModuleDataMonitor` and `TrainingDataMonitor` ([#285](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/285)) - Added DCGAN module ([#403](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/403)) +- Added YOLO module ([#552](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/552)) - Added `VisionDataModule` as parent class for `BinaryMNISTDataModule`, `CIFAR10DataModule`, `FashionMNISTDataModule`, and `MNISTDataModule` ([#400](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/400)) - Added GIoU loss ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347)) From 26ff9796c9f274873c9cdaba0324e0f9b62d36fe Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Tue, 9 Feb 2021 18:24:15 +0200 Subject: [PATCH 16/61] Use torch.min() instead of torch.minimum() to avoid error with older PyTorch versions. --- pl_bolts/models/detection/yolo/yolo_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index e629cfe578..2e85590c24 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -369,7 +369,7 @@ def _calculate_losses(self, xy, wh, confidence, classprob, targets, lc_mask): # mapped to the last class. labels = image_targets['labels'] labels = labels[selected] - labels = torch.minimum(labels, torch.tensor(self.num_classes - 1, device=device)) + labels = torch.min(labels, torch.tensor(self.num_classes - 1, device=device)) target_label.append(labels) pred_xy.append(xy[image_idx, cell_j, cell_i, predictors]) From 3e9bddeb09a71fafc55769fa9a5c0cd7033cd330 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Fri, 12 Feb 2021 20:13:35 +0200 Subject: [PATCH 17/61] Generalized interface for custom losses * IoU loss functions take image space coordinates as input. --- pl_bolts/models/detection/yolo/yolo_config.py | 10 + pl_bolts/models/detection/yolo/yolo_layers.py | 222 +++++++++++++----- pl_bolts/models/detection/yolo/yolo_module.py | 2 + 3 files changed, 171 insertions(+), 63 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index 48ca42095e..c3e247a79a 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -239,6 +239,14 @@ def _create_yolo(config, num_inputs): class_loss_multiplier = config.get('cls_normalizer', 1.0) confidence_loss_multiplier = config.get('obj_normalizer', 1.0) + overlap_loss_name = config.get('iou_loss', 'mse') + if overlap_loss_name == 'mse': + overlap_loss_func = nn.MSELoss(reduction='none') + elif overlap_loss_name == 'giou': + overlap_loss_func = yolo_layers.GIoULoss() + else: + overlap_loss_func = yolo_layers.IoULoss() + module = yolo_layers.DetectionLayer( num_classes=config['classes'], image_width=config['width'], @@ -247,6 +255,8 @@ def _create_yolo(config, num_inputs): anchor_ids=config['mask'], xy_scale=xy_scale, ignore_threshold=ignore_threshold, + overlap_loss_func=overlap_loss_func, + image_space_loss=overlap_loss_name != 'mse', overlap_loss_multiplier=overlap_loss_multiplier, class_loss_multiplier=class_loss_multiplier, confidence_loss_multiplier=confidence_loss_multiplier diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 2e85590c24..3e6aff85ff 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -7,7 +7,7 @@ from pl_bolts.utils.warnings import warn_missing_pkg try: - from torchvision.ops import box_iou + from torchvision.ops import box_area, box_iou except ModuleNotFoundError: warn_missing_pkg('torchvision') # pragma: no-cover _TORCHVISION_AVAILABLE = False @@ -15,6 +15,23 @@ _TORCHVISION_AVAILABLE = True +def _corner_coordinates(xy, wh): + """ + Converts box center points and sizes to corner coordinates. + + Args: + xy (Tensor): Center coordinates. Tensor of size `[..., 2]`. + wh (Tensor): Width and height. Tensor of size `[..., 2]`. + + Returns: + boxes (Tensor): A matrix of (x1, y1, x2, y2) coordinates. + """ + half_wh = wh / 2 + top_left = xy - half_wh + bottom_right = xy + half_wh + return torch.cat((top_left, bottom_right), -1) + + def _aligned_iou(dims1, dims2): """ Calculates a matrix of intersections over union from box dimensions, assuming that the boxes @@ -38,6 +55,84 @@ def _aligned_iou(dims1, dims2): return inter / union +def _elementwise_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: + """ + Returns the elementwise intersection-over-union between two sets of boxes. + + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + + Args: + boxes1 (Tensor[N, 4]) + boxes2 (Tensor[N, 4]) + + Returns: + iou (Tensor[N]): the vector containing the elementwise IoU values for every element in + boxes1 and boxes2 + """ + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2] + rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2] + + wh = (rb - lt).clamp(min=0) # [N,2] + inter = wh[:, 0] * wh[:, 1] # [N] + + iou = inter / (area1 + area2 - inter) + return iou + + +def _elementwise_generalized_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: + """ + Returns the elementwise generalized intersection-over-union between two sets of boxes. + + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + + Args: + boxes1 (Tensor[N, 4]) + boxes2 (Tensor[N, 4]) + + Returns: + generalized_iou (Tensor[N]): the vector containing the elementwise generalized IoU values + for every element in boxes1 and boxes2 + """ + + # Degenerate boxes give inf / nan results, so do an early check. + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2] + rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2] + + wh = (rb - lt).clamp(min=0) # [N,2] + inter = wh[:, 0] * wh[:, 1] # [N] + + union = area1 + area2 - inter + + iou = inter / union + + lti = torch.min(boxes1[:, :2], boxes2[:, :2]) + rbi = torch.max(boxes1[:, 2:], boxes2[:, 2:]) + + whi = (rbi - lti).clamp(min=0) # [N,2] + areai = whi[:, 0] * whi[:, 1] + + return iou - (areai - union) / areai + + +class IoULoss(nn.Module): + def forward(self, inputs: Tensor, target: Tensor) -> Tensor: + return 1.0 - _elementwise_iou(inputs, target) + + +class GIoULoss(nn.Module): + def forward(self, inputs: Tensor, target: Tensor) -> Tensor: + return 1.0 - _elementwise_generalized_iou(inputs, target) + + class DetectionLayer(nn.Module): """ A YOLO detection layer. A YOLO model has usually 1 - 3 detection layers at different @@ -56,6 +151,7 @@ def __init__( overlap_loss_func: Callable = None, class_loss_func: Callable = None, confidence_loss_func: Callable = None, + image_space_loss: bool = False, overlap_loss_multiplier: float = 1.0, class_loss_multiplier: float = 1.0, confidence_loss_multiplier: float = 1.0 @@ -77,12 +173,17 @@ def __init__( ignore_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. - overlap_loss_func: Loss function for (x, y, w, h) coordinates. Default is the sum of + overlap_loss_func: Loss function for bounding box coordinates. Default is the sum of squared errors. class_loss_func: Loss function for class probability distribution. Default is the sum of squared errors. confidence_loss_func: Loss function for confidence score. Default is the sum of squared errors. + image_space_loss: If set to `True`, the overlap loss function will receive the bounding + box (x1, y1, x2, y2) coordinate normalized to the [0, 1] range. This is needed for + the IoU losses introduced in YOLOv4. Otherwise the loss will be computed from the x, + y, width, and height values, as predicted by the network (i.e. relative to the + anchor box, and width and height are logarithmic). coord_loss_multiplier: Multiply the coordinate/size loss by this factor. class_loss_multiplier: Multiply the classification loss by this factor. confidence_loss_multiplier: Multiply the confidence loss by this factor. @@ -103,14 +204,14 @@ def __init__( self.xy_scale = xy_scale self.ignore_threshold = ignore_threshold - self.overlap_loss_multiplier = overlap_loss_multiplier - self.class_loss_multiplier = class_loss_multiplier - self.confidence_loss_multiplier = confidence_loss_multiplier - se_loss = nn.MSELoss(reduction='none') self.overlap_loss_func = overlap_loss_func or se_loss self.class_loss_func = class_loss_func or se_loss self.confidence_loss_func = confidence_loss_func or se_loss + self.image_space_loss = image_space_loss + self.overlap_loss_multiplier = overlap_loss_multiplier + self.class_loss_multiplier = class_loss_multiplier + self.confidence_loss_multiplier = confidence_loss_multiplier def forward(self, x: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None) -> Tuple[Tensor, Dict[str, Tensor]]: """ @@ -127,8 +228,9 @@ def forward(self, x: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None) dictionaries, one for each image. Returns: - result: Layer output, and if training targets were provided, a dictionary of losses. - Layer output is sized `[batch_size, num_anchors * height * width, num_classes + 5]`. + output (Tensor), losses (Dict[str, Tensor]): Layer output, and if training targets were + provided, a dictionary of losses. Layer output is sized + `[batch_size, num_anchors * height * width, num_classes + 5]`. """ batch_size, num_features, height, width = x.shape num_attrs = self.num_classes + 5 @@ -160,15 +262,17 @@ def forward(self, x: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None) image_xy = self._global_xy(xy) image_wh = self._scale_wh(wh) - corners = self._corner_coordinates(image_xy, image_wh) - output = torch.cat((corners, confidence.unsqueeze(-1), classprob), -1) + boxes = _corner_coordinates(image_xy, image_wh) + output = torch.cat((boxes, confidence.unsqueeze(-1), classprob), -1) output = output.reshape(batch_size, height * width * boxes_per_cell, num_attrs) if targets is None: return output - lc_mask = self._low_confidence_mask(corners, targets) - losses = self._calculate_losses(xy, wh, confidence, classprob, targets, lc_mask) + lc_mask = self._low_confidence_mask(boxes, targets) + if not self.image_space_loss: + boxes = torch.cat((xy, wh), -1) + losses = self._calculate_losses(boxes, confidence, classprob, targets, lc_mask) return output, losses def _global_xy(self, xy): @@ -217,22 +321,6 @@ def _scale_wh(self, wh): anchor_wh = torch.tensor(anchor_wh, dtype=wh.dtype, device=wh.device) return torch.exp(wh) * anchor_wh / image_size - def _corner_coordinates(self, xy, wh): - """ - Converts box center points and sizes to corner coordinates. - - Args: - xy (Tensor): Center coordinates. Tensor of size `[..., 2]`. - wh (Tensor): Width and height. Tensor of size `[..., 2]`. - - Returns: - corners (Tensor): A matrix of (x1, y1, x2, y2) coordinates. - """ - half_wh = wh / 2 - top_left = xy - half_wh - bottom_right = xy + half_wh - return torch.cat((top_left, bottom_right), -1) - def _low_confidence_mask(self, boxes, targets): """ Initializes the mask that will be used to select predictors that are not predicting any @@ -268,16 +356,14 @@ def _low_confidence_mask(self, boxes, targets): return results.view((batch_size, height, width, boxes_per_cell)) - def _calculate_losses(self, xy, wh, confidence, classprob, targets, lc_mask): + def _calculate_losses(self, boxes, confidence, classprob, targets, lc_mask): """ From the targets that are in the image space calculates the actual targets for the network predictions, and returns a dictionary of training losses. Args: - xy (Tensor): The predicted center coordinates before scaling. Values from zero to one - in a tensor sized `[batch_size, height, width, boxes_per_cell, 2]`. - wh (Tensor): The unnormalized width and height predictions. Tensor of size - `[batch_size, height, width, boxes_per_cell, 2]`. + boxes (Tensor): The predicted bounding boxes. A tensor sized + `[batch_size, height, width, boxes_per_cell, 4]`. confidence (Tensor): The confidence predictions, normalized to [0, 1]. A tensor sized `[batch_size, height, width, boxes_per_cell]`. classprob (Tensor): The class probability predictions, normalized to [0, 1]. A tensor @@ -290,8 +376,8 @@ def _calculate_losses(self, xy, wh, confidence, classprob, targets, lc_mask): Returns: predicted (Dict[str, Tensor]): A dictionary of training losses. """ - batch_size, height, width, boxes_per_cell, _ = xy.shape - device = xy.device + batch_size, height, width, boxes_per_cell, _ = boxes.shape + device = boxes.device assert batch_size == len(targets) # Divisor for converting targets from image coordinates to feature map coordinates @@ -299,7 +385,7 @@ def _calculate_losses(self, xy, wh, confidence, classprob, targets, lc_mask): # Divisor for converting targets from image coordinates to [0, 1] range image_to_unit = torch.tensor([self.image_width, self.image_height], device=device) - anchor_wh = torch.tensor(self.anchor_dims, dtype=wh.dtype, device=device) + anchor_wh = torch.tensor(self.anchor_dims, dtype=boxes.dtype, device=device) anchor_map = torch.tensor(self.anchor_map, dtype=torch.int64, device=device) # List of predicted and target values for the predictors that are responsible for @@ -308,33 +394,34 @@ def _calculate_losses(self, xy, wh, confidence, classprob, targets, lc_mask): target_wh = [] target_label = [] size_compensation = [] - pred_xy = [] - pred_wh = [] + pred_boxes = [] pred_classprob = [] pred_confidence = [] for image_idx, image_targets in enumerate(targets): - boxes = image_targets['boxes'] - if boxes.shape[0] < 1: + target_boxes = image_targets['boxes'] + if target_boxes.shape[0] < 1: continue # Bounding box corner coordinates are converted to center coordinates, width, and - # height. - box_wh = boxes[:, 2:4] - boxes[:, 0:2] - box_xy = boxes[:, 0:2] + (box_wh / 2) + # height, and normalized to [0, 1] range. + wh = target_boxes[:, 2:4] - target_boxes[:, 0:2] + xy = target_boxes[:, 0:2] + (wh / 2) + unit_xy = xy / image_to_unit + unit_wh = wh / image_to_unit # The center coordinates are converted to the feature map dimensions so that the whole # number tells the cell index and the fractional part tells the location inside the cell. - box_xy = box_xy / image_to_feature_map - cell_i = box_xy[:, 0].to(torch.int64).clamp(0, width - 1) - cell_j = box_xy[:, 1].to(torch.int64).clamp(0, height - 1) + xy = xy / image_to_feature_map + cell_i = xy[:, 0].to(torch.int64).clamp(0, width - 1) + cell_j = xy[:, 1].to(torch.int64).clamp(0, height - 1) # We want to know which anchor box overlaps a ground truth box more than any other # anchor box. We know that the anchor box is located in the same grid cell as the # ground truth box. For each prior shape (width, height), we calculate the IoU with # all ground truth boxes, assuming the boxes are at the same location. Then for each # target, we select the prior shape that gives the highest IoU. - ious = _aligned_iou(box_wh, anchor_wh) + ious = _aligned_iou(wh, anchor_wh) best_anchors = ious.max(1).indices # `anchor_map` maps the anchor indices to the predictors in this layer, or to -1 if @@ -342,8 +429,8 @@ def _calculate_losses(self, xy, wh, confidence, classprob, targets, lc_mask): # another layer. predictors = anchor_map[best_anchors] selected = predictors >= 0 - box_xy = box_xy[selected] - box_wh = box_wh[selected] + unit_xy = unit_xy[selected] + unit_wh = unit_wh[selected] cell_i = cell_i[selected] cell_j = cell_j[selected] predictors = predictors[selected] @@ -354,14 +441,21 @@ def _calculate_losses(self, xy, wh, confidence, classprob, targets, lc_mask): # the target confidence. lc_mask[image_idx, cell_j, cell_i, predictors] = False - # Bounding box targets - relative_xy = box_xy - box_xy.floor() - relative_wh = torch.log(box_wh / anchor_wh[best_anchors] + 1e-16) - target_xy.append(relative_xy) - target_wh.append(relative_wh) - - # Size compensation factor for bounding box loss - unit_wh = box_wh / image_to_unit + # IoU losses are calculated from the image space coordinates normalized to [0, 1] + # range. The squared-error loss is calculated from the raw predicted values. + if self.image_space_loss: + target_xy.append(unit_xy) + target_wh.append(unit_wh) + else: + xy = xy[selected] + wh = wh[selected] + relative_xy = xy - xy.floor() + relative_wh = torch.log(wh / anchor_wh[best_anchors] + 1e-16) + target_xy.append(relative_xy) + target_wh.append(relative_wh) + + # Size compensation factor for bounding box overlap loss is calculated from image space + # width and height. size_compensation.append(2 - (unit_wh[:, 0] * unit_wh[:, 1])) # The data may contain a different number of classes than this detection layer. In case @@ -372,18 +466,20 @@ def _calculate_losses(self, xy, wh, confidence, classprob, targets, lc_mask): labels = torch.min(labels, torch.tensor(self.num_classes - 1, device=device)) target_label.append(labels) - pred_xy.append(xy[image_idx, cell_j, cell_i, predictors]) - pred_wh.append(wh[image_idx, cell_j, cell_i, predictors]) + pred_boxes.append(boxes[image_idx, cell_j, cell_i, predictors]) pred_classprob.append(classprob[image_idx, cell_j, cell_i, predictors]) pred_confidence.append(confidence[image_idx, cell_j, cell_i, predictors]) losses = dict() - if pred_xy and pred_wh and target_xy and target_wh: + if pred_boxes and target_xy and target_wh: size_compensation = torch.cat(size_compensation).unsqueeze(1) - pred_xywh = torch.cat((torch.cat(pred_xy), torch.cat(pred_wh)), -1) - target_xywh = torch.cat((torch.cat(target_xy), torch.cat(target_wh)), -1) - overlap_loss = self.overlap_loss_func(pred_xywh, target_xywh) + pred_boxes = torch.cat(pred_boxes) + if self.image_space_loss: + target_boxes = _corner_coordinates(torch.cat(target_xy), torch.cat(target_wh)) + else: + target_boxes = torch.cat((torch.cat(target_xy), torch.cat(target_wh)), -1) + overlap_loss = self.overlap_loss_func(pred_boxes, target_boxes) overlap_loss = overlap_loss * size_compensation overlap_loss = overlap_loss.sum() / batch_size losses['overlap'] = overlap_loss * self.overlap_loss_multiplier diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 42988491d5..93568109a1 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -174,6 +174,8 @@ def configure_optimizers(self) -> Tuple[List, List]: ) elif self.optimizer == 'adam': optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) + else: + raise ValueError("Unknown optimizer: {}".format(self.optimizer)) lr_scheduler = LinearWarmupCosineAnnealingLR( optimizer, From c34861949e3d57a708c33a5cf0ce9f4a281cf6da Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Fri, 12 Feb 2021 21:03:40 +0200 Subject: [PATCH 18/61] box_area() implementation copied from torchvision * box_area() is not found from the version of torchvision used by the test suite. --- pl_bolts/models/detection/yolo/yolo_layers.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 3e6aff85ff..161b53b101 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -7,7 +7,7 @@ from pl_bolts.utils.warnings import warn_missing_pkg try: - from torchvision.ops import box_area, box_iou + from torchvision.ops import box_iou except ModuleNotFoundError: warn_missing_pkg('torchvision') # pragma: no-cover _TORCHVISION_AVAILABLE = False @@ -32,6 +32,21 @@ def _corner_coordinates(xy, wh): return torch.cat((top_left, bottom_right), -1) +def _area(boxes: Tensor) -> Tensor: + """ + Computes the area of a set of bounding boxes, which are specified by its + (x1, y1, x2, y2) coordinates. + + Arguments: + boxes (Tensor[N, 4]): boxes for which the area will be computed. They + are expected to be in (x1, y1, x2, y2) format + + Returns: + area (Tensor[N]): area for each box + """ + return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + def _aligned_iou(dims1, dims2): """ Calculates a matrix of intersections over union from box dimensions, assuming that the boxes @@ -69,8 +84,8 @@ def _elementwise_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: iou (Tensor[N]): the vector containing the elementwise IoU values for every element in boxes1 and boxes2 """ - area1 = box_area(boxes1) - area2 = box_area(boxes2) + area1 = _area(boxes1) + area2 = _area(boxes2) lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2] rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2] @@ -101,8 +116,8 @@ def _elementwise_generalized_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: assert (boxes1[:, 2:] >= boxes1[:, :2]).all() assert (boxes2[:, 2:] >= boxes2[:, :2]).all() - area1 = box_area(boxes1) - area2 = box_area(boxes2) + area1 = _area(boxes1) + area2 = _area(boxes2) lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2] rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2] From c2d7907938f65d099b6cc7b665857efd8e9eb1c7 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Fri, 12 Feb 2021 21:35:53 +0200 Subject: [PATCH 19/61] Confirm to yapf formatter rules. --- pl_bolts/models/detection/yolo/yolo_layers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 161b53b101..1f80371283 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -139,11 +139,13 @@ def _elementwise_generalized_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: class IoULoss(nn.Module): + def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return 1.0 - _elementwise_iou(inputs, target) class GIoULoss(nn.Module): + def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return 1.0 - _elementwise_generalized_iou(inputs, target) From 3d7f4400474e7964193ff744da1cfc6d48bea36a Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 15 Feb 2021 09:13:51 +0200 Subject: [PATCH 20/61] Removed the unnecessary linter instructions. Co-authored-by: Akihiro Nitta --- pl_bolts/models/detection/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pl_bolts/models/detection/__init__.py b/pl_bolts/models/detection/__init__.py index bb37ad76a3..f79aa31207 100644 --- a/pl_bolts/models/detection/__init__.py +++ b/pl_bolts/models/detection/__init__.py @@ -1,5 +1,5 @@ -from pl_bolts.models.detection import components # noqa: F401 -from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401 -from pl_bolts.models.detection.yolo import Yolo, YoloConfiguration # noqa: F401 +from pl_bolts.models.detection import components +from pl_bolts.models.detection.faster_rcnn import FasterRCNN +from pl_bolts.models.detection.yolo import Yolo, YoloConfiguration __all__ = ["components", "FasterRCNN", "YoloConfiguration", "Yolo"] From eb6be462adb867a7cde1a18f9cf247dc556c7c52 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 15 Feb 2021 10:15:46 +0200 Subject: [PATCH 21/61] IoU losses use torchvision * IoU losses take the diagnoal of torchvision iou ops instead of implementing their own elementwise ops. * Quotes replaced with double quotes in docstrings. * Import _TORCHVISION_AVAILABLE from pl_bolts. --- pl_bolts/models/detection/yolo/yolo_config.py | 8 +- pl_bolts/models/detection/yolo/yolo_layers.py | 135 ++++-------------- pl_bolts/models/detection/yolo/yolo_module.py | 40 +++--- 3 files changed, 48 insertions(+), 135 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index c3e247a79a..30e745fa5d 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -11,14 +11,14 @@ class YoloConfiguration: """ This class can be used to parse the configuration files of the Darknet YOLOv4 implementation. - The `get_network()` method returns a PyTorch module list that can be used to construct a YOLO + The ``get_network()`` method returns a PyTorch module list that can be used to construct a YOLO model. """ def __init__(self, path: str): """ Saves the variables from the first configuration section to attributes of this object, and - the rest of the sections to the `layer_configs` list. + the rest of the sections to the ``layer_configs`` list. Args: path: Path to a configuration file @@ -39,7 +39,7 @@ def get_network(self) -> nn.ModuleList: modules. Returns the network structure that can be used to create a YOLO model. Returns: - modules: A `nn.ModuleList` that defines the YOLO network. + modules: A ``nn.ModuleList`` that defines the YOLO network. """ result = nn.ModuleList() num_inputs = [3] # Number of channels in the input of every layer up to the current layer @@ -149,7 +149,7 @@ def convert(key, value): def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]: """ - Calls one of the `_create_(config, num_inputs)` functions to create a PyTorch + Calls one of the ``_create_(config, num_inputs)`` functions to create a PyTorch module from the layer config. Args: diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 1f80371283..bfbbafcef8 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -4,15 +4,13 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn, Tensor +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg -try: - from torchvision.ops import box_iou -except ModuleNotFoundError: - warn_missing_pkg('torchvision') # pragma: no-cover - _TORCHVISION_AVAILABLE = False +if _TORCHVISION_AVAILABLE: + from torchvision.ops import box_iou, generalized_box_iou else: - _TORCHVISION_AVAILABLE = True + warn_missing_pkg('torchvision') def _corner_coordinates(xy, wh): @@ -20,8 +18,8 @@ def _corner_coordinates(xy, wh): Converts box center points and sizes to corner coordinates. Args: - xy (Tensor): Center coordinates. Tensor of size `[..., 2]`. - wh (Tensor): Width and height. Tensor of size `[..., 2]`. + xy (Tensor): Center coordinates. Tensor of size ``[..., 2]``. + wh (Tensor): Width and height. Tensor of size ``[..., 2]``. Returns: boxes (Tensor): A matrix of (x1, y1, x2, y2) coordinates. @@ -32,21 +30,6 @@ def _corner_coordinates(xy, wh): return torch.cat((top_left, bottom_right), -1) -def _area(boxes: Tensor) -> Tensor: - """ - Computes the area of a set of bounding boxes, which are specified by its - (x1, y1, x2, y2) coordinates. - - Arguments: - boxes (Tensor[N, 4]): boxes for which the area will be computed. They - are expected to be in (x1, y1, x2, y2) format - - Returns: - area (Tensor[N]): area for each box - """ - return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) - - def _aligned_iou(dims1, dims2): """ Calculates a matrix of intersections over union from box dimensions, assuming that the boxes @@ -58,7 +41,7 @@ def _aligned_iou(dims1, dims2): Returns: iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in - `dims1` and `dims2` + ``dims1`` and ``dims2`` """ area1 = dims1[:, 0] * dims1[:, 1] # [N] area2 = dims2[:, 0] * dims2[:, 1] # [M] @@ -70,84 +53,16 @@ def _aligned_iou(dims1, dims2): return inter / union -def _elementwise_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: - """ - Returns the elementwise intersection-over-union between two sets of boxes. - - Both sets of boxes are expected to be in (x1, y1, x2, y2) format. - - Args: - boxes1 (Tensor[N, 4]) - boxes2 (Tensor[N, 4]) - - Returns: - iou (Tensor[N]): the vector containing the elementwise IoU values for every element in - boxes1 and boxes2 - """ - area1 = _area(boxes1) - area2 = _area(boxes2) - - lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2] - rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2] - - wh = (rb - lt).clamp(min=0) # [N,2] - inter = wh[:, 0] * wh[:, 1] # [N] - - iou = inter / (area1 + area2 - inter) - return iou - - -def _elementwise_generalized_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: - """ - Returns the elementwise generalized intersection-over-union between two sets of boxes. - - Both sets of boxes are expected to be in (x1, y1, x2, y2) format. - - Args: - boxes1 (Tensor[N, 4]) - boxes2 (Tensor[N, 4]) - - Returns: - generalized_iou (Tensor[N]): the vector containing the elementwise generalized IoU values - for every element in boxes1 and boxes2 - """ - - # Degenerate boxes give inf / nan results, so do an early check. - assert (boxes1[:, 2:] >= boxes1[:, :2]).all() - assert (boxes2[:, 2:] >= boxes2[:, :2]).all() - - area1 = _area(boxes1) - area2 = _area(boxes2) - - lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2] - rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2] - - wh = (rb - lt).clamp(min=0) # [N,2] - inter = wh[:, 0] * wh[:, 1] # [N] - - union = area1 + area2 - inter - - iou = inter / union - - lti = torch.min(boxes1[:, :2], boxes2[:, :2]) - rbi = torch.max(boxes1[:, 2:], boxes2[:, 2:]) - - whi = (rbi - lti).clamp(min=0) # [N,2] - areai = whi[:, 0] * whi[:, 1] - - return iou - (areai - union) / areai - - class IoULoss(nn.Module): def forward(self, inputs: Tensor, target: Tensor) -> Tensor: - return 1.0 - _elementwise_iou(inputs, target) + return 1.0 - box_iou(inputs, target).diagonal() class GIoULoss(nn.Module): def forward(self, inputs: Tensor, target: Tensor) -> Tensor: - return 1.0 - _elementwise_generalized_iou(inputs, target) + return 1.0 - generalized_box_iou(inputs, target).diagonal() class DetectionLayer(nn.Module): @@ -183,7 +98,7 @@ def __init__( anchor_dims: A list of all the predefined anchor box dimensions. The list should contain (width, height) tuples in the network input resolution (relative to the width and height defined in the configuration file). - anchor_ids: List of indices to `anchor_dims` that is used to select the (usually 3) + anchor_ids: List of indices to ``anchor_dims`` that is used to select the (usually 3) anchors that this layer uses. xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps to produce coordinate values close to one. @@ -196,7 +111,7 @@ def __init__( of squared errors. confidence_loss_func: Loss function for confidence score. Default is the sum of squared errors. - image_space_loss: If set to `True`, the overlap loss function will receive the bounding + image_space_loss: If set to ``True``, the overlap loss function will receive the bounding box (x1, y1, x2, y2) coordinate normalized to the [0, 1] range. This is needed for the IoU losses introduced in YOLOv4. Otherwise the loss will be computed from the x, y, width, and height values, as predicted by the network (i.e. relative to the @@ -240,14 +155,14 @@ def forward(self, x: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None) Args: x: The output from the previous layer. Tensor of size - `[batch_size, boxes_per_cell * (num_classes + 5), height, width]`. + ``[batch_size, boxes_per_cell * (num_classes + 5), height, width]``. targets: If set, computes losses from detection layers against these targets. A list of dictionaries, one for each image. Returns: output (Tensor), losses (Dict[str, Tensor]): Layer output, and if training targets were provided, a dictionary of losses. Layer output is sized - `[batch_size, num_anchors * height * width, num_classes + 5]`. + ``[batch_size, num_anchors * height * width, num_classes + 5]``. """ batch_size, num_features, height, width = x.shape num_attrs = self.num_classes + 5 @@ -303,7 +218,7 @@ def _global_xy(self, xy): Args: xy (Tensor): The predicted center coordinates before scaling. Values from zero to one - in a tensor sized `[batch_size, height, width, boxes_per_cell, 2]`. + in a tensor sized ``[batch_size, height, width, boxes_per_cell, 2]``. Returns: result (Tensor): Global coordinates from zero to one, in a tensor with the same shape @@ -327,7 +242,7 @@ def _scale_wh(self, wh): Args: wh (Tensor): The unnormalized width and height predictions. Tensor of size - `[..., boxes_per_cell, 2]`. + ``[..., boxes_per_cell, 2]``. Returns: result (Tensor): A tensor with the same shape as the input tensor, but scaled sizes @@ -341,18 +256,18 @@ def _scale_wh(self, wh): def _low_confidence_mask(self, boxes, targets): """ Initializes the mask that will be used to select predictors that are not predicting any - ground-truth target. The value will be `True`, unless the predicted box overlaps any target - significantly (IoU greater than `self.ignore_threshold`). + ground-truth target. The value will be ``True``, unless the predicted box overlaps any target + significantly (IoU greater than ``self.ignore_threshold``). Args: boxes (Tensor): The predicted corner coordinates, normalized to the [0, 1] range. - Tensor of size `[batch_size, height, width, boxes_per_cell, 4]`. + Tensor of size ``[batch_size, height, width, boxes_per_cell, 4]``. targets (List[Dict[str, Tensor]]): List of dictionaries of target values, one dictionary for each image. Returns: - results (Tensor): A boolean tensor shaped `[batch_size, height, width, boxes_per_cell]` - with `False` where the predicted box overlaps a target significantly and `True` + results (Tensor): A boolean tensor shaped ``[batch_size, height, width, boxes_per_cell]`` + with ``False`` where the predicted box overlaps a target significantly and ``True`` elsewhere. """ batch_size, height, width, boxes_per_cell, num_coords = boxes.shape @@ -380,14 +295,14 @@ def _calculate_losses(self, boxes, confidence, classprob, targets, lc_mask): Args: boxes (Tensor): The predicted bounding boxes. A tensor sized - `[batch_size, height, width, boxes_per_cell, 4]`. + ``[batch_size, height, width, boxes_per_cell, 4]``. confidence (Tensor): The confidence predictions, normalized to [0, 1]. A tensor sized - `[batch_size, height, width, boxes_per_cell]`. + ``[batch_size, height, width, boxes_per_cell]``. classprob (Tensor): The class probability predictions, normalized to [0, 1]. A tensor - sized `[batch_size, height, width, boxes_per_cell, num_classes]`. + sized ``[batch_size, height, width, boxes_per_cell, num_classes]``. targets (List[Dict[str, Tensor]]): List of dictionaries of target values, one dictionary for each image. - lc_mask (Tensor): A boolean mask containing `True` where the predicted box does not + lc_mask (Tensor): A boolean mask containing ``True`` where the predicted box does not overlap any target significantly. Returns: @@ -441,7 +356,7 @@ def _calculate_losses(self, boxes, confidence, classprob, targets, lc_mask): ious = _aligned_iou(wh, anchor_wh) best_anchors = ious.max(1).indices - # `anchor_map` maps the anchor indices to the predictors in this layer, or to -1 if + # ``anchor_map`` maps the anchor indices to the predictors in this layer, or to -1 if # it's not an anchor of this layer. We ignore the predictions if the best anchor is in # another layer. predictors = anchor_map[best_anchors] diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 93568109a1..dc27e71a9b 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -10,17 +10,15 @@ from pl_bolts.models.detection.yolo.yolo_config import YoloConfiguration from pl_bolts.models.detection.yolo.yolo_layers import DetectionLayer, RouteLayer, ShortcutLayer from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg -try: +if _TORCHVISION_AVAILABLE: import torchvision.transforms as T from torchvision.ops import nms from torchvision.transforms import functional as F -except ModuleNotFoundError: - warn_missing_pkg('torchvision') # pragma: no-cover - _TORCHVISION_AVAILABLE = False else: - _TORCHVISION_AVAILABLE = True + warn_missing_pkg('torchvision') class Yolo(pl.LightningModule): @@ -43,8 +41,8 @@ class Yolo(pl.LightningModule): of dictionaries). The target dictionaries should contain: - - boxes (`FloatTensor[N, 4]`): the ground truth boxes in `[x1, y1, x2, y2]` format. - - labels (`LongTensor[N]`): the class label for each ground truh box + - boxes (``FloatTensor[N, 4]``): the ground truth boxes in ``[x1, y1, x2, y2]`` format. + - labels (``LongTensor[N]``): the class label for each ground truh box CLI command:: @@ -69,14 +67,14 @@ def __init__( """ Args: network: A list of network modules. This can be obtained from a Darknet configuration - using the `YoloConfiguration.get_network()` method. + using the ``YoloConfiguration.get_network()`` method. optimizer: Which optimizer to use for training; either 'sgd' or 'adam'. momentum: Momentum factor for SGD with momentum. weight_decay: Weight decay (L2 penalty). learning_rate: Learning rate after the warmup period. warmup_epochs: Length of the learning rate warmup period in the beginning of training. During this number of epochs, the learning rate will be raised from - `warmup_start_lr` to `learning_rate`. + ``warmup_start_lr`` to ``learning_rate``. warmup_start_lr: Learning rate in the beginning of the warmup period. annealing_epochs: Length of the learning rate annealing period, during which the learning rate will go to zero. @@ -107,7 +105,7 @@ def forward(self, images: Tensor, targets: List[Dict[str, Tensor]] = None) -> Tuple[Tensor, Tensor, Tensor, Dict[str, Tensor]]: """ - Runs a forward pass through the network (all layers listed in `self.network`), and if + Runs a forward pass through the network (all layers listed in ``self.network``), and if training targets are provided, computes the losses from the detection layers. Detections are concatenated from the detection layers. Each image will produce @@ -119,7 +117,7 @@ def forward(self, Args: images: Images to be processed. Tensor of size - `[batch_size, num_channels, height, width]`. + ``[batch_size, num_channels, height, width]``. targets: If set, computes losses from detection layers against these targets. A list of dictionaries, one for each image. @@ -127,7 +125,7 @@ def forward(self, boxes (Tensor), confidences (Tensor), classprobs (Tensor), losses (Dict[str, Tensor]): Detections, and if targets were provided, a dictionary of losses. The first dimension of the detections is the index of the image in the batch and the second - dimension is the detection within the image. `boxes` contains the predicted + dimension is the detection within the image. ``boxes`` contains the predicted (x1, y1, x2, y2) coordinates, normalized to [0, 1]. """ outputs = [] # Outputs from all layers @@ -252,7 +250,7 @@ def infer(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: detected bounding boxes, confidences, and class labels. Args: - image: An input image, a tensor of uint8 values sized `[channels, height, width]`. + image: An input image, a tensor of uint8 values sized ``[channels, height, width]``. Returns: boxes (:class:`~torch.Tensor`), confidences (:class:`~torch.Tensor`), labels (:class:`~torch.Tensor`): @@ -301,8 +299,8 @@ def load_darknet_weights(self, weight_file): def read(tensor): """ - Reads the contents of `tensor` from the current position of `weight_file`. - If there's no more data in `weight_file`, returns without error. + Reads the contents of ``tensor`` from the current position of ``weight_file``. + If there's no more data in ``weight_file``, returns without error. """ x = np.fromfile(weight_file, count=tensor.numel(), dtype=np.float32) if x.shape[0] == 0: @@ -390,10 +388,10 @@ def _filter_detections(self, boxes: Tensor, confidences: Tensor, classprobs: Ten Args: boxes: Detected bounding box (x1, y1, x2, y2) coordinates in a tensor sized - `[batch_size, N, 4]`. - confidences: Detection confidences in a tensor sized `[batch_size, N]`. - classprobs: Probabilities of the best classes in a tensor sized `[batch_size, N]`. - labels: Indices of the best classes in a tensor sized `[batch_size, N]`. + ``[batch_size, N, 4]``. + confidences: Detection confidences in a tensor sized ``[batch_size, N]``. + classprobs: Probabilities of the best classes in a tensor sized ``[batch_size, N]``. + labels: Indices of the best classes in a tensor sized ``[batch_size, N]``. Returns: boxes (List[Tensor]), confidences (List[Tensor]), classprobs (List[Tensor]), labels (List[Tensor]): @@ -447,8 +445,8 @@ class Resize: Args: output_size (tuple or int): Desired output size. If tuple (height, width), the output is - matched to `output_size`. If int, the smaller of the image edges is matched to - `output_size`, keeping the aspect ratio the same. + matched to ``output_size``. If int, the smaller of the image edges is matched to + ``output_size``, keeping the aspect ratio the same. """ def __init__(self, output_size: tuple): From cf7420ca699eb6a6df23ce010314d9d6b4524cf2 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 15 Feb 2021 10:33:06 +0200 Subject: [PATCH 22/61] Improved strange yapf formatting Co-authored-by: Akihiro Nitta --- pl_bolts/models/detection/yolo/yolo_module.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index dc27e71a9b..4c186ab494 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -339,8 +339,10 @@ def get_deprecated_arg_names(cls) -> List: depr_arg_names.extend(val) return depr_arg_names - def _validate_batch(self, batch: Tuple[List[Tensor], List[Dict[str, - Tensor]]]) -> Tuple[Tensor, List[Dict[str, Tensor]]]: + def _validate_batch( + self, + batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], + ) -> Tuple[Tensor, List[Dict[str, Tensor]]]: """ Reads a batch of data, validates the format, and stacks the images into a single tensor. From 6c90cd4d5399f4929eb8e0e267c1199c974c4a6a Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 15 Feb 2021 19:53:54 +0200 Subject: [PATCH 23/61] Refactoring * YOLO written with all caps in class names * Generic way to specify optimizer and LR scheduler --- pl_bolts/models/detection/__init__.py | 5 +- pl_bolts/models/detection/yolo/__init__.py | 4 - pl_bolts/models/detection/yolo/yolo_config.py | 2 +- pl_bolts/models/detection/yolo/yolo_layers.py | 15 ++- pl_bolts/models/detection/yolo/yolo_module.py | 92 +++++++++---------- tests/models/test_detection.py | 10 +- 6 files changed, 66 insertions(+), 62 deletions(-) diff --git a/pl_bolts/models/detection/__init__.py b/pl_bolts/models/detection/__init__.py index f79aa31207..3a8034f9f0 100644 --- a/pl_bolts/models/detection/__init__.py +++ b/pl_bolts/models/detection/__init__.py @@ -1,5 +1,6 @@ from pl_bolts.models.detection import components from pl_bolts.models.detection.faster_rcnn import FasterRCNN -from pl_bolts.models.detection.yolo import Yolo, YoloConfiguration +from pl_bolts.models.detection.yolo.yolo_module import YOLO +from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration -__all__ = ["components", "FasterRCNN", "YoloConfiguration", "Yolo"] +__all__ = ["components", "FasterRCNN", "YOLOConfiguration", "YOLO"] diff --git a/pl_bolts/models/detection/yolo/__init__.py b/pl_bolts/models/detection/yolo/__init__.py index a2785f5882..e69de29bb2 100644 --- a/pl_bolts/models/detection/yolo/__init__.py +++ b/pl_bolts/models/detection/yolo/__init__.py @@ -1,4 +0,0 @@ -from pl_bolts.models.detection.yolo.yolo_config import YoloConfiguration -from pl_bolts.models.detection.yolo.yolo_module import Yolo - -__all__ = ["YoloConfiguration", "Yolo"] diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index 30e745fa5d..58f513eac2 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -8,7 +8,7 @@ from pl_bolts.models.detection.yolo import yolo_layers -class YoloConfiguration: +class YOLOConfiguration: """ This class can be used to parse the configuration files of the Darknet YOLOv4 implementation. The ``get_network()`` method returns a PyTorch module list that can be used to construct a YOLO diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index bfbbafcef8..84bac064b6 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -8,7 +8,13 @@ from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: - from torchvision.ops import box_iou, generalized_box_iou + from torchvision.ops import box_iou + try: + from torchvision.ops import generalized_box_iou + except ImportError: + _GIOU_AVAILABLE = False + else: + _GIOU_AVAILABLE = True else: warn_missing_pkg('torchvision') @@ -61,6 +67,13 @@ def forward(self, inputs: Tensor, target: Tensor) -> Tensor: class GIoULoss(nn.Module): + def __init__(self): + super().__init__() + if not _GIOU_AVAILABLE: + raise ModuleNotFoundError( # pragma: no-cover + 'A more recent version of `torchvision` is needed for generalized IoU loss.' + ) + def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return 1.0 - generalized_box_iou(inputs, target).diagonal() diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index dc27e71a9b..f07dcf8ed5 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -1,5 +1,5 @@ import inspect -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Type import numpy as np import pytorch_lightning as pl @@ -7,7 +7,7 @@ import torch.nn as nn from torch import optim, Tensor -from pl_bolts.models.detection.yolo.yolo_config import YoloConfiguration +from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration from pl_bolts.models.detection.yolo.yolo_layers import DetectionLayer, RouteLayer, ShortcutLayer from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR from pl_bolts.utils import _TORCHVISION_AVAILABLE @@ -21,7 +21,7 @@ warn_missing_pkg('torchvision') -class Yolo(pl.LightningModule): +class YOLO(pl.LightningModule): """ PyTorch Lightning implementation of `YOLOv3 `_ with some improvements from `YOLOv4 `_. @@ -54,13 +54,10 @@ class Yolo(pl.LightningModule): def __init__( self, network: nn.ModuleList, - optimizer: str = 'sgd', - momentum: float = 0.9, - weight_decay: float = 0.0005, - learning_rate: float = 0.0013, - warmup_epochs: int = 1, - warmup_start_lr: float = 0.0001, - annealing_epochs: int = 271, + optimizer: Type[optim.Optimizer] = optim.SGD, + optimizer_params: Dict[str, Any] = {'lr': 0.0013, 'momentum': 0.9, 'weight_decay': 0.0005}, + lr_scheduler: Type[optim.lr_scheduler._LRScheduler] = LinearWarmupCosineAnnealingLR, + lr_scheduler_params: Dict[str, Any] = {'warmup_epochs': 1, 'max_epochs': 271, 'warmup_start_lr': 0.0}, confidence_threshold: float = 0.2, nms_threshold: float = 0.45 ): @@ -68,16 +65,10 @@ def __init__( Args: network: A list of network modules. This can be obtained from a Darknet configuration using the ``YoloConfiguration.get_network()`` method. - optimizer: Which optimizer to use for training; either 'sgd' or 'adam'. - momentum: Momentum factor for SGD with momentum. - weight_decay: Weight decay (L2 penalty). - learning_rate: Learning rate after the warmup period. - warmup_epochs: Length of the learning rate warmup period in the beginning of - training. During this number of epochs, the learning rate will be raised from - ``warmup_start_lr`` to ``learning_rate``. - warmup_start_lr: Learning rate in the beginning of the warmup period. - annealing_epochs: Length of the learning rate annealing period, during which the - learning rate will go to zero. + optimizer: Which optimizer class to use for training. + optimizer_params: Parameters to pass to the optimizer constructor. + lr_scheduler: Which learning rate scheduler class to use for training. + lr_scheduler_params: Parameters to pass to the learning rate scheduler constructor. confidence_threshold: Postprocessing will remove bounding boxes whose confidence score is not higher than this threshold. nms_threshold: Non-maximum suppression will remove bounding boxes whose IoU @@ -91,13 +82,10 @@ def __init__( ) self.network = network - self.optimizer = optimizer - self.momentum = momentum - self.weight_decay = weight_decay - self.learning_rate = learning_rate - self.warmup_epochs = warmup_epochs - self.warmup_start_lr = warmup_start_lr - self.annealing_epochs = annealing_epochs + self.optimizer_class = optimizer + self.optimizer_params = optimizer_params + self.lr_scheduler_class = lr_scheduler + self.lr_scheduler_params = lr_scheduler_params self.confidence_threshold = confidence_threshold self.nms_threshold = nms_threshold @@ -166,22 +154,8 @@ def mean_loss(loss_name): def configure_optimizers(self) -> Tuple[List, List]: """Constructs the optimizer and learning rate scheduler.""" - if self.optimizer == 'sgd': - optimizer = optim.SGD( - self.parameters(), lr=self.learning_rate, momentum=self.momentum, weight_decay=self.weight_decay - ) - elif self.optimizer == 'adam': - optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) - else: - raise ValueError("Unknown optimizer: {}".format(self.optimizer)) - - lr_scheduler = LinearWarmupCosineAnnealingLR( - optimizer, - warmup_epochs=self.warmup_epochs, - max_epochs=self.annealing_epochs, - warmup_start_lr=self.warmup_start_lr - ) - + optimizer = self.optimizer_class(self.parameters(), **self.optimizer_params) + lr_scheduler = self.lr_scheduler_class(optimizer, **self.lr_scheduler_params) return [optimizer], [lr_scheduler] def training_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int) -> Dict[str, Tensor]: @@ -483,22 +457,42 @@ def run_cli(): parser.add_argument('--config', type=str, help='model configuration file', required=True) parser.add_argument('--darknet-weights', type=str, help='initialize the model weights from this Darknet model file') parser.add_argument('--batch-size', type=int, help='number of images in one batch', default=16) + parser.add_argument('--lr', type=float, help='learning ratea after warmup', default=0.0013) + parser.add_argument('--momentum', type=float, help='optimizer momentum factor', default=0.9) + parser.add_argument('--weight-decay', type=float, help='weight decay (L2 penalty)', default=0.0005) + parser.add_argument('--warmup-epochs', type=int, help='length of the learning rate warmup', default=1) + parser.add_argument('--max-epochs', type=int, help='maximum number of epochs to train', default=271) + parser.add_argument('--initial-lr', type=float, help='learning rate before warmup', default=0.0) + parser.add_argument('--confidence-threshold', type=float, help='threshold for prediction confidence', default=0.01) + parser.add_argument('--nms-threshold', type=float, help='non-maximum suppression threshold', default=0.45) parser = VOCDetectionDataModule.add_argparse_args(parser) - parser = argparse_utils.add_argparse_args(Yolo, parser) parser = pl.Trainer.add_argparse_args(parser) args = parser.parse_args() - config = YoloConfiguration(args.config) + config = YOLOConfiguration(args.config) transforms = [Resize((config.height, config.width))] image_transforms = T.ToTensor() datamodule = VOCDetectionDataModule.from_argparse_args(args) datamodule.prepare_data() - params = vars(args) - valid_kwargs = inspect.signature(Yolo.__init__).parameters - kwargs = dict((name, params[name]) for name in valid_kwargs if name in params) - model = Yolo(network=config.get_network(), **kwargs) + optimizer_params = { + 'lr': args.lr, + 'momentum': args.momentum, + 'weight_decay': args.weight_decay + } + lr_scheduler_params = { + 'warmup_epochs': args.warmup_epochs, + 'max_epochs': args.max_epochs, + 'warmup_start_lr': args.initial_lr + } + model = YOLO( + network=config.get_network(), + optimizer_params=optimizer_params, + lr_scheduler_params=lr_scheduler_params, + confidence_threshold=args.confidence_threshold, + nms_threshold=args.nms_threshold + ) if args.darknet_weights is not None: with open(args.darknet_weights, 'r') as weight_file: model.load_darknet_weights(weight_file) diff --git a/tests/models/test_detection.py b/tests/models/test_detection.py index 7ad94db224..448943ef3e 100644 --- a/tests/models/test_detection.py +++ b/tests/models/test_detection.py @@ -6,7 +6,7 @@ from torch.utils.data import DataLoader from pl_bolts.datasets import DummyDetectionDataset -from pl_bolts.models.detection import FasterRCNN, Yolo, YoloConfiguration +from pl_bolts.models.detection import FasterRCNN, YOLO, YOLOConfiguration from pl_bolts.models.detection.yolo.yolo_layers import _aligned_iou @@ -129,8 +129,8 @@ def _create_yolo_config_file(config_path): def test_yolo(tmpdir): config_path = Path(tmpdir) / 'yolo.cfg' _create_yolo_config_file(config_path) - config = YoloConfiguration(config_path) - model = Yolo(config.get_network()) + config = YOLOConfiguration(config_path) + model = YOLO(config.get_network()) image = torch.rand(1, 3, 256, 256) model(image) @@ -139,8 +139,8 @@ def test_yolo(tmpdir): def test_yolo_train(tmpdir): config_path = Path(tmpdir) / 'yolo.cfg' _create_yolo_config_file(config_path) - config = YoloConfiguration(config_path) - model = Yolo(config.get_network()) + config = YOLOConfiguration(config_path) + model = YOLO(config.get_network()) train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) From 60fda75775f0e2e663fd9615e2abf7d45747d3a8 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 15 Feb 2021 19:57:35 +0200 Subject: [PATCH 24/61] get_deprecated_arg_names() is not needed anymore. --- pl_bolts/models/detection/yolo/yolo_module.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index cf0d2fe0c5..ed07e31fed 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -304,18 +304,9 @@ def read(tensor): read(conv.weight) - @classmethod - def get_deprecated_arg_names(cls) -> List: - """Returns a list with deprecated constructor arguments.""" - depr_arg_names = [] - for name, val in cls.__dict__.items(): - if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)): - depr_arg_names.extend(val) - return depr_arg_names - def _validate_batch( self, - batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], + batch: Tuple[List[Tensor], List[Dict[str, Tensor]]] ) -> Tuple[Tensor, List[Dict[str, Tensor]]]: """ Reads a batch of data, validates the format, and stacks the images into a single tensor. From b2e3e84a9f43f0eea9121edfdd7751acfb34f107 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 15 Feb 2021 20:00:32 +0200 Subject: [PATCH 25/61] Fixed yapf formatting. Co-authored-by: Akihiro Nitta --- pl_bolts/models/detection/yolo/yolo_module.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index ed07e31fed..01a7315f56 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -345,8 +345,13 @@ def _validate_batch( images = torch.stack(images) return images, targets - def _filter_detections(self, boxes: Tensor, confidences: Tensor, classprobs: Tensor, - labels: Tensor) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]: + def _filter_detections( + self, + boxes: Tensor, + confidences: Tensor, + classprobs: Tensor, + labels: Tensor, + ) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]: """ Filters detections based on confidence threshold. Then for every class performs non-maximum suppression (NMS). NMS iterates the bounding boxes that predict this class in descending From 940947f694528e1acd7b604f11427b0d2a53126c Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 15 Feb 2021 20:01:17 +0200 Subject: [PATCH 26/61] Fixed formatting. Co-authored-by: Akihiro Nitta --- pl_bolts/models/detection/yolo/yolo_layers.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 84bac064b6..47ce33bf38 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -135,10 +135,8 @@ def __init__( """ super().__init__() - if not _TORCHVISION_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover - 'YOLO model uses `torchvision`, which is not installed yet.' - ) + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError('YOLO model uses `torchvision`, which is not installed yet.') self.num_classes = num_classes self.image_width = image_width From 6e3d5bf65dd96a013460a8dc28f88b6de49c4b40 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 15 Feb 2021 20:03:13 +0200 Subject: [PATCH 27/61] Removed unused imports. --- pl_bolts/models/detection/yolo/yolo_module.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index ed07e31fed..89dad36157 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -1,4 +1,3 @@ -import inspect from typing import Any, Dict, List, Tuple, Type import numpy as np @@ -440,8 +439,6 @@ def __call__(self, image, target): def run_cli(): from argparse import ArgumentParser - from pytorch_lightning.utilities import argparse_utils - from pl_bolts.datamodules import VOCDetectionDataModule pl.seed_everything(42) From b2ea49786aee5afdec5062afe6cad80c04e4efb3 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Tue, 16 Feb 2021 10:45:37 +0200 Subject: [PATCH 28/61] Fixed some type hints. --- pl_bolts/models/detection/yolo/yolo_config.py | 2 +- pl_bolts/models/detection/yolo/yolo_layers.py | 14 ++++++------ pl_bolts/models/detection/yolo/yolo_module.py | 22 +++++++++---------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index 58f513eac2..9d2814aa61 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -15,7 +15,7 @@ class YOLOConfiguration: model. """ - def __init__(self, path: str): + def __init__(self, path: str) -> None: """ Saves the variables from the first configuration section to attributes of this object, and the rest of the sections to the ``layer_configs`` list. diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 47ce33bf38..e9f14f89b2 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -67,7 +67,7 @@ def forward(self, inputs: Tensor, target: Tensor) -> Tensor: class GIoULoss(nn.Module): - def __init__(self): + def __init__(self) -> None: super().__init__() if not _GIOU_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover @@ -93,14 +93,14 @@ def __init__( anchor_ids: List[int], xy_scale: float = 1.0, ignore_threshold: float = 0.5, - overlap_loss_func: Callable = None, - class_loss_func: Callable = None, - confidence_loss_func: Callable = None, + overlap_loss_func: Optional[Callable] = None, + class_loss_func: Optional[Callable] = None, + confidence_loss_func: Optional[Callable] = None, image_space_loss: bool = False, overlap_loss_multiplier: float = 1.0, class_loss_multiplier: float = 1.0, confidence_loss_multiplier: float = 1.0 - ): + ) -> None: """ Args: num_classes: Number of different classes that this layer predicts. @@ -463,7 +463,7 @@ def forward(self, x): class RouteLayer(nn.Module): """Route layer concatenates the output (or part of it) from given layers.""" - def __init__(self, source_layers: List[int], num_chunks: int, chunk_idx: int): + def __init__(self, source_layers: List[int], num_chunks: int, chunk_idx: int) -> None: """ Args: source_layers: Indices of the layers whose output will be concatenated. @@ -483,7 +483,7 @@ def forward(self, x, outputs): class ShortcutLayer(nn.Module): """Shortcut layer adds a residual connection from the source layer.""" - def __init__(self, source_layer: int): + def __init__(self, source_layer: int) -> None: """ Args: source_layer: Index of the layer whose output will be added to the output of the diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 8cc7af4e64..9177e430e9 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import numpy as np import pytorch_lightning as pl @@ -13,7 +13,6 @@ from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: - import torchvision.transforms as T from torchvision.ops import nms from torchvision.transforms import functional as F else: @@ -59,7 +58,7 @@ def __init__( lr_scheduler_params: Dict[str, Any] = {'warmup_epochs': 1, 'max_epochs': 271, 'warmup_start_lr': 0.0}, confidence_threshold: float = 0.2, nms_threshold: float = 0.45 - ): + ) -> None: """ Args: network: A list of network modules. This can be obtained from a Darknet configuration @@ -88,9 +87,11 @@ def __init__( self.confidence_threshold = confidence_threshold self.nms_threshold = nms_threshold - def forward(self, - images: Tensor, - targets: List[Dict[str, Tensor]] = None) -> Tuple[Tensor, Tensor, Tensor, Dict[str, Tensor]]: + def forward( + self, + images: Tensor, + targets: Optional[List[Dict[str, Tensor]]] = None + ) -> Tuple[Tensor, Tensor, Tensor, Dict[str, Tensor]]: """ Runs a forward pass through the network (all layers listed in ``self.network``), and if training targets are provided, computes the losses from the detection layers. @@ -349,7 +350,7 @@ def _filter_detections( boxes: Tensor, confidences: Tensor, classprobs: Tensor, - labels: Tensor, + labels: Tensor ) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]: """ Filters detections based on confidence threshold. Then for every class performs non-maximum @@ -420,7 +421,7 @@ class Resize: ``output_size``, keeping the aspect ratio the same. """ - def __init__(self, output_size: tuple): + def __init__(self, output_size: tuple) -> None: self.output_size = output_size def __call__(self, image, target): @@ -467,7 +468,6 @@ def run_cli(): config = YOLOConfiguration(args.config) transforms = [Resize((config.height, config.width))] - image_transforms = T.ToTensor() datamodule = VOCDetectionDataModule.from_argparse_args(args) datamodule.prepare_data() @@ -494,8 +494,8 @@ def run_cli(): trainer = pl.Trainer.from_argparse_args(args) trainer.fit( - model, datamodule.train_dataloader(args.batch_size, transforms, image_transforms), - datamodule.val_dataloader(args.batch_size, transforms, image_transforms) + model, datamodule.train_dataloader(args.batch_size, transforms), + datamodule.val_dataloader(args.batch_size, transforms) ) From 6d8fa7d156497669c3ceaa40761a311d5aea3435 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Tue, 16 Feb 2021 11:02:31 +0200 Subject: [PATCH 29/61] Sorted imports. --- pl_bolts/models/detection/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/models/detection/__init__.py b/pl_bolts/models/detection/__init__.py index 3a8034f9f0..bcb97d7269 100644 --- a/pl_bolts/models/detection/__init__.py +++ b/pl_bolts/models/detection/__init__.py @@ -1,6 +1,6 @@ from pl_bolts.models.detection import components from pl_bolts.models.detection.faster_rcnn import FasterRCNN -from pl_bolts.models.detection.yolo.yolo_module import YOLO from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration +from pl_bolts.models.detection.yolo.yolo_module import YOLO __all__ = ["components", "FasterRCNN", "YOLOConfiguration", "YOLO"] From e68df7ad02479ebf9eb0033a9334a9775293bdf4 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 24 Feb 2021 19:14:23 +0200 Subject: [PATCH 30/61] Possible to limit the number of predictions per image --- pl_bolts/models/detection/yolo/yolo_module.py | 86 +++++++++++++++---- 1 file changed, 67 insertions(+), 19 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 9177e430e9..7051e03549 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -57,7 +57,8 @@ def __init__( lr_scheduler: Type[optim.lr_scheduler._LRScheduler] = LinearWarmupCosineAnnealingLR, lr_scheduler_params: Dict[str, Any] = {'warmup_epochs': 1, 'max_epochs': 271, 'warmup_start_lr': 0.0}, confidence_threshold: float = 0.2, - nms_threshold: float = 0.45 + nms_threshold: float = 0.45, + max_predictions_per_image: int = -1 ) -> None: """ Args: @@ -69,8 +70,10 @@ def __init__( lr_scheduler_params: Parameters to pass to the learning rate scheduler constructor. confidence_threshold: Postprocessing will remove bounding boxes whose confidence score is not higher than this threshold. - nms_threshold: Non-maximum suppression will remove bounding boxes whose IoU - with the next best bounding box in that class is higher than this threshold. + nms_threshold: Non-maximum suppression will remove bounding boxes whose IoU with a higher + confidence box is higher than this threshold, if the predicted categories are equal. + max_predictions_per_image: If non-negative, keep at most this number of + highest-confidence predictions per image. """ super().__init__() @@ -86,6 +89,7 @@ def __init__( self.lr_scheduler_params = lr_scheduler_params self.confidence_threshold = confidence_threshold self.nms_threshold = nms_threshold + self.max_predictions_per_image = max_predictions_per_image def forward( self, @@ -404,10 +408,14 @@ def _filter_detections( img_out_classprobs = torch.cat((img_out_classprobs, cls_classprobs[selected])) img_out_labels = torch.cat((img_out_labels, cls_labels[selected])) - out_boxes.append(img_out_boxes) - out_confidences.append(img_out_confidences) - out_classprobs.append(img_out_classprobs) - out_labels.append(img_out_labels) + # Sort by descending confidence and limit the maximum number of predictions. + indices = torch.argsort(img_out_confidences, descending=True) + if self.max_predictions_per_image >= 0: + indices = indices[:self.max_predictions_per_image] + out_boxes.append(img_out_boxes[indices]) + out_confidences.append(img_out_confidences[indices]) + out_classprobs.append(img_out_classprobs[indices]) + out_labels.append(img_out_labels[indices]) return out_boxes, out_confidences, out_classprobs, out_labels @@ -450,17 +458,56 @@ def run_cli(): pl.seed_everything(42) parser = ArgumentParser() - parser.add_argument('--config', type=str, help='model configuration file', required=True) - parser.add_argument('--darknet-weights', type=str, help='initialize the model weights from this Darknet model file') - parser.add_argument('--batch-size', type=int, help='number of images in one batch', default=16) - parser.add_argument('--lr', type=float, help='learning ratea after warmup', default=0.0013) - parser.add_argument('--momentum', type=float, help='optimizer momentum factor', default=0.9) - parser.add_argument('--weight-decay', type=float, help='weight decay (L2 penalty)', default=0.0005) - parser.add_argument('--warmup-epochs', type=int, help='length of the learning rate warmup', default=1) - parser.add_argument('--max-epochs', type=int, help='maximum number of epochs to train', default=271) - parser.add_argument('--initial-lr', type=float, help='learning rate before warmup', default=0.0) - parser.add_argument('--confidence-threshold', type=float, help='threshold for prediction confidence', default=0.01) - parser.add_argument('--nms-threshold', type=float, help='non-maximum suppression threshold', default=0.45) + parser.add_argument( + '--config', type=str, metavar='PATH', required=True, + help='read model configuration from PATH' + ) + parser.add_argument( + '--darknet-weights', type=str, metavar='PATH', + help='read the initial model weights from PATH in Darknet format' + ) + parser.add_argument( + '--batch-size', type=int, metavar='N', default=16, + help='batch size is N image' + ) + parser.add_argument( + '--lr', type=float, metavar='LR', default=0.0013, + help='learning rate after the warmup period' + ) + parser.add_argument( + '--momentum', type=float, metavar='GAMMA', default=0.9, + help='if nonzero, the optimizer uses momentum with factor GAMMA' + ) + parser.add_argument( + '--weight-decay', type=float, metavar='LAMBDA', default=0.0005, + help='if nonzero, the optimizer uses weight decay (L2 penalty) with factor LAMBDA' + ) + parser.add_argument( + '--warmup-epochs', type=int, metavar='N', default=1, + help='learning rate warmup period is N epochs' + ) + parser.add_argument( + '--max-epochs', type=int, metavar='N', default=300, + help='train at most N epochs' + ) + parser.add_argument( + '--initial-lr', type=float, metavar='LR', default=0.0, + help='learning rate before the warmup period' + ) + parser.add_argument( + '--confidence-threshold', type=float, metavar='THRESHOLD', default=0.001, + help='keep predictions only if the confidence is above THRESHOLD' + ) + parser.add_argument( + '--nms-threshold', type=float, metavar='THRESHOLD', default=0.45, + help='non-maximum suppression removes predicted boxes that have IoU greater than ' + 'THRESHOLD with a higher scoring box' + ) + parser.add_argument( + '--max-predictions-per-image', type=int, metavar='N', default=100, + help='keep at most N best predictions' + ) + parser = VOCDetectionDataModule.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser) args = parser.parse_args() @@ -486,7 +533,8 @@ def run_cli(): optimizer_params=optimizer_params, lr_scheduler_params=lr_scheduler_params, confidence_threshold=args.confidence_threshold, - nms_threshold=args.nms_threshold + nms_threshold=args.nms_threshold, + max_predictions_per_image=args.max_predictions_per_image ) if args.darknet_weights is not None: with open(args.darknet_weights, 'r') as weight_file: From f895530cbcf3c6bab9d5092edb701a69d1c5c2f7 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 24 Feb 2021 19:22:29 +0200 Subject: [PATCH 31/61] None instead of an empty list as default argument --- pl_bolts/datamodules/vocdetection_datamodule.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index b34ba48da3..6d2d1fc8e2 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -154,7 +154,7 @@ def prepare_data(self) -> None: def train_dataloader( self, batch_size: int = 1, - transforms: List[Callable] = [], + transforms: List[Callable] = None, image_transforms: Optional[Callable] = None ) -> DataLoader: """ @@ -165,9 +165,13 @@ def train_dataloader( transforms: custom transforms for image and target image_transforms: custom image-only transforms """ - transforms = [_prepare_voc_instance] + transforms + if transforms is None: + transforms = [_prepare_voc_instance] + else: + transforms = [_prepare_voc_instance] + transforms image_transforms = image_transforms or self.train_transforms or self._default_transforms() transforms = Compose(transforms, image_transforms) + dataset = VOCDetection(self.data_dir, year=self.year, image_set="train", transforms=transforms) loader = DataLoader( dataset, @@ -183,7 +187,7 @@ def train_dataloader( def val_dataloader( self, batch_size: int = 1, - transforms: List[Callable] = [], + transforms: List[Callable] = None, image_transforms: Optional[Callable] = None ) -> DataLoader: """ @@ -194,9 +198,13 @@ def val_dataloader( transforms: custom transforms for image and target image_transforms: custom image-only transforms """ - transforms = [_prepare_voc_instance] + transforms + if transforms is None: + transforms = [_prepare_voc_instance] + else: + transforms = [_prepare_voc_instance] + transforms image_transforms = image_transforms or self.train_transforms or self._default_transforms() transforms = Compose(transforms, image_transforms) + dataset = VOCDetection(self.data_dir, year=self.year, image_set="val", transforms=transforms) loader = DataLoader( dataset, From 58f1456a136106d41d261ce8ff30297863f864b5 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 24 Feb 2021 23:25:10 +0200 Subject: [PATCH 32/61] Fixed capitalization of YOLO class. --- docs/source/object_detection.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/object_detection.rst b/docs/source/object_detection.rst index cdff88f9e0..37ace12a88 100644 --- a/docs/source/object_detection.rst +++ b/docs/source/object_detection.rst @@ -16,5 +16,5 @@ Faster R-CNN YOLO ---- -.. autoclass:: pl_bolts.models.detection.yolo.yolo_module.Yolo +.. autoclass:: pl_bolts.models.detection.yolo.yolo_module.YOLO :noindex: From 4e6d4cfb37f93b17ea395609ddee77eb2096b253 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 8 Mar 2021 08:25:13 +0200 Subject: [PATCH 33/61] No need to check for NaN values as Trainer has terminate_on_nan argument. --- pl_bolts/models/detection/yolo/yolo_layers.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index e9f14f89b2..bfd345666b 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -200,9 +200,6 @@ def forward(self, x: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None) # x/y coordinates. xy = xy * self.xy_scale - 0.5 * (self.xy_scale - 1) - if not torch.isfinite(x).all(): - raise ValueError('Detection layer output contains nan or inf values.') - image_xy = self._global_xy(xy) image_wh = self._scale_wh(wh) boxes = _corner_coordinates(image_xy, image_wh) From af3e0e67d66e2f697768e7210c7ba7ec6f823b00 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 8 Mar 2021 09:10:30 +0200 Subject: [PATCH 34/61] YOLO test configuration moved to tests/data/yolo.cfg --- tests/data/yolo.cfg | 79 +++++++++++++++++++++++++++++ tests/models/test_detection.py | 93 ++-------------------------------- 2 files changed, 82 insertions(+), 90 deletions(-) create mode 100644 tests/data/yolo.cfg diff --git a/tests/data/yolo.cfg b/tests/data/yolo.cfg new file mode 100644 index 0000000000..d7596ef4b5 --- /dev/null +++ b/tests/data/yolo.cfg @@ -0,0 +1,79 @@ +[net] +width=256 +height=256 +channels=3 + +[convolutional] +batch_normalize=1 +filters=8 +size=3 +stride=1 +pad=1 +activation=leaky + +[route] +layers=-1 +groups=2 +group_id=1 + +[maxpool] +size=2 +stride=2 + +[convolutional] +batch_normalize=1 +filters=2 +size=1 +stride=1 +pad=1 +activation=mish + +[convolutional] +batch_normalize=1 +filters=4 +size=3 +stride=1 +pad=1 +activation=mish + +[shortcut] +from=-3 +activation=linear + +[convolutional] +size=1 +stride=1 +pad=1 +filters=14 +activation=linear + +[yolo] +mask=2,3 +anchors=1,2, 3,4, 5,6, 9,10 +classes=2 +scale_x_y=1.05 +cls_normalizer=1.0 +iou_normalizer=0.07 +ignore_thresh=0.7 + +[route] +layers = -4 + +[upsample] +stride=2 + +[convolutional] +size=1 +stride=1 +pad=1 +filters=14 +activation=linear + +[yolo] +mask=0,1 +anchors=1,2, 3,4, 5,6, 9,10 +classes=2 +scale_x_y=1.05 +cls_normalizer=1.0 +iou_normalizer=0.07 +ignore_thresh=0.7 diff --git a/tests/models/test_detection.py b/tests/models/test_detection.py index 448943ef3e..ed42f55229 100644 --- a/tests/models/test_detection.py +++ b/tests/models/test_detection.py @@ -8,6 +8,7 @@ from pl_bolts.datasets import DummyDetectionDataset from pl_bolts.models.detection import FasterRCNN, YOLO, YOLOConfiguration from pl_bolts.models.detection.yolo.yolo_layers import _aligned_iou +from tests import TEST_ROOT def _collate_fn(batch): @@ -40,95 +41,8 @@ def test_fasterrcnn_bbone_train(tmpdir): trainer.fit(model, train_dl, valid_dl) -def _create_yolo_config_file(config_path): - config_file = open(config_path, 'w') - config_file.write( - '''[net] -width=256 -height=256 -channels=3 - -[convolutional] -batch_normalize=1 -filters=8 -size=3 -stride=1 -pad=1 -activation=leaky - -[route] -layers=-1 -groups=2 -group_id=1 - -[maxpool] -size=2 -stride=2 - -[convolutional] -batch_normalize=1 -filters=2 -size=1 -stride=1 -pad=1 -activation=mish - -[convolutional] -batch_normalize=1 -filters=4 -size=3 -stride=1 -pad=1 -activation=mish - -[shortcut] -from=-3 -activation=linear - -[convolutional] -size=1 -stride=1 -pad=1 -filters=14 -activation=linear - -[yolo] -mask=2,3 -anchors=1,2, 3,4, 5,6, 9,10 -classes=2 -scale_x_y=1.05 -cls_normalizer=1.0 -iou_normalizer=0.07 -ignore_thresh=0.7 - -[route] -layers = -4 - -[upsample] -stride=2 - -[convolutional] -size=1 -stride=1 -pad=1 -filters=14 -activation=linear - -[yolo] -mask=0,1 -anchors=1,2, 3,4, 5,6, 9,10 -classes=2 -scale_x_y=1.05 -cls_normalizer=1.0 -iou_normalizer=0.07 -ignore_thresh=0.7''' - ) - config_file.close() - - def test_yolo(tmpdir): - config_path = Path(tmpdir) / 'yolo.cfg' - _create_yolo_config_file(config_path) + config_path = Path(TEST_ROOT) / 'data' / 'yolo.cfg' config = YOLOConfiguration(config_path) model = YOLO(config.get_network()) @@ -137,8 +51,7 @@ def test_yolo(tmpdir): def test_yolo_train(tmpdir): - config_path = Path(tmpdir) / 'yolo.cfg' - _create_yolo_config_file(config_path) + config_path = Path(TEST_ROOT) / 'data' / 'yolo.cfg' config = YOLOConfiguration(config_path) model = YOLO(config.get_network()) From 40122474b481b4b768477acfe32464b4a4c62f60 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 8 Mar 2021 09:19:10 +0200 Subject: [PATCH 35/61] Use Optional[] as the default value for transforms is now None --- pl_bolts/datamodules/vocdetection_datamodule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 6d2d1fc8e2..bbd3ae3299 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -154,7 +154,7 @@ def prepare_data(self) -> None: def train_dataloader( self, batch_size: int = 1, - transforms: List[Callable] = None, + transforms: Optional[List[Callable]] = None, image_transforms: Optional[Callable] = None ) -> DataLoader: """ @@ -187,7 +187,7 @@ def train_dataloader( def val_dataloader( self, batch_size: int = 1, - transforms: List[Callable] = None, + transforms: Optional[List[Callable]] = None, image_transforms: Optional[Callable] = None ) -> DataLoader: """ From c8b76a5879b25140937f00aae809e69d2436fbfa Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Tue, 23 Mar 2021 18:20:34 +0200 Subject: [PATCH 36/61] Refactoring and documentation improvements * Synchronize validation and test step logging calls * Log losses to progress bar --- pl_bolts/models/detection/yolo/yolo_layers.py | 22 +-- pl_bolts/models/detection/yolo/yolo_module.py | 174 +++++++++++------- 2 files changed, 114 insertions(+), 82 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index bfd345666b..f5a44d3588 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -28,7 +28,7 @@ def _corner_coordinates(xy, wh): wh (Tensor): Width and height. Tensor of size ``[..., 2]``. Returns: - boxes (Tensor): A matrix of (x1, y1, x2, y2) coordinates. + boxes (Tensor): A matrix of `(x1, y1, x2, y2)` coordinates. """ half_wh = wh / 2 top_left = xy - half_wh @@ -125,7 +125,7 @@ def __init__( confidence_loss_func: Loss function for confidence score. Default is the sum of squared errors. image_space_loss: If set to ``True``, the overlap loss function will receive the bounding - box (x1, y1, x2, y2) coordinate normalized to the [0, 1] range. This is needed for + box `(x1, y1, x2, y2)` coordinate normalized to the `[0, 1]` range. This is needed for the IoU losses introduced in YOLOv4. Otherwise the loss will be computed from the x, y, width, and height values, as predicted by the network (i.e. relative to the anchor box, and width and height are logarithmic). @@ -160,7 +160,7 @@ def forward(self, x: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None) """ Runs a forward pass through this YOLO detection layer. - Maps cell-local coordinates to global coordinates in the [0, 1] range, scales the bounding + Maps cell-local coordinates to global coordinates in the `[0, 1]` range, scales the bounding boxes with the anchors, converts the center coordinates to corner coordinates, and maps probabilities to ]0, 1[ range using sigmoid. @@ -222,7 +222,7 @@ def _global_xy(self, xy): The predicted coordinates are interpreted as coordinates inside a grid cell whose width and height is 1. Adding offset to the cell and dividing by the grid size, we get global - coordinates in the [0, 1] range. + coordinates in the `[0, 1]` range. Args: xy (Tensor): The predicted center coordinates before scaling. Values from zero to one @@ -254,7 +254,7 @@ def _scale_wh(self, wh): Returns: result (Tensor): A tensor with the same shape as the input tensor, but scaled sizes - normalized to the [0, 1] range. + normalized to the `[0, 1]` range. """ image_size = torch.tensor([self.image_width, self.image_height], device=wh.device) anchor_wh = [self.anchor_dims[i] for i in self.anchor_ids] @@ -268,7 +268,7 @@ def _low_confidence_mask(self, boxes, targets): significantly (IoU greater than ``self.ignore_threshold``). Args: - boxes (Tensor): The predicted corner coordinates, normalized to the [0, 1] range. + boxes (Tensor): The predicted corner coordinates, normalized to the `[0, 1]` range. Tensor of size ``[batch_size, height, width, boxes_per_cell, 4]``. targets (List[Dict[str, Tensor]]): List of dictionaries of target values, one dictionary for each image. @@ -304,9 +304,9 @@ def _calculate_losses(self, boxes, confidence, classprob, targets, lc_mask): Args: boxes (Tensor): The predicted bounding boxes. A tensor sized ``[batch_size, height, width, boxes_per_cell, 4]``. - confidence (Tensor): The confidence predictions, normalized to [0, 1]. A tensor sized + confidence (Tensor): The confidence predictions, normalized to `[0, 1]`. A tensor sized ``[batch_size, height, width, boxes_per_cell]``. - classprob (Tensor): The class probability predictions, normalized to [0, 1]. A tensor + classprob (Tensor): The class probability predictions, normalized to `[0, 1]`. A tensor sized ``[batch_size, height, width, boxes_per_cell, num_classes]``. targets (List[Dict[str, Tensor]]): List of dictionaries of target values, one dictionary for each image. @@ -322,7 +322,7 @@ def _calculate_losses(self, boxes, confidence, classprob, targets, lc_mask): # Divisor for converting targets from image coordinates to feature map coordinates image_to_feature_map = torch.tensor([self.image_width / width, self.image_height / height], device=device) - # Divisor for converting targets from image coordinates to [0, 1] range + # Divisor for converting targets from image coordinates to `[0, 1]` range image_to_unit = torch.tensor([self.image_width, self.image_height], device=device) anchor_wh = torch.tensor(self.anchor_dims, dtype=boxes.dtype, device=device) @@ -344,7 +344,7 @@ def _calculate_losses(self, boxes, confidence, classprob, targets, lc_mask): continue # Bounding box corner coordinates are converted to center coordinates, width, and - # height, and normalized to [0, 1] range. + # height, and normalized to `[0, 1]` range. wh = target_boxes[:, 2:4] - target_boxes[:, 0:2] xy = target_boxes[:, 0:2] + (wh / 2) unit_xy = xy / image_to_unit @@ -381,7 +381,7 @@ def _calculate_losses(self, boxes, confidence, classprob, targets, lc_mask): # the target confidence. lc_mask[image_idx, cell_j, cell_i, predictors] = False - # IoU losses are calculated from the image space coordinates normalized to [0, 1] + # IoU losses are calculated from the image space coordinates normalized to `[0, 1]` # range. The squared-error loss is calculated from the raw predicted values. if self.image_space_loss: target_xy.append(unit_xy) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 7051e03549..2b0420f2d8 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -1,3 +1,4 @@ +import logging from typing import Any, Dict, List, Optional, Tuple, Type import numpy as np @@ -18,6 +19,8 @@ else: warn_missing_pkg('torchvision') +log = logging.getLogger(__name__) + class YOLO(pl.LightningModule): """ @@ -33,14 +36,31 @@ class YOLO(pl.LightningModule): The network architecture can be read from a Darknet configuration file using the :class:`~pl_bolts.models.detection.yolo.yolo_config.YoloConfiguration` class, or created by - some other means, and provided as a list of PyTorch modules. Supports loading weights from a - Darknet model file too, if you don't want to start training from a randomly initialized model. - During training, the model expects both the images (list of tensors), as well as targets (list - of dictionaries). + some other means, and provided as a list of PyTorch modules. + + The input from the data loader is expected to be a list of images. Each image is a tensor with + shape ``[channels, height, width]``. The images from each batch are concatenated into a single + tensor, so the sizes have to match. Different batches can have different image sizes, as long + as the size is divisible by the ratio in which the network downsamples the input. + + During training, the model expects both the input tensors and a list of targets. Each target is + a dictionary containing: + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in `(x1, y1, x2, y2)` format + - labels (``Int64Tensor[N]``): the class label for each ground-truth box - The target dictionaries should contain: - - boxes (``FloatTensor[N, 4]``): the ground truth boxes in ``[x1, y1, x2, y2]`` format. - - labels (``LongTensor[N]``): the class label for each ground truh box + ``forward()`` method returns all predictions from all detection layers in all images in one + tensor with shape ``[images, predictors, classes + 5]``. The coordinates are in the `[0, 1]` + range. During training it also returns a dictionary containing the classification, box overlap, + and confidence losses. + + During inference, the model requires only the input tensors. ``infer()`` method filters and + processes the predictions, producing the following tensors: + - boxes (``FloatTensor[N, 4]``): predicted bounding box `(x1, y1, x2, y2)` coordinates in image + space + - scores (``FloatTensor[N]``): detection confidences + - labels (``Int64Tensor[N]``): the predicted labels for each image + + Weights can be loaded from a Darknet model file using ``load_darknet_weights()``. CLI command:: @@ -114,11 +134,12 @@ def forward( dictionaries, one for each image. Returns: - boxes (Tensor), confidences (Tensor), classprobs (Tensor), losses (Dict[str, Tensor]): - Detections, and if targets were provided, a dictionary of losses. The first - dimension of the detections is the index of the image in the batch and the second - dimension is the detection within the image. ``boxes`` contains the predicted - (x1, y1, x2, y2) coordinates, normalized to [0, 1]. + detections (Tensor), losses (Dict[str, Tensor]): + Detections, and if targets were provided, a dictionary of losses. Detections are + shaped ``[batch_size, num_predictors, num_classes + 5]``, where ``num_predictors`` + is the total number of cells in all detection layers times the number of boxes + predicted by one cell. The predicted box coordinates are in `(x1, y1, x2, y2)` format + and normalized to `[0, 1]`. """ outputs = [] # Outputs from all layers detections = [] # Outputs from detection layers @@ -146,15 +167,11 @@ def mean_loss(loss_name): return torch.stack(loss_tuple).sum() / images.shape[0] detections = torch.cat(detections, 1) - boxes = detections[..., :4] - confidences = detections[..., 4] - classprobs = detections[..., 5:] - if targets is None: - return boxes, confidences, classprobs + return detections losses = {loss_name: mean_loss(loss_name) for loss_name in losses[0].keys()} - return boxes, confidences, classprobs, losses + return detections, losses def configure_optimizers(self) -> Tuple[List, List]: """Constructs the optimizer and learning rate scheduler.""" @@ -175,11 +192,11 @@ def training_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], bat A dictionary that includes the training loss in 'loss'. """ images, targets = self._validate_batch(batch) - _, _, _, losses = self(images, targets) + _, losses = self(images, targets) total_loss = torch.stack(tuple(losses.values())).sum() for name, value in losses.items(): - self.log('train/{}_loss'.format(name), value) + self.log(f'train/{name}_loss', value, prog_bar=True) self.log('train/total_loss', total_loss) return {'loss': total_loss} @@ -194,14 +211,14 @@ def validation_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], b batch_idx: The index of this batch """ images, targets = self._validate_batch(batch) - boxes, confidences, classprobs, losses = self(images, targets) - classprobs, labels = torch.max(classprobs, -1) - boxes, confidences, classprobs, labels = self._filter_detections(boxes, confidences, classprobs, labels) + detections, losses = self(images, targets) + detections = self._split_detections(detections) + detections = self._filter_detections(detections) total_loss = torch.stack(tuple(losses.values())).sum() for name, value in losses.items(): - self.log('val/{}_loss'.format(name), value) - self.log('val/total_loss', total_loss) + self.log(f'val/{name}_loss', value, sync_dist=True) + self.log('val/total_loss', total_loss, sync_dist=True) def test_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int) -> Dict[str, Tensor]: """ @@ -213,14 +230,14 @@ def test_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_i batch_idx: The index of this batch. """ images, targets = self._validate_batch(batch) - boxes, confidences, classprobs, losses = self(images, targets) - classprobs, labels = torch.max(classprobs, -1) - boxes, confidences, classprobs, labels = self._filter_detections(boxes, confidences, classprobs, labels) + detections, losses = self(images, targets) + detections = self._split_detections(detections) + detections = self._filter_detections(detections) total_loss = torch.stack(tuple(losses.values())).sum() for name, value in losses.items(): - self.log('test/{}_loss'.format(name), value) - self.log('test/total_loss', total_loss) + self.log(f'test/{name}_loss', value, sync_dist=True) + self.log('test/total_loss', total_loss, sync_dist=True) def infer(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """ @@ -232,27 +249,26 @@ def infer(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: Returns: boxes (:class:`~torch.Tensor`), confidences (:class:`~torch.Tensor`), labels (:class:`~torch.Tensor`): - A matrix of detected bounding box (x1, y1, x2, y2) coordinates, a vector of + A matrix of detected bounding box `(x1, y1, x2, y2)` coordinates, a vector of confidences for the bounding box detections, and a vector of predicted class labels. """ network_input = image.float().div(255.0) network_input = network_input.unsqueeze(0) self.eval() - boxes, confidences, classprobs = self(network_input) - classprobs, labels = torch.max(classprobs, -1) - boxes, confidences, classprobs, labels = self._filter_detections(boxes, confidences, classprobs, labels) - assert len(boxes) == 1 - boxes = boxes[0] - confidences = confidences[0] - labels = labels[0] + detections = self(network_input) + detections = self._split_detections(detections) + detections = self._filter_detections(detections) + boxes = detections['boxes'][0] + scores = detections['scores'][0] + labels = detections['labels'][0] height = image.shape[1] width = image.shape[2] scale = torch.tensor([width, height, width, height], device=boxes.device) boxes = boxes * scale boxes = torch.round(boxes).int() - return boxes, confidences, labels + return boxes, scores, labels def load_darknet_weights(self, weight_file): """ @@ -270,9 +286,9 @@ def load_darknet_weights(self, weight_file): """ version = np.fromfile(weight_file, count=3, dtype=np.int32) images_seen = np.fromfile(weight_file, count=1, dtype=np.int64) - print( - 'Loading weights from Darknet model version {}.{}.{} that has been trained on {} ' - 'images.'.format(version[0], version[1], version[2], images_seen[0]) + log.info( + 'Loading weights from Darknet model version %d.%d.%d that has been trained on %d ' + 'images.', version[0], version[1], version[2], images_seen[0] ) def read(tensor): @@ -349,47 +365,63 @@ def _validate_batch( images = torch.stack(images) return images, targets - def _filter_detections( - self, - boxes: Tensor, - confidences: Tensor, - classprobs: Tensor, - labels: Tensor - ) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]: + def _split_detections(self, detections: Tensor) -> Dict[str, Tensor]: + """ + Splits the detection tensor returned by a forward pass into a dictionary. + + The fields of the dictionary are as follows: + - boxes (``Tensor[batch_size, N, 4]``): detected bounding box `(x1, y1, x2, y2)` coordinates + - scores (``Tensor[batch_size, N]``): detection confidences + - classprobs (``Tensor[batch_size, N]``): probabilities of the best classes + - labels (``Int64Tensor[batch_size, N]``): the predicted labels for each image + + Args: + detections: A tensor of detected bounding boxes and their attributes. + + Returns: + A dictionary of detection results. + """ + boxes = detections[..., :4] + scores = detections[..., 4] + classprobs = detections[..., 5:] + classprobs, labels = torch.max(classprobs, -1) + return {'boxes': boxes, 'scores': scores, 'classprobs': classprobs, 'labels': labels} + + def _filter_detections(self, detections: Dict[str, Tensor]) -> Dict[str, List[Tensor]]: """ Filters detections based on confidence threshold. Then for every class performs non-maximum suppression (NMS). NMS iterates the bounding boxes that predict this class in descending - order of confidence score, and removes the bounding box, if its IoU with the next one is - higher than the NMS threshold. + order of confidence score, and removes lower scoring boxes that have an IoU greater than + the NMS threshold with a higher scoring box. Finally the detections are sorted by descending + confidence and possible truncated to the maximum number of predictions. Args: - boxes: Detected bounding box (x1, y1, x2, y2) coordinates in a tensor sized - ``[batch_size, N, 4]``. - confidences: Detection confidences in a tensor sized ``[batch_size, N]``. - classprobs: Probabilities of the best classes in a tensor sized ``[batch_size, N]``. - labels: Indices of the best classes in a tensor sized ``[batch_size, N]``. + detections: All detections. A dictionary of tensors, each containing the predictions + from all images. Returns: - boxes (List[Tensor]), confidences (List[Tensor]), classprobs (List[Tensor]), labels (List[Tensor]): - Four lists, each containing one tensor per image - bounding box (x1, y1, x2, y2) - coordinates, detection confidences, probabilities of the best class of each - prediction, and the predicted class labels. + Filtered detections. A dictionary of lists, each containing a tensor per image. """ + boxes = detections['boxes'] + scores = detections['scores'] + classprobs = detections['classprobs'] + labels = detections['labels'] + out_boxes = [] - out_confidences = [] + out_scores = [] out_classprobs = [] out_labels = [] - for img_boxes, img_confidences, img_classprobs, img_labels in zip(boxes, confidences, classprobs, labels): + for img_boxes, img_scores, img_classprobs, img_labels in zip(boxes, scores, classprobs, labels): # Select detections with high confidence score. - selected = img_confidences > self.confidence_threshold + selected = img_scores > self.confidence_threshold img_boxes = img_boxes[selected] - img_confidences = img_confidences[selected] + img_scores = img_scores[selected] img_classprobs = img_classprobs[selected] img_labels = img_labels[selected] img_out_boxes = boxes.new_zeros((0, 4)) - img_out_confidences = confidences.new_zeros(0) + img_out_scores = scores.new_zeros(0) img_out_classprobs = classprobs.new_zeros(0) img_out_labels = labels.new_zeros(0) @@ -398,26 +430,26 @@ def _filter_detections( for cls_label in labels.unique(): selected = img_labels == cls_label cls_boxes = img_boxes[selected] - cls_confidences = img_confidences[selected] + cls_scores = img_scores[selected] cls_classprobs = img_classprobs[selected] cls_labels = img_labels[selected] - selected = nms(cls_boxes, cls_confidences, self.nms_threshold) + selected = nms(cls_boxes, cls_scores, self.nms_threshold) img_out_boxes = torch.cat((img_out_boxes, cls_boxes[selected])) - img_out_confidences = torch.cat((img_out_confidences, cls_confidences[selected])) + img_out_scores = torch.cat((img_out_scores, cls_scores[selected])) img_out_classprobs = torch.cat((img_out_classprobs, cls_classprobs[selected])) img_out_labels = torch.cat((img_out_labels, cls_labels[selected])) # Sort by descending confidence and limit the maximum number of predictions. - indices = torch.argsort(img_out_confidences, descending=True) + indices = torch.argsort(img_out_scores, descending=True) if self.max_predictions_per_image >= 0: indices = indices[:self.max_predictions_per_image] out_boxes.append(img_out_boxes[indices]) - out_confidences.append(img_out_confidences[indices]) + out_scores.append(img_out_scores[indices]) out_classprobs.append(img_out_classprobs[indices]) out_labels.append(img_out_labels[indices]) - return out_boxes, out_confidences, out_classprobs, out_labels + return {'boxes': out_boxes, 'scores': out_scores, 'classprobs': out_classprobs, 'labels': out_labels} class Resize: From 71a4c3c3fc46861852c2fe6a0e046a1a649efb67 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 24 Mar 2021 14:11:30 +0200 Subject: [PATCH 37/61] Fixed documentation formatting --- pl_bolts/models/detection/yolo/yolo_config.py | 16 ++-- pl_bolts/models/detection/yolo/yolo_layers.py | 77 ++++++++--------- pl_bolts/models/detection/yolo/yolo_module.py | 86 ++++++++++--------- 3 files changed, 91 insertions(+), 88 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index 9d2814aa61..fc8cd6aba4 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -1,5 +1,5 @@ import re -from typing import List, Tuple +from typing import Any, Dict, Iterable, List, Tuple from warnings import warn import torch.nn as nn @@ -11,8 +11,8 @@ class YOLOConfiguration: """ This class can be used to parse the configuration files of the Darknet YOLOv4 implementation. - The ``get_network()`` method returns a PyTorch module list that can be used to construct a YOLO - model. + The :func:`~pl_bolts.models.detection.yolo.yolo_config.YOLOConfiguration.get_network` method + returns a PyTorch module list that can be used to construct a YOLO model. """ def __init__(self, path: str) -> None: @@ -39,7 +39,7 @@ def get_network(self) -> nn.ModuleList: modules. Returns the network structure that can be used to create a YOLO model. Returns: - modules: A ``nn.ModuleList`` that defines the YOLO network. + A :class:`~torch.nn.ModuleList` that defines the YOLO network. """ result = nn.ModuleList() num_inputs = [3] # Number of channels in the input of every layer up to the current layer @@ -50,15 +50,15 @@ def get_network(self) -> nn.ModuleList: num_inputs.append(num_outputs) return result - def _read_file(self, config_file): + def _read_file(self, config_file: Iterable[str]) -> List[Dict[str, Any]]: """ Reads a YOLOv4 network configuration file and returns a list of configuration sections. Args: - config_file (iterable over lines): The configuration file to read. + config_file: The configuration file to read. Returns: - sections (List[dict]): A list of configuration sections. + A list of configuration sections. """ section_re = re.compile(r'\[([^]]+)\]') list_variables = ('layers', 'anchors', 'mask', 'scales') @@ -158,7 +158,7 @@ def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]: Returns: module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the - number of channels in its output. + number of channels in its output. """ create_func = { 'convolutional': _create_convolutional, diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index f5a44d3588..4978bef1b9 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -19,16 +19,16 @@ warn_missing_pkg('torchvision') -def _corner_coordinates(xy, wh): +def _corner_coordinates(xy: Tensor, wh: Tensor) -> Tensor: """ Converts box center points and sizes to corner coordinates. Args: - xy (Tensor): Center coordinates. Tensor of size ``[..., 2]``. - wh (Tensor): Width and height. Tensor of size ``[..., 2]``. + xy: Center coordinates. Tensor of size ``[..., 2]``. + wh: Width and height. Tensor of size ``[..., 2]``. Returns: - boxes (Tensor): A matrix of `(x1, y1, x2, y2)` coordinates. + A matrix of `(x1, y1, x2, y2)` coordinates. """ half_wh = wh / 2 top_left = xy - half_wh @@ -36,18 +36,18 @@ def _corner_coordinates(xy, wh): return torch.cat((top_left, bottom_right), -1) -def _aligned_iou(dims1, dims2): +def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor: """ Calculates a matrix of intersections over union from box dimensions, assuming that the boxes are located at the same coordinates. Args: - dims1 (Tensor[N, 2]): width and height of N boxes - dims2 (Tensor[M, 2]): width and height of M boxes + dims1: Width and height of `N` boxes. Tensor of size ``[N, 2]``. + dims2: Width and height of `M` boxes. Tensor of size ``[M, 2]``. Returns: - iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in - ``dims1`` and ``dims2`` + Tensor of size ``[N, M]`` containing the pairwise IoU values for every element in + ``dims1`` and ``dims2`` """ area1 = dims1[:, 0] * dims1[:, 1] # [N] area2 = dims2[:, 0] * dims2[:, 1] # [M] @@ -162,7 +162,7 @@ def forward(self, x: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None) Maps cell-local coordinates to global coordinates in the `[0, 1]` range, scales the bounding boxes with the anchors, converts the center coordinates to corner coordinates, and maps - probabilities to ]0, 1[ range using sigmoid. + probabilities to the `]0, 1[` range using sigmoid. Args: x: The output from the previous layer. Tensor of size @@ -172,8 +172,8 @@ def forward(self, x: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None) Returns: output (Tensor), losses (Dict[str, Tensor]): Layer output, and if training targets were - provided, a dictionary of losses. Layer output is sized - ``[batch_size, num_anchors * height * width, num_classes + 5]``. + provided, a dictionary of losses. Layer output is sized + ``[batch_size, num_anchors * height * width, num_classes + 5]``. """ batch_size, num_features, height, width = x.shape num_attrs = self.num_classes + 5 @@ -215,7 +215,7 @@ def forward(self, x: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None) losses = self._calculate_losses(boxes, confidence, classprob, targets, lc_mask) return output, losses - def _global_xy(self, xy): + def _global_xy(self, xy: Tensor) -> Tensor: """ Adds offsets to the predicted box center coordinates to obtain global coordinates to the image. @@ -225,12 +225,12 @@ def _global_xy(self, xy): coordinates in the `[0, 1]` range. Args: - xy (Tensor): The predicted center coordinates before scaling. Values from zero to one - in a tensor sized ``[batch_size, height, width, boxes_per_cell, 2]``. + xy: The predicted center coordinates before scaling. Values from zero to one in a + tensor sized ``[batch_size, height, width, boxes_per_cell, 2]``. Returns: - result (Tensor): Global coordinates from zero to one, in a tensor with the same shape - as the input tensor. + Global coordinates in the `[0, 1]` range, in a tensor with the same shape as the input + tensor. """ height = xy.shape[1] width = xy.shape[2] @@ -244,39 +244,37 @@ def _global_xy(self, xy): return (xy + offset) / grid_size - def _scale_wh(self, wh): + def _scale_wh(self, wh: Tensor) -> Tensor: """ Scales the box size predictions by the prior dimensions from the anchors. Args: - wh (Tensor): The unnormalized width and height predictions. Tensor of size + wh: The unnormalized width and height predictions. Tensor of size ``[..., boxes_per_cell, 2]``. Returns: - result (Tensor): A tensor with the same shape as the input tensor, but scaled sizes - normalized to the `[0, 1]` range. + A tensor with the same shape as the input tensor, but scaled sizes normalized to the + `[0, 1]` range. """ image_size = torch.tensor([self.image_width, self.image_height], device=wh.device) anchor_wh = [self.anchor_dims[i] for i in self.anchor_ids] anchor_wh = torch.tensor(anchor_wh, dtype=wh.dtype, device=wh.device) return torch.exp(wh) * anchor_wh / image_size - def _low_confidence_mask(self, boxes, targets): + def _low_confidence_mask(self, boxes: Tensor, targets: List[Dict[str, Tensor]]) -> Tensor: """ Initializes the mask that will be used to select predictors that are not predicting any ground-truth target. The value will be ``True``, unless the predicted box overlaps any target significantly (IoU greater than ``self.ignore_threshold``). Args: - boxes (Tensor): The predicted corner coordinates, normalized to the `[0, 1]` range. - Tensor of size ``[batch_size, height, width, boxes_per_cell, 4]``. - targets (List[Dict[str, Tensor]]): List of dictionaries of target values, one - dictionary for each image. + boxes: The predicted corner coordinates, normalized to the `[0, 1]` range. Tensor of + size ``[batch_size, height, width, boxes_per_cell, 4]``. + targets: List of dictionaries of ground-truth targets, one dictionary per image. Returns: - results (Tensor): A boolean tensor shaped ``[batch_size, height, width, boxes_per_cell]`` - with ``False`` where the predicted box overlaps a target significantly and ``True`` - elsewhere. + A boolean tensor shaped ``[batch_size, height, width, boxes_per_cell]`` with ``False`` + where the predicted box overlaps a target significantly and ``True`` elsewhere. """ batch_size, height, width, boxes_per_cell, num_coords = boxes.shape num_preds = height * width * boxes_per_cell @@ -296,25 +294,26 @@ def _low_confidence_mask(self, boxes, targets): return results.view((batch_size, height, width, boxes_per_cell)) - def _calculate_losses(self, boxes, confidence, classprob, targets, lc_mask): + def _calculate_losses( + self, boxes: Tensor, confidence: Tensor, classprob: Tensor, targets: List[Dict[str, Tensor]], lc_mask: Tensor + ) -> Dict[str, Tensor]: """ From the targets that are in the image space calculates the actual targets for the network predictions, and returns a dictionary of training losses. Args: - boxes (Tensor): The predicted bounding boxes. A tensor sized + boxes: The predicted bounding boxes. A tensor sized ``[batch_size, height, width, boxes_per_cell, 4]``. - confidence (Tensor): The confidence predictions, normalized to `[0, 1]`. A tensor sized + confidence: The confidence predictions, normalized to `[0, 1]`. A tensor sized ``[batch_size, height, width, boxes_per_cell]``. - classprob (Tensor): The class probability predictions, normalized to `[0, 1]`. A tensor - sized ``[batch_size, height, width, boxes_per_cell, num_classes]``. - targets (List[Dict[str, Tensor]]): List of dictionaries of target values, one - dictionary for each image. - lc_mask (Tensor): A boolean mask containing ``True`` where the predicted box does not - overlap any target significantly. + classprob: The class probability predictions, normalized to `[0, 1]`. A tensor sized + ``[batch_size, height, width, boxes_per_cell, num_classes]``. + targets: List of dictionaries of target values, one dictionary for each image. + lc_mask: A boolean mask containing ``True`` where the predicted box does not overlap + any target significantly. Returns: - predicted (Dict[str, Tensor]): A dictionary of training losses. + A dictionary of training losses. """ batch_size, height, width, boxes_per_cell, _ = boxes.shape device = boxes.device diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 2b0420f2d8..653e8e4f1f 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -27,38 +27,42 @@ class YOLO(pl.LightningModule): PyTorch Lightning implementation of `YOLOv3 `_ with some improvements from `YOLOv4 `_. - YOLOv3 paper authors: Joseph Redmon and Ali Farhadi + *YOLOv3 paper authors*: Joseph Redmon and Ali Farhadi - YOLOv4 paper authors: Alexey Bochkovskiy, Chien-Yao Wang, Hong-Yuan Mark Liao + *YOLOv4 paper authors*: Alexey Bochkovskiy, Chien-Yao Wang, Hong-Yuan Mark Liao - Model implemented by: - - `Seppo Enarvi `_ + *Model implemented by*: + + - `Seppo Enarvi `_ The network architecture can be read from a Darknet configuration file using the - :class:`~pl_bolts.models.detection.yolo.yolo_config.YoloConfiguration` class, or created by + :class:`~pl_bolts.models.detection.yolo.yolo_config.YOLOConfiguration` class, or created by some other means, and provided as a list of PyTorch modules. The input from the data loader is expected to be a list of images. Each image is a tensor with - shape ``[channels, height, width]``. The images from each batch are concatenated into a single - tensor, so the sizes have to match. Different batches can have different image sizes, as long - as the size is divisible by the ratio in which the network downsamples the input. - - During training, the model expects both the input tensors and a list of targets. Each target is - a dictionary containing: - - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in `(x1, y1, x2, y2)` format - - labels (``Int64Tensor[N]``): the class label for each ground-truth box - - ``forward()`` method returns all predictions from all detection layers in all images in one - tensor with shape ``[images, predictors, classes + 5]``. The coordinates are in the `[0, 1]` - range. During training it also returns a dictionary containing the classification, box overlap, - and confidence losses. - - During inference, the model requires only the input tensors. ``infer()`` method filters and - processes the predictions, producing the following tensors: - - boxes (``FloatTensor[N, 4]``): predicted bounding box `(x1, y1, x2, y2)` coordinates in image - space - - scores (``FloatTensor[N]``): detection confidences - - labels (``Int64Tensor[N]``): the predicted labels for each image + shape ``[channels, height, width]``. The images from a single batch will be stacked into a + single tensor, so the sizes have to match. Different batches can have different image sizes, as + long as the size is divisible by the ratio in which the network downsamples the input. + + During training, the model expects both the input tensors and a list of targets. *Each target is + a dictionary containing*: + + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in `(x1, y1, x2, y2)` format + - labels (``Int64Tensor[N]``): the class label for each ground-truth box + + :func:`~pl_bolts.models.detection.yolo.yolo_module.YOLO.forward` method returns all + predictions from all detection layers in all images in one tensor with shape + ``[images, predictors, classes + 5]``. The coordinates are in the `[0, 1]` range. During + training it also returns a dictionary containing the classification, box overlap, and + confidence losses. + + During inference, the model requires only the input tensors. + :func:`~pl_bolts.models.detection.yolo.yolo_module.YOLO.infer` method filters and processes the + predictions. *The processed output includes the following tensors*: + + - boxes (``FloatTensor[N, 4]``): predicted bounding box `(x1, y1, x2, y2)` coordinates in image space + - scores (``FloatTensor[N]``): detection confidences + - labels (``Int64Tensor[N]``): the predicted labels for each image Weights can be loaded from a Darknet model file using ``load_darknet_weights()``. @@ -73,9 +77,9 @@ def __init__( self, network: nn.ModuleList, optimizer: Type[optim.Optimizer] = optim.SGD, - optimizer_params: Dict[str, Any] = {'lr': 0.0013, 'momentum': 0.9, 'weight_decay': 0.0005}, + optimizer_params: Dict[str, Any] = {'lr': 0.001, 'momentum': 0.9, 'weight_decay': 0.0005}, lr_scheduler: Type[optim.lr_scheduler._LRScheduler] = LinearWarmupCosineAnnealingLR, - lr_scheduler_params: Dict[str, Any] = {'warmup_epochs': 1, 'max_epochs': 271, 'warmup_start_lr': 0.0}, + lr_scheduler_params: Dict[str, Any] = {'warmup_epochs': 1, 'max_epochs': 300, 'warmup_start_lr': 0.0}, confidence_threshold: float = 0.2, nms_threshold: float = 0.45, max_predictions_per_image: int = -1 @@ -83,7 +87,8 @@ def __init__( """ Args: network: A list of network modules. This can be obtained from a Darknet configuration - using the ``YoloConfiguration.get_network()`` method. + using the :func:`~pl_bolts.models.detection.yolo.yolo_config.YOLOConfiguration.get_network` + method. optimizer: Which optimizer class to use for training. optimizer_params: Parameters to pass to the optimizer constructor. lr_scheduler: Which learning rate scheduler class to use for training. @@ -115,7 +120,7 @@ def forward( self, images: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None - ) -> Tuple[Tensor, Tensor, Tensor, Dict[str, Tensor]]: + ) -> Tuple[Tensor, Dict[str, Tensor]]: """ Runs a forward pass through the network (all layers listed in ``self.network``), and if training targets are provided, computes the losses from the detection layers. @@ -134,12 +139,12 @@ def forward( dictionaries, one for each image. Returns: - detections (Tensor), losses (Dict[str, Tensor]): - Detections, and if targets were provided, a dictionary of losses. Detections are - shaped ``[batch_size, num_predictors, num_classes + 5]``, where ``num_predictors`` - is the total number of cells in all detection layers times the number of boxes - predicted by one cell. The predicted box coordinates are in `(x1, y1, x2, y2)` format - and normalized to `[0, 1]`. + detections (:class:`~torch.Tensor`), losses (Dict[str, :class:`~torch.Tensor`]): + Detections, and if targets were provided, a dictionary of losses. Detections are shaped + ``[batch_size, num_predictors, num_classes + 5]``, where ``num_predictors`` is the + total number of cells in all detection layers times the number of boxes predicted by + one cell. The predicted box coordinates are in `(x1, y1, x2, y2)` format and normalized + to `[0, 1]`. """ outputs = [] # Outputs from all layers detections = [] # Outputs from detection layers @@ -201,7 +206,7 @@ def training_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], bat return {'loss': total_loss} - def validation_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int) -> Dict[str, Tensor]: + def validation_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int): """ Evaluates a batch of data from the validation set. @@ -220,7 +225,7 @@ def validation_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], b self.log(f'val/{name}_loss', value, sync_dist=True) self.log('val/total_loss', total_loss, sync_dist=True) - def test_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int) -> Dict[str, Tensor]: + def test_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int): """ Evaluates a batch of data from the test set. @@ -249,9 +254,8 @@ def infer(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: Returns: boxes (:class:`~torch.Tensor`), confidences (:class:`~torch.Tensor`), labels (:class:`~torch.Tensor`): - A matrix of detected bounding box `(x1, y1, x2, y2)` coordinates, a vector of - confidences for the bounding box detections, and a vector of predicted class - labels. + A matrix of detected bounding box `(x1, y1, x2, y2)` coordinates, a vector of + confidences for the bounding box detections, and a vector of predicted class labels. """ network_input = image.float().div(255.0) network_input = network_input.unsqueeze(0) @@ -335,7 +339,7 @@ def _validate_batch( batch: The batch of data read by the :class:`~torch.utils.data.DataLoader`. Returns: - batch: The input batch with images stacked into a single tensor. + The input batch with images stacked into a single tensor. """ images, targets = batch From 70f14b099340814d32fb159ef3c1eafb15929410 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 31 Mar 2021 13:37:12 +0300 Subject: [PATCH 38/61] Coordinate predictions are in image scale --- pl_bolts/models/detection/yolo/yolo_layers.py | 81 +++++++++---------- pl_bolts/models/detection/yolo/yolo_module.py | 19 ++--- 2 files changed, 46 insertions(+), 54 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 4978bef1b9..64258d038a 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -125,10 +125,10 @@ def __init__( confidence_loss_func: Loss function for confidence score. Default is the sum of squared errors. image_space_loss: If set to ``True``, the overlap loss function will receive the bounding - box `(x1, y1, x2, y2)` coordinate normalized to the `[0, 1]` range. This is needed for - the IoU losses introduced in YOLOv4. Otherwise the loss will be computed from the x, - y, width, and height values, as predicted by the network (i.e. relative to the - anchor box, and width and height are logarithmic). + box `(x1, y1, x2, y2)` coordinates, scaled to the input image size. This is needed + for the IoU losses introduced in YOLOv4. Otherwise the loss will be computed from + the x, y, width, and height values, as predicted by the network (i.e. relative to + the anchor box, and width and height are logarithmic). coord_loss_multiplier: Multiply the coordinate/size loss by this factor. class_loss_multiplier: Multiply the classification loss by this factor. confidence_loss_multiplier: Multiply the confidence loss by this factor. @@ -160,10 +160,12 @@ def forward(self, x: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None) """ Runs a forward pass through this YOLO detection layer. - Maps cell-local coordinates to global coordinates in the `[0, 1]` range, scales the bounding + Maps cell-local coordinates to global coordinates in the image space, scales the bounding boxes with the anchors, converts the center coordinates to corner coordinates, and maps probabilities to the `]0, 1[` range using sigmoid. + If targets are given, computes also losses from the predictions and the targets. + Args: x: The output from the previous layer. Tensor of size ``[batch_size, boxes_per_cell * (num_classes + 5), height, width]``. @@ -221,16 +223,16 @@ def _global_xy(self, xy: Tensor) -> Tensor: image. The predicted coordinates are interpreted as coordinates inside a grid cell whose width and - height is 1. Adding offset to the cell and dividing by the grid size, we get global - coordinates in the `[0, 1]` range. + height is 1. Adding offset to the cell, dividing by the grid size, and multiplying by the + image size, we get global coordinates in the image scale. Args: xy: The predicted center coordinates before scaling. Values from zero to one in a tensor sized ``[batch_size, height, width, boxes_per_cell, 2]``. Returns: - Global coordinates in the `[0, 1]` range, in a tensor with the same shape as the input - tensor. + Global coordinates scaled to the size of the network input image, in a tensor with the + same shape as the input tensor. """ height = xy.shape[1] width = xy.shape[2] @@ -242,7 +244,9 @@ def _global_xy(self, xy: Tensor) -> Tensor: offset = torch.stack((grid_x, grid_y), -1) # [height, width, 2] offset = offset.unsqueeze(2) # [height, width, 1, 2] - return (xy + offset) / grid_size + image_size = torch.tensor([self.image_width, self.image_height], device=xy.device) + scale = image_size / grid_size + return (xy + offset) * scale def _scale_wh(self, wh: Tensor) -> Tensor: """ @@ -253,13 +257,12 @@ def _scale_wh(self, wh: Tensor) -> Tensor: ``[..., boxes_per_cell, 2]``. Returns: - A tensor with the same shape as the input tensor, but scaled sizes normalized to the - `[0, 1]` range. + A tensor with the same shape as the input tensor, containing final width and height in + the image space. """ - image_size = torch.tensor([self.image_width, self.image_height], device=wh.device) anchor_wh = [self.anchor_dims[i] for i in self.anchor_ids] anchor_wh = torch.tensor(anchor_wh, dtype=wh.dtype, device=wh.device) - return torch.exp(wh) * anchor_wh / image_size + return torch.exp(wh) * anchor_wh def _low_confidence_mask(self, boxes: Tensor, targets: List[Dict[str, Tensor]]) -> Tensor: """ @@ -268,8 +271,8 @@ def _low_confidence_mask(self, boxes: Tensor, targets: List[Dict[str, Tensor]]) significantly (IoU greater than ``self.ignore_threshold``). Args: - boxes: The predicted corner coordinates, normalized to the `[0, 1]` range. Tensor of - size ``[batch_size, height, width, boxes_per_cell, 4]``. + boxes: The predicted corner coordinates in the image space. Tensor of size + ``[batch_size, height, width, boxes_per_cell, 4]``. targets: List of dictionaries of ground-truth targets, one dictionary per image. Returns: @@ -280,10 +283,6 @@ def _low_confidence_mask(self, boxes: Tensor, targets: List[Dict[str, Tensor]]) num_preds = height * width * boxes_per_cell boxes = boxes.view(batch_size, num_preds, num_coords) - scale = torch.tensor([self.image_width, self.image_height, self.image_width, self.image_height], - device=boxes.device) - boxes = boxes * scale - results = torch.ones((batch_size, num_preds), dtype=torch.bool, device=boxes.device) for image_idx, (image_boxes, image_targets) in enumerate(zip(boxes, targets)): target_boxes = image_targets['boxes'] @@ -319,10 +318,10 @@ def _calculate_losses( device = boxes.device assert batch_size == len(targets) - # Divisor for converting targets from image coordinates to feature map coordinates - image_to_feature_map = torch.tensor([self.image_width / width, self.image_height / height], device=device) - # Divisor for converting targets from image coordinates to `[0, 1]` range - image_to_unit = torch.tensor([self.image_width, self.image_height], device=device) + image_size = torch.tensor([self.image_width, self.image_height], device=device) + grid_size = torch.tensor([width, height], device=device) + # For scaling image coordinates to feature map coordinates + image_to_grid = grid_size / image_size anchor_wh = torch.tensor(self.anchor_dims, dtype=boxes.dtype, device=device) anchor_map = torch.tensor(self.anchor_map, dtype=torch.int64, device=device) @@ -343,17 +342,15 @@ def _calculate_losses( continue # Bounding box corner coordinates are converted to center coordinates, width, and - # height, and normalized to `[0, 1]` range. + # height. wh = target_boxes[:, 2:4] - target_boxes[:, 0:2] xy = target_boxes[:, 0:2] + (wh / 2) - unit_xy = xy / image_to_unit - unit_wh = wh / image_to_unit # The center coordinates are converted to the feature map dimensions so that the whole # number tells the cell index and the fractional part tells the location inside the cell. - xy = xy / image_to_feature_map - cell_i = xy[:, 0].to(torch.int64).clamp(0, width - 1) - cell_j = xy[:, 1].to(torch.int64).clamp(0, height - 1) + grid_xy = xy * image_to_grid + cell_i = grid_xy[:, 0].to(torch.int64).clamp(0, width - 1) + cell_j = grid_xy[:, 1].to(torch.int64).clamp(0, height - 1) # We want to know which anchor box overlaps a ground truth box more than any other # anchor box. We know that the anchor box is located in the same grid cell as the @@ -368,33 +365,33 @@ def _calculate_losses( # another layer. predictors = anchor_map[best_anchors] selected = predictors >= 0 - unit_xy = unit_xy[selected] - unit_wh = unit_wh[selected] cell_i = cell_i[selected] cell_j = cell_j[selected] predictors = predictors[selected] - best_anchors = best_anchors[selected] + wh = wh[selected] # The "low-confidence" mask is used to select predictors that are not responsible for # predicting any object, for calculating the part of the confidence loss with zero as # the target confidence. lc_mask[image_idx, cell_j, cell_i, predictors] = False - # IoU losses are calculated from the image space coordinates normalized to `[0, 1]` - # range. The squared-error loss is calculated from the raw predicted values. + # IoU losses are calculated from the image space coordinates. The squared-error loss is + # calculated from the raw predicted values. if self.image_space_loss: - target_xy.append(unit_xy) - target_wh.append(unit_wh) - else: xy = xy[selected] - wh = wh[selected] - relative_xy = xy - xy.floor() + target_xy.append(xy) + target_wh.append(wh) + else: + grid_xy = grid_xy[selected] + best_anchors = best_anchors[selected] + relative_xy = grid_xy - grid_xy.floor() relative_wh = torch.log(wh / anchor_wh[best_anchors] + 1e-16) target_xy.append(relative_xy) target_wh.append(relative_wh) - # Size compensation factor for bounding box overlap loss is calculated from image space - # width and height. + # Size compensation factor for bounding box overlap loss is calculated from unit width + # and height. + unit_wh = wh / image_size size_compensation.append(2 - (unit_wh[:, 0] * unit_wh[:, 1])) # The data may contain a different number of classes than this detection layer. In case diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 653e8e4f1f..705e238b6a 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -7,7 +7,6 @@ import torch.nn as nn from torch import optim, Tensor -from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration from pl_bolts.models.detection.yolo.yolo_layers import DetectionLayer, RouteLayer, ShortcutLayer from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR from pl_bolts.utils import _TORCHVISION_AVAILABLE @@ -52,8 +51,8 @@ class YOLO(pl.LightningModule): :func:`~pl_bolts.models.detection.yolo.yolo_module.YOLO.forward` method returns all predictions from all detection layers in all images in one tensor with shape - ``[images, predictors, classes + 5]``. The coordinates are in the `[0, 1]` range. During - training it also returns a dictionary containing the classification, box overlap, and + ``[images, predictors, classes + 5]``. The coordinates are scaled to the input image size. + During training it also returns a dictionary containing the classification, box overlap, and confidence losses. During inference, the model requires only the input tensors. @@ -143,8 +142,8 @@ def forward( Detections, and if targets were provided, a dictionary of losses. Detections are shaped ``[batch_size, num_predictors, num_classes + 5]``, where ``num_predictors`` is the total number of cells in all detection layers times the number of boxes predicted by - one cell. The predicted box coordinates are in `(x1, y1, x2, y2)` format and normalized - to `[0, 1]`. + one cell. The predicted box coordinates are in `(x1, y1, x2, y2)` format and scaled to + the input image size. """ outputs = [] # Outputs from all layers detections = [] # Outputs from detection layers @@ -246,8 +245,8 @@ def test_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_i def infer(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """ - Resizes given image to the network input size and feeds it to the network. Returns the - detected bounding boxes, confidences, and class labels. + Feeds an image to the network and returns the detected bounding boxes, confidence scores, + and class labels. Args: image: An input image, a tensor of uint8 values sized ``[channels, height, width]``. @@ -266,11 +265,6 @@ def infer(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: boxes = detections['boxes'][0] scores = detections['scores'][0] labels = detections['labels'][0] - - height = image.shape[1] - width = image.shape[2] - scale = torch.tensor([width, height, width, height], device=boxes.device) - boxes = boxes * scale boxes = torch.round(boxes).int() return boxes, scores, labels @@ -490,6 +484,7 @@ def run_cli(): from argparse import ArgumentParser from pl_bolts.datamodules import VOCDetectionDataModule + from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration pl.seed_everything(42) From 9587b4601be16eaf9c1c4df1bb5a3fe0402a74d3 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Thu, 1 Apr 2021 12:07:04 +0300 Subject: [PATCH 39/61] Use default dtype for torch.arange() to fix export to TensorRT --- pl_bolts/models/detection/yolo/yolo_layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 64258d038a..af01b15c6e 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -238,8 +238,8 @@ def _global_xy(self, xy: Tensor) -> Tensor: width = xy.shape[2] grid_size = torch.tensor([width, height], device=xy.device) - x_range = torch.arange(width, dtype=xy.dtype, device=xy.device) - y_range = torch.arange(height, dtype=xy.dtype, device=xy.device) + x_range = torch.arange(width, device=xy.device) + y_range = torch.arange(height, device=xy.device) grid_y, grid_x = torch.meshgrid(y_range, x_range) offset = torch.stack((grid_x, grid_y), -1) # [height, width, 2] offset = offset.unsqueeze(2) # [height, width, 1, 2] From 2b6c5529313888bb34247eba026a4c66a9302ca2 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Sat, 10 Apr 2021 13:30:51 +0300 Subject: [PATCH 40/61] Network input size can differ from the image size specified in the configuration * Image size is given to detection layer forward() instead of the constructor to allow variable image sizes. * Use default data type for torch.arange() to fix export to TensorRT. --- pl_bolts/models/detection/yolo/yolo_config.py | 2 -- pl_bolts/models/detection/yolo/yolo_layers.py | 34 +++++++++---------- pl_bolts/models/detection/yolo/yolo_module.py | 31 +++++++++++------ 3 files changed, 37 insertions(+), 30 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index fc8cd6aba4..7e90fb3b80 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -249,8 +249,6 @@ def _create_yolo(config, num_inputs): module = yolo_layers.DetectionLayer( num_classes=config['classes'], - image_width=config['width'], - image_height=config['height'], anchor_dims=anchor_dims, anchor_ids=config['mask'], xy_scale=xy_scale, diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 64258d038a..22b4cab78e 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -87,8 +87,6 @@ class DetectionLayer(nn.Module): def __init__( self, num_classes: int, - image_width: int, - image_height: int, anchor_dims: List[Tuple[int, int]], anchor_ids: List[int], xy_scale: float = 1.0, @@ -104,10 +102,6 @@ def __init__( """ Args: num_classes: Number of different classes that this layer predicts. - image_width: Image width (defines the scale of the anchor box and target bounding box - dimensions). - image_height: Image height (defines the scale of the anchor box and target bounding box - dimensions). anchor_dims: A list of all the predefined anchor box dimensions. The list should contain (width, height) tuples in the network input resolution (relative to the width and height defined in the configuration file). @@ -139,8 +133,6 @@ def __init__( raise ModuleNotFoundError('YOLO model uses `torchvision`, which is not installed yet.') self.num_classes = num_classes - self.image_width = image_width - self.image_height = image_height self.anchor_dims = anchor_dims self.anchor_ids = anchor_ids self.anchor_map = [anchor_ids.index(i) if i in anchor_ids else -1 for i in range(9)] @@ -156,7 +148,10 @@ def __init__( self.class_loss_multiplier = class_loss_multiplier self.confidence_loss_multiplier = confidence_loss_multiplier - def forward(self, x: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None) -> Tuple[Tensor, Dict[str, Tensor]]: + def forward(self, + x: Tensor, + image_size: Tensor, + targets: Optional[List[Dict[str, Tensor]]] = None) -> Tuple[Tensor, Dict[str, Tensor]]: """ Runs a forward pass through this YOLO detection layer. @@ -169,6 +164,8 @@ def forward(self, x: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None) Args: x: The output from the previous layer. Tensor of size ``[batch_size, boxes_per_cell * (num_classes + 5), height, width]``. + image_size: Image width and height in a vector (defines the scale of the predicted and + target coordinates). targets: If set, computes losses from detection layers against these targets. A list of dictionaries, one for each image. @@ -202,7 +199,7 @@ def forward(self, x: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None) # x/y coordinates. xy = xy * self.xy_scale - 0.5 * (self.xy_scale - 1) - image_xy = self._global_xy(xy) + image_xy = self._global_xy(xy, image_size) image_wh = self._scale_wh(wh) boxes = _corner_coordinates(image_xy, image_wh) output = torch.cat((boxes, confidence.unsqueeze(-1), classprob), -1) @@ -214,10 +211,10 @@ def forward(self, x: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None) lc_mask = self._low_confidence_mask(boxes, targets) if not self.image_space_loss: boxes = torch.cat((xy, wh), -1) - losses = self._calculate_losses(boxes, confidence, classprob, targets, lc_mask) + losses = self._calculate_losses(boxes, confidence, classprob, targets, image_size, lc_mask) return output, losses - def _global_xy(self, xy: Tensor) -> Tensor: + def _global_xy(self, xy: Tensor, image_size: Tensor) -> Tensor: """ Adds offsets to the predicted box center coordinates to obtain global coordinates to the image. @@ -229,6 +226,7 @@ def _global_xy(self, xy: Tensor) -> Tensor: Args: xy: The predicted center coordinates before scaling. Values from zero to one in a tensor sized ``[batch_size, height, width, boxes_per_cell, 2]``. + image_size: Width and height in a vector that will be used to scale the coordinates. Returns: Global coordinates scaled to the size of the network input image, in a tensor with the @@ -238,13 +236,12 @@ def _global_xy(self, xy: Tensor) -> Tensor: width = xy.shape[2] grid_size = torch.tensor([width, height], device=xy.device) - x_range = torch.arange(width, dtype=xy.dtype, device=xy.device) - y_range = torch.arange(height, dtype=xy.dtype, device=xy.device) + x_range = torch.arange(width, device=xy.device) + y_range = torch.arange(height, device=xy.device) grid_y, grid_x = torch.meshgrid(y_range, x_range) offset = torch.stack((grid_x, grid_y), -1) # [height, width, 2] offset = offset.unsqueeze(2) # [height, width, 1, 2] - image_size = torch.tensor([self.image_width, self.image_height], device=xy.device) scale = image_size / grid_size return (xy + offset) * scale @@ -294,7 +291,8 @@ def _low_confidence_mask(self, boxes: Tensor, targets: List[Dict[str, Tensor]]) return results.view((batch_size, height, width, boxes_per_cell)) def _calculate_losses( - self, boxes: Tensor, confidence: Tensor, classprob: Tensor, targets: List[Dict[str, Tensor]], lc_mask: Tensor + self, boxes: Tensor, confidence: Tensor, classprob: Tensor, targets: List[Dict[str, Tensor]], + image_size: Tensor, lc_mask: Tensor ) -> Dict[str, Tensor]: """ From the targets that are in the image space calculates the actual targets for the network @@ -308,6 +306,7 @@ def _calculate_losses( classprob: The class probability predictions, normalized to `[0, 1]`. A tensor sized ``[batch_size, height, width, boxes_per_cell, num_classes]``. targets: List of dictionaries of target values, one dictionary for each image. + image_size: Width and height in a vector that defines the scale of the target coordinates. lc_mask: A boolean mask containing ``True`` where the predicted box does not overlap any target significantly. @@ -318,9 +317,8 @@ def _calculate_losses( device = boxes.device assert batch_size == len(targets) - image_size = torch.tensor([self.image_width, self.image_height], device=device) + # A multiplier for scaling image coordinates to feature map coordinates grid_size = torch.tensor([width, height], device=device) - # For scaling image coordinates to feature map coordinates image_to_grid = grid_size / image_size anchor_wh = torch.tensor(self.anchor_dims, dtype=boxes.dtype, device=device) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 705e238b6a..3a9904d37d 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -76,9 +76,17 @@ def __init__( self, network: nn.ModuleList, optimizer: Type[optim.Optimizer] = optim.SGD, - optimizer_params: Dict[str, Any] = {'lr': 0.001, 'momentum': 0.9, 'weight_decay': 0.0005}, + optimizer_params: Dict[str, Any] = { + 'lr': 0.001, + 'momentum': 0.9, + 'weight_decay': 0.0005 + }, lr_scheduler: Type[optim.lr_scheduler._LRScheduler] = LinearWarmupCosineAnnealingLR, - lr_scheduler_params: Dict[str, Any] = {'warmup_epochs': 1, 'max_epochs': 300, 'warmup_start_lr': 0.0}, + lr_scheduler_params: Dict[str, Any] = { + 'warmup_epochs': 1, + 'max_epochs': 300, + 'warmup_start_lr': 0.0 + }, confidence_threshold: float = 0.2, nms_threshold: float = 0.45, max_predictions_per_image: int = -1 @@ -103,8 +111,7 @@ def __init__( if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover - 'YOLO model uses `torchvision`, which is not installed yet.' - ) + 'YOLO model uses `torchvision`, which is not installed yet.') self.network = network self.optimizer_class = optimizer @@ -149,16 +156,20 @@ def forward( detections = [] # Outputs from detection layers losses = [] # Losses from detection layers + image_height = images.shape[2] + image_width = images.shape[3] + image_size = torch.tensor([image_width, image_height], device=images.device) + x = images for module in self.network: if isinstance(module, (RouteLayer, ShortcutLayer)): x = module(x, outputs) elif isinstance(module, DetectionLayer): if targets is None: - x = module(x) + x = module(x, image_size) detections.append(x) else: - x, layer_losses = module(x, targets) + x, layer_losses = module(x, image_size, targets) detections.append(x) losses.append(layer_losses) else: @@ -256,16 +267,16 @@ def infer(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: A matrix of detected bounding box `(x1, y1, x2, y2)` coordinates, a vector of confidences for the bounding box detections, and a vector of predicted class labels. """ - network_input = image.float().div(255.0) - network_input = network_input.unsqueeze(0) + if not isinstance(image, torch.Tensor): + image = F.to_tensor(image) + self.eval() - detections = self(network_input) + detections = self(image.unsqueeze(0)) detections = self._split_detections(detections) detections = self._filter_detections(detections) boxes = detections['boxes'][0] scores = detections['scores'][0] labels = detections['labels'][0] - boxes = torch.round(boxes).int() return boxes, scores, labels def load_darknet_weights(self, weight_file): From 004d1ce4a240fb4575e8dd018393ecd0091745a0 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 3 May 2021 13:23:12 +0300 Subject: [PATCH 41/61] Use torch.true_divide() instead of / --- pl_bolts/models/detection/yolo/yolo_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 22b4cab78e..7ccba05c97 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -242,7 +242,7 @@ def _global_xy(self, xy: Tensor, image_size: Tensor) -> Tensor: offset = torch.stack((grid_x, grid_y), -1) # [height, width, 2] offset = offset.unsqueeze(2) # [height, width, 1, 2] - scale = image_size / grid_size + scale = torch.true_divide(image_size, grid_size) return (xy + offset) * scale def _scale_wh(self, wh: Tensor) -> Tensor: From ad1e48e573c87ec9ffb7a237e8954c300da65079 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 5 May 2021 14:00:41 +0300 Subject: [PATCH 42/61] Use torch.true_divide() instead of / --- pl_bolts/models/detection/yolo/yolo_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 7ccba05c97..f028f49f0a 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -319,7 +319,7 @@ def _calculate_losses( # A multiplier for scaling image coordinates to feature map coordinates grid_size = torch.tensor([width, height], device=device) - image_to_grid = grid_size / image_size + image_to_grid = torch.true_divide(grid_size, image_size) anchor_wh = torch.tensor(self.anchor_dims, dtype=boxes.dtype, device=device) anchor_map = torch.tensor(self.anchor_map, dtype=torch.int64, device=device) From c237b370e9b439e1001b7329d52145f853783ba3 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 23 Jun 2021 19:29:20 +0300 Subject: [PATCH 43/61] Loss is normalized by batch size only once --- pl_bolts/models/detection/yolo/yolo_layers.py | 55 ++++++++----------- pl_bolts/models/detection/yolo/yolo_module.py | 33 +++++++---- 2 files changed, 46 insertions(+), 42 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index f028f49f0a..fb49ee81e0 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -133,9 +133,9 @@ def __init__( raise ModuleNotFoundError('YOLO model uses `torchvision`, which is not installed yet.') self.num_classes = num_classes - self.anchor_dims = anchor_dims - self.anchor_ids = anchor_ids - self.anchor_map = [anchor_ids.index(i) if i in anchor_ids else -1 for i in range(9)] + self.all_anchor_dims = anchor_dims + self.anchor_dims = [anchor_dims[i] for i in anchor_ids] + self.anchor_map = [anchor_ids.index(i) if i in anchor_ids else -1 for i in range(len(anchor_dims))] self.xy_scale = xy_scale self.ignore_threshold = ignore_threshold @@ -159,7 +159,9 @@ def forward(self, boxes with the anchors, converts the center coordinates to corner coordinates, and maps probabilities to the `]0, 1[` range using sigmoid. - If targets are given, computes also losses from the predictions and the targets. + If targets are given, computes also losses from the predictions and the targets. This layer + is responsible only for the targets that best match one of the anchors assigned to this + layer. Args: x: The output from the previous layer. Tensor of size @@ -170,17 +172,18 @@ def forward(self, dictionaries, one for each image. Returns: - output (Tensor), losses (Dict[str, Tensor]): Layer output, and if training targets were - provided, a dictionary of losses. Layer output is sized - ``[batch_size, num_anchors * height * width, num_classes + 5]``. + output (Tensor), losses (Dict[str, Tensor]), hits (int): Layer output tensor, sized + ``[batch_size, num_anchors * height * width, num_classes + 5]``. If training targets + were provided, also returns a dictionary of losses and the number of targets that this + layer was responsible for. """ batch_size, num_features, height, width = x.shape num_attrs = self.num_classes + 5 boxes_per_cell = num_features // num_attrs - if boxes_per_cell != len(self.anchor_ids): + if boxes_per_cell != len(self.anchor_dims): raise MisconfigurationException( "The model predicts {} bounding boxes per cell, but {} anchor boxes are defined " - "for this layer.".format(boxes_per_cell, len(self.anchor_ids)) + "for this layer.".format(boxes_per_cell, len(self.anchor_dims)) ) # Reshape the output to have the bounding box attributes of each grid cell on its own row. @@ -200,7 +203,7 @@ def forward(self, xy = xy * self.xy_scale - 0.5 * (self.xy_scale - 1) image_xy = self._global_xy(xy, image_size) - image_wh = self._scale_wh(wh) + image_wh = torch.exp(wh) * torch.tensor(self.anchor_dims, dtype=wh.dtype, device=wh.device) boxes = _corner_coordinates(image_xy, image_wh) output = torch.cat((boxes, confidence.unsqueeze(-1), classprob), -1) output = output.reshape(batch_size, height * width * boxes_per_cell, num_attrs) @@ -211,8 +214,8 @@ def forward(self, lc_mask = self._low_confidence_mask(boxes, targets) if not self.image_space_loss: boxes = torch.cat((xy, wh), -1) - losses = self._calculate_losses(boxes, confidence, classprob, targets, image_size, lc_mask) - return output, losses + losses, hits = self._calculate_losses(boxes, confidence, classprob, targets, image_size, lc_mask) + return output, losses, hits def _global_xy(self, xy: Tensor, image_size: Tensor) -> Tensor: """ @@ -245,22 +248,6 @@ def _global_xy(self, xy: Tensor, image_size: Tensor) -> Tensor: scale = torch.true_divide(image_size, grid_size) return (xy + offset) * scale - def _scale_wh(self, wh: Tensor) -> Tensor: - """ - Scales the box size predictions by the prior dimensions from the anchors. - - Args: - wh: The unnormalized width and height predictions. Tensor of size - ``[..., boxes_per_cell, 2]``. - - Returns: - A tensor with the same shape as the input tensor, containing final width and height in - the image space. - """ - anchor_wh = [self.anchor_dims[i] for i in self.anchor_ids] - anchor_wh = torch.tensor(anchor_wh, dtype=wh.dtype, device=wh.device) - return torch.exp(wh) * anchor_wh - def _low_confidence_mask(self, boxes: Tensor, targets: List[Dict[str, Tensor]]) -> Tensor: """ Initializes the mask that will be used to select predictors that are not predicting any @@ -306,12 +293,14 @@ def _calculate_losses( classprob: The class probability predictions, normalized to `[0, 1]`. A tensor sized ``[batch_size, height, width, boxes_per_cell, num_classes]``. targets: List of dictionaries of target values, one dictionary for each image. - image_size: Width and height in a vector that defines the scale of the target coordinates. + image_size: Width and height in a vector that defines the scale of the target + coordinates. lc_mask: A boolean mask containing ``True`` where the predicted box does not overlap any target significantly. Returns: - A dictionary of training losses. + losses (Dict[str, Tensor]), hits (int): A dictionary of training losses and the number + of targets that this layer was responsible for. """ batch_size, height, width, boxes_per_cell, _ = boxes.shape device = boxes.device @@ -321,7 +310,7 @@ def _calculate_losses( grid_size = torch.tensor([width, height], device=device) image_to_grid = torch.true_divide(grid_size, image_size) - anchor_wh = torch.tensor(self.anchor_dims, dtype=boxes.dtype, device=device) + anchor_wh = torch.tensor(self.all_anchor_dims, dtype=boxes.dtype, device=device) anchor_map = torch.tensor(self.anchor_map, dtype=torch.int64, device=device) # List of predicted and target values for the predictors that are responsible for @@ -333,6 +322,7 @@ def _calculate_losses( pred_boxes = [] pred_classprob = [] pred_confidence = [] + hits = 0 for image_idx, image_targets in enumerate(targets): target_boxes = image_targets['boxes'] @@ -367,6 +357,7 @@ def _calculate_losses( cell_j = cell_j[selected] predictors = predictors[selected] wh = wh[selected] + hits += selected.count_nonzero() # The "low-confidence" mask is used to select predictors that are not responsible for # predicting any object, for calculating the part of the confidence loss with zero as @@ -441,7 +432,7 @@ def _calculate_losses( confidence_loss = confidence_loss.sum() / batch_size losses['confidence'] = confidence_loss * self.confidence_loss_multiplier - return losses + return losses, hits class Mish(nn.Module): diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 3a9904d37d..b1e74df467 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -68,8 +68,8 @@ class YOLO(pl.LightningModule): CLI command:: # PascalVOC - wget https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4-tiny.cfg - python yolo_module.py --config yolov4-tiny.cfg --data_dir . --gpus 8 --batch-size 8 + wget https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4-tiny-3l.cfg + python yolo_module.py --config yolov4-tiny-3l.cfg --data_dir . --gpus 8 --batch-size 8 """ def __init__( @@ -155,6 +155,7 @@ def forward( outputs = [] # Outputs from all layers detections = [] # Outputs from detection layers losses = [] # Losses from detection layers + hits = [] # Number of targets each detection layer was responsible for image_height = images.shape[2] image_width = images.shape[3] @@ -169,23 +170,35 @@ def forward( x = module(x, image_size) detections.append(x) else: - x, layer_losses = module(x, image_size, targets) + x, layer_losses, layer_hits = module(x, image_size, targets) detections.append(x) losses.append(layer_losses) + hits.append(layer_hits) else: x = module(x) outputs.append(x) - def mean_loss(loss_name): - loss_tuple = tuple(layer_losses[loss_name] for layer_losses in losses) - return torch.stack(loss_tuple).sum() / images.shape[0] - detections = torch.cat(detections, 1) if targets is None: return detections - losses = {loss_name: mean_loss(loss_name) for loss_name in losses[0].keys()} + total_hits = sum(hits) + num_targets = sum(len(image_targets['boxes']) for image_targets in targets) + if total_hits != num_targets: + log.warning( + f'{num_targets} training targets were matched a total of {total_hits} times by detection layers. ' + 'Anchors may have been configured incorrectly.' + ) + for layer_idx, layer_hits in enumerate(hits): + self.log(f'train/layer_{layer_idx}_hit_rate', layer_hits / total_hits, sync_dist=True) + + def total_loss(loss_name): + """Returns the sum of the loss over detection layers.""" + loss_tuple = tuple(layer_losses[loss_name] for layer_losses in losses) + return torch.stack(loss_tuple).sum() + + losses = {loss_name: total_loss(loss_name) for loss_name in losses[0].keys()} return detections, losses def configure_optimizers(self) -> Tuple[List, List]: @@ -211,8 +224,8 @@ def training_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], bat total_loss = torch.stack(tuple(losses.values())).sum() for name, value in losses.items(): - self.log(f'train/{name}_loss', value, prog_bar=True) - self.log('train/total_loss', total_loss) + self.log(f'train/{name}_loss', value, prog_bar=True, sync_dist=True) + self.log('train/total_loss', total_loss, sync_dist=True) return {'loss': total_loss} From 9b010de71f3acfd67b5a3bf558ee069832ba88f2 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 23 Jun 2021 20:10:43 +0300 Subject: [PATCH 44/61] Fixed division by zero when there are no targets in a batch --- pl_bolts/models/detection/yolo/yolo_module.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index b1e74df467..ad3a7ca4d5 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -190,8 +190,9 @@ def forward( f'{num_targets} training targets were matched a total of {total_hits} times by detection layers. ' 'Anchors may have been configured incorrectly.' ) - for layer_idx, layer_hits in enumerate(hits): - self.log(f'train/layer_{layer_idx}_hit_rate', layer_hits / total_hits, sync_dist=True) + if total_hits > 0: + for layer_idx, layer_hits in enumerate(hits): + self.log(f'train/layer_{layer_idx}_hit_rate', layer_hits / total_hits, sync_dist=True) def total_loss(loss_name): """Returns the sum of the loss over detection layers.""" From f15282d3c79cafa78b2ea746336e93b317bda6e0 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Thu, 24 Jun 2021 11:06:53 +0300 Subject: [PATCH 45/61] Always return all losses to avoid deadlock with DDP when there are no targets --- pl_bolts/models/detection/yolo/yolo_layers.py | 4 ++++ pl_bolts/models/detection/yolo/yolo_module.py | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index fb49ee81e0..909e8dc6eb 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -408,6 +408,8 @@ def _calculate_losses( overlap_loss = overlap_loss * size_compensation overlap_loss = overlap_loss.sum() / batch_size losses['overlap'] = overlap_loss * self.overlap_loss_multiplier + else: + losses['overlap'] = torch.tensor(0.0, device=device) if pred_classprob and target_label: pred_classprob = torch.cat(pred_classprob) @@ -417,6 +419,8 @@ def _calculate_losses( class_loss = self.class_loss_func(pred_classprob, target_classprob) class_loss = class_loss.sum() / batch_size losses['class'] = class_loss * self.class_loss_multiplier + else: + losses['class'] = torch.tensor(0.0, device=device) pred_low_confidence = confidence[lc_mask] target_low_confidence = torch.zeros_like(pred_low_confidence) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index ad3a7ca4d5..8d49826f17 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -190,9 +190,9 @@ def forward( f'{num_targets} training targets were matched a total of {total_hits} times by detection layers. ' 'Anchors may have been configured incorrectly.' ) - if total_hits > 0: - for layer_idx, layer_hits in enumerate(hits): - self.log(f'train/layer_{layer_idx}_hit_rate', layer_hits / total_hits, sync_dist=True) + for layer_idx, layer_hits in enumerate(hits): + hit_rate = layer_hits / total_hits if total_hits > 0 else 1.0 + self.log(f'train/layer_{layer_idx}_hit_rate', hit_rate, sync_dist=True) def total_loss(loss_name): """Returns the sum of the loss over detection layers.""" From f6d3476cba7412bf2f7a7ebd956ec2afa9ff534a Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Thu, 1 Jul 2021 22:49:06 +0300 Subject: [PATCH 46/61] Hit rates are always logged so don't prefix the names --- pl_bolts/models/detection/yolo/yolo_module.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 8d49826f17..c18a336d0d 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -2,9 +2,10 @@ from typing import Any, Dict, List, Optional, Tuple, Type import numpy as np -import pytorch_lightning as pl import torch import torch.nn as nn +from pytorch_lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_info from torch import optim, Tensor from pl_bolts.models.detection.yolo.yolo_layers import DetectionLayer, RouteLayer, ShortcutLayer @@ -21,7 +22,7 @@ log = logging.getLogger(__name__) -class YOLO(pl.LightningModule): +class YOLO(LightningModule): """ PyTorch Lightning implementation of `YOLOv3 `_ with some improvements from `YOLOv4 `_. @@ -192,7 +193,7 @@ def forward( ) for layer_idx, layer_hits in enumerate(hits): hit_rate = layer_hits / total_hits if total_hits > 0 else 1.0 - self.log(f'train/layer_{layer_idx}_hit_rate', hit_rate, sync_dist=True) + self.log(f'layer_{layer_idx}_hit_rate', hit_rate, sync_dist=True) def total_loss(loss_name): """Returns the sum of the loss over detection layers.""" @@ -309,9 +310,9 @@ def load_darknet_weights(self, weight_file): """ version = np.fromfile(weight_file, count=3, dtype=np.int32) images_seen = np.fromfile(weight_file, count=1, dtype=np.int64) - log.info( - 'Loading weights from Darknet model version %d.%d.%d that has been trained on %d ' - 'images.', version[0], version[1], version[2], images_seen[0] + rank_zero_info( + f'Loading weights from Darknet model version {version[0]}.{version[1]}.{version[2]} ' + f'that has been trained on {images_seen[0]} images.' ) def read(tensor): @@ -508,10 +509,12 @@ def __call__(self, image, target): def run_cli(): from argparse import ArgumentParser + from pytorch_lightning import seed_everything, Trainer + from pl_bolts.datamodules import VOCDetectionDataModule from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration - pl.seed_everything(42) + seed_everything(42) parser = ArgumentParser() parser.add_argument( @@ -565,7 +568,7 @@ def run_cli(): ) parser = VOCDetectionDataModule.add_argparse_args(parser) - parser = pl.Trainer.add_argparse_args(parser) + parser = Trainer.add_argparse_args(parser) args = parser.parse_args() config = YOLOConfiguration(args.config) @@ -596,7 +599,7 @@ def run_cli(): with open(args.darknet_weights, 'r') as weight_file: model.load_darknet_weights(weight_file) - trainer = pl.Trainer.from_argparse_args(args) + trainer = Trainer.from_argparse_args(args) trainer.fit( model, datamodule.train_dataloader(args.batch_size, transforms), datamodule.val_dataloader(args.batch_size, transforms) From 86a6b668364cdd28fe5a6aede7f6c0fc432c4aa8 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Sat, 31 Jul 2021 20:49:07 +0300 Subject: [PATCH 47/61] Fixed training loss * The vector of overlap losses was accidentally transformed to a square matrix. * Some versions of Lightning don't work correctly when logging losses with sync_dist=True. --- pl_bolts/models/detection/yolo/yolo_layers.py | 2 +- pl_bolts/models/detection/yolo/yolo_module.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 909e8dc6eb..dd45b91d1a 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -398,7 +398,7 @@ def _calculate_losses( losses = dict() if pred_boxes and target_xy and target_wh: - size_compensation = torch.cat(size_compensation).unsqueeze(1) + size_compensation = torch.cat(size_compensation) pred_boxes = torch.cat(pred_boxes) if self.image_space_loss: target_boxes = _corner_coordinates(torch.cat(target_xy), torch.cat(target_wh)) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index c18a336d0d..100a5e9959 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -193,7 +193,7 @@ def forward( ) for layer_idx, layer_hits in enumerate(hits): hit_rate = layer_hits / total_hits if total_hits > 0 else 1.0 - self.log(f'layer_{layer_idx}_hit_rate', hit_rate, sync_dist=True) + self.log(f'layer_{layer_idx}_hit_rate', hit_rate, sync_dist=False) def total_loss(loss_name): """Returns the sum of the loss over detection layers.""" @@ -225,9 +225,11 @@ def training_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], bat _, losses = self(images, targets) total_loss = torch.stack(tuple(losses.values())).sum() + # sync_dist=True is broken in some versions of Lightning and may cause the sum of the loss + # across GPUs to be returned. for name, value in losses.items(): - self.log(f'train/{name}_loss', value, prog_bar=True, sync_dist=True) - self.log('train/total_loss', total_loss, sync_dist=True) + self.log(f'train/{name}_loss', value, prog_bar=True, sync_dist=False) + self.log('train/total_loss', total_loss, sync_dist=False) return {'loss': total_loss} From 3286533f28f1bc3d6691c594bc2baae642f2a239 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 4 Aug 2021 22:37:50 +0300 Subject: [PATCH 48/61] Truncate nms() inputs to avoid it crashing when too many boxes are detected --- pl_bolts/models/detection/yolo/yolo_module.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 100a5e9959..ebd4d953b1 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -460,7 +460,11 @@ def _filter_detections(self, detections: Dict[str, Tensor]) -> Dict[str, List[Te cls_classprobs = img_classprobs[selected] cls_labels = img_labels[selected] + # NMS will crash if there are too many boxes. + cls_boxes = cls_boxes[:100000] + cls_scores = cls_scores[:100000] selected = nms(cls_boxes, cls_scores, self.nms_threshold) + img_out_boxes = torch.cat((img_out_boxes, cls_boxes[selected])) img_out_scores = torch.cat((img_out_scores, cls_scores[selected])) img_out_classprobs = torch.cat((img_out_classprobs, cls_classprobs[selected])) From bb92076dea086ac42febf8bd95ab30f5a396ede4 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 11 Aug 2021 11:56:17 +0300 Subject: [PATCH 49/61] Use sum() instead of count_nonzero() as it's available already before PyTorch 1.7 --- pl_bolts/models/detection/yolo/yolo_layers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index dd45b91d1a..19d42ffdc0 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -357,7 +357,8 @@ def _calculate_losses( cell_j = cell_j[selected] predictors = predictors[selected] wh = wh[selected] - hits += selected.count_nonzero() + # sum() is equivalent to count_nonzero() and available before PyTorch 1.7. + hits += selected.sum() # The "low-confidence" mask is used to select predictors that are not responsible for # predicting any object, for calculating the part of the confidence loss with zero as From b8961128b608f4d4012491776be24d2c26e03a07 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Tue, 17 Aug 2021 16:44:06 +0300 Subject: [PATCH 50/61] Squared error loss takes the sum over the predicted attributes --- pl_bolts/models/detection/yolo/yolo_config.py | 2 +- pl_bolts/models/detection/yolo/yolo_layers.py | 28 ++++++++++++------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index 7e90fb3b80..3d5c71bc6d 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -241,7 +241,7 @@ def _create_yolo(config, num_inputs): overlap_loss_name = config.get('iou_loss', 'mse') if overlap_loss_name == 'mse': - overlap_loss_func = nn.MSELoss(reduction='none') + overlap_loss_func = yolo_layers.SELoss() elif overlap_loss_name == 'giou': overlap_loss_func = yolo_layers.GIoULoss() else: diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 19d42ffdc0..9335c2d8cb 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -59,6 +59,15 @@ def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor: return inter / union +class SELoss(nn.MSELoss): + + def __init__(self): + super().__init__(reduction='none') + + def forward(self, inputs: Tensor, target: Tensor) -> Tensor: + return super().forward(inputs, target).sum(1) + + class IoULoss(nn.Module): def forward(self, inputs: Tensor, target: Tensor) -> Tensor: @@ -118,12 +127,12 @@ def __init__( of squared errors. confidence_loss_func: Loss function for confidence score. Default is the sum of squared errors. - image_space_loss: If set to ``True``, the overlap loss function will receive the bounding - box `(x1, y1, x2, y2)` coordinates, scaled to the input image size. This is needed - for the IoU losses introduced in YOLOv4. Otherwise the loss will be computed from - the x, y, width, and height values, as predicted by the network (i.e. relative to - the anchor box, and width and height are logarithmic). - coord_loss_multiplier: Multiply the coordinate/size loss by this factor. + image_space_loss: If set to ``True``, the overlap loss function will receive the + bounding box `(x1, y1, x2, y2)` coordinates, scaled to the input image size. This is + needed for the IoU losses introduced in YOLOv4. Otherwise the loss will be computed + from the x, y, width, and height values, as predicted by the network (i.e. relative + to the anchor box, and width and height are logarithmic). + overlap_loss_multiplier: Multiply the overlap loss by this factor. class_loss_multiplier: Multiply the classification loss by this factor. confidence_loss_multiplier: Multiply the confidence loss by this factor. """ @@ -139,10 +148,9 @@ def __init__( self.xy_scale = xy_scale self.ignore_threshold = ignore_threshold - se_loss = nn.MSELoss(reduction='none') - self.overlap_loss_func = overlap_loss_func or se_loss - self.class_loss_func = class_loss_func or se_loss - self.confidence_loss_func = confidence_loss_func or se_loss + self.overlap_loss_func = overlap_loss_func or SELoss() + self.class_loss_func = class_loss_func or SELoss() + self.confidence_loss_func = confidence_loss_func or nn.MSELoss(reduction='none') self.image_space_loss = image_space_loss self.overlap_loss_multiplier = overlap_loss_multiplier self.class_loss_multiplier = class_loss_multiplier From 55a11803bbd23671dd679a89b3dd8bd568b26ca3 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Tue, 17 Aug 2021 17:04:11 +0300 Subject: [PATCH 51/61] Swish and logistic activation functions --- pl_bolts/models/detection/yolo/yolo_config.py | 18 ++++++++++-- pl_bolts/models/detection/yolo/yolo_layers.py | 29 +++++++++++++++---- pl_bolts/models/detection/yolo/yolo_module.py | 4 +-- 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index 3d5c71bc6d..8e5ef909b7 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -93,6 +93,7 @@ def _read_file(self, config_file: Iterable[str]) -> List[Dict[str, Any]]: 'max_delta': float, 'momentum': float, 'mosaic': bool, + 'new_coords': int, 'nms_kind': str, 'num': int, 'obj_normalizer': float, @@ -186,12 +187,23 @@ def _create_convolutional(config, num_inputs): bn = nn.BatchNorm2d(config['filters']) module.add_module('bn', bn) - if config['activation'] == 'leaky': + activation_name = config['activation'] + if activation_name == 'leaky': leakyrelu = nn.LeakyReLU(0.1, inplace=True) module.add_module('leakyrelu', leakyrelu) - elif config['activation'] == 'mish': + elif activation_name == 'mish': mish = yolo_layers.Mish() module.add_module('mish', mish) + elif activation_name == 'swish': + swish = nn.SiLU(inplace=True) + module.add_module('swish', swish) + elif activation_name == 'logistic': + logistic = nn.Sigmoid() + module.add_module('logistic', logistic) + elif activation_name == 'linear': + pass + else: + raise ValueError('Unknown activation: ' + activation_name) return module, config['filters'] @@ -234,6 +246,7 @@ def _create_yolo(config, num_inputs): anchor_dims = [(anchor_dims[i], anchor_dims[i + 1]) for i in range(0, len(anchor_dims), 2)] xy_scale = config.get('scale_x_y', 1.0) + input_is_normalized = config.get('new_coords', 0) > 0 ignore_threshold = config.get('ignore_thresh', 1.0) overlap_loss_multiplier = config.get('iou_normalizer', 1.0) class_loss_multiplier = config.get('cls_normalizer', 1.0) @@ -252,6 +265,7 @@ def _create_yolo(config, num_inputs): anchor_dims=anchor_dims, anchor_ids=config['mask'], xy_scale=xy_scale, + input_is_normalized=input_is_normalized, ignore_threshold=ignore_threshold, overlap_loss_func=overlap_loss_func, image_space_loss=overlap_loss_name != 'mse', diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 9335c2d8cb..98586efc19 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -99,6 +99,7 @@ def __init__( anchor_dims: List[Tuple[int, int]], anchor_ids: List[int], xy_scale: float = 1.0, + input_is_normalized: bool = False, ignore_threshold: float = 0.5, overlap_loss_func: Optional[Callable] = None, class_loss_func: Optional[Callable] = None, @@ -118,6 +119,10 @@ def __init__( anchors that this layer uses. xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps to produce coordinate values close to one. + input_is_normalized: The input is normalized by logistic activation in the previous + layer. In this case the detection layer will not take the sigmoid of the coordinate + and probability predictions, and the width and height are scaled up so that the + maximum value is four times the anchor dimension ignore_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. @@ -146,6 +151,7 @@ def __init__( self.anchor_dims = [anchor_dims[i] for i in anchor_ids] self.anchor_map = [anchor_ids.index(i) if i in anchor_ids else -1 for i in range(len(anchor_dims))] self.xy_scale = xy_scale + self.input_is_normalized = input_is_normalized self.ignore_threshold = ignore_threshold self.overlap_loss_func = overlap_loss_func or SELoss() @@ -199,11 +205,16 @@ def forward(self, x = x.view(batch_size, height, width, boxes_per_cell, num_attrs) # Take the sigmoid of the bounding box coordinates, confidence score, and class - # probabilities. - xy = torch.sigmoid(x[..., :2]) + # probabilities, unless the input is normalized by the previous layer activation. + if self.input_is_normalized: + xy = x[..., :2] + confidence = x[..., 4] + classprob = x[..., 5:] + else: + xy = torch.sigmoid(x[..., :2]) + confidence = torch.sigmoid(x[..., 4]) + classprob = torch.sigmoid(x[..., 5:]) wh = x[..., 2:4] - confidence = torch.sigmoid(x[..., 4]) - classprob = torch.sigmoid(x[..., 5:]) # Eliminate grid sensitivity. The previous layer should output extremely high values for # the sigmoid to produce x/y coordinates close to one. YOLOv4 solves this by scaling the @@ -211,7 +222,10 @@ def forward(self, xy = xy * self.xy_scale - 0.5 * (self.xy_scale - 1) image_xy = self._global_xy(xy, image_size) - image_wh = torch.exp(wh) * torch.tensor(self.anchor_dims, dtype=wh.dtype, device=wh.device) + if self.input_is_normalized: + image_wh = 4 * torch.square(wh) * torch.tensor(self.anchor_dims, dtype=wh.dtype, device=wh.device) + else: + image_wh = torch.exp(wh) * torch.tensor(self.anchor_dims, dtype=wh.dtype, device=wh.device) boxes = _corner_coordinates(image_xy, image_wh) output = torch.cat((boxes, confidence.unsqueeze(-1), classprob), -1) output = output.reshape(batch_size, height * width * boxes_per_cell, num_attrs) @@ -383,7 +397,10 @@ def _calculate_losses( grid_xy = grid_xy[selected] best_anchors = best_anchors[selected] relative_xy = grid_xy - grid_xy.floor() - relative_wh = torch.log(wh / anchor_wh[best_anchors] + 1e-16) + if self.input_is_normalized: + relative_wh = torch.sqrt(wh / (4 * anchor_wh[best_anchors] + 1e-16)) + else: + relative_wh = torch.log(wh / anchor_wh[best_anchors] + 1e-16) target_xy.append(relative_xy) target_wh.append(relative_wh) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index ebd4d953b1..be18b5d8f7 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -337,10 +337,8 @@ def read(tensor): conv = module[0] assert isinstance(conv, nn.Conv2d) - if len(module) > 1: + if len(module) > 1 and isinstance(module[1], nn.BatchNorm2d): bn = module[1] - assert isinstance(bn, nn.BatchNorm2d) - read(bn.bias) read(bn.weight) read(bn.running_mean) From 7857bea6b0da909fc1bed87f0192fd8cbab1f045 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 18 Aug 2021 12:34:21 +0300 Subject: [PATCH 52/61] Added a comment --- pl_bolts/models/detection/yolo/yolo_module.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index be18b5d8f7..0d1eb38341 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -337,6 +337,8 @@ def read(tensor): conv = module[0] assert isinstance(conv, nn.Conv2d) + # Convolution may be followed by batch normalization, in which case we read the batch + # normalization parameters and not the convolution bias. if len(module) > 1 and isinstance(module[1], nn.BatchNorm2d): bn = module[1] read(bn.bias) From 0804699e5c6346836aad8f29d71dba0abfb50e3a Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Thu, 19 Aug 2021 00:52:55 +0300 Subject: [PATCH 53/61] Fixed code formatting --- .../datamodules/vocdetection_datamodule.py | 4 +- pl_bolts/models/detection/__init__.py | 7 +- pl_bolts/models/detection/yolo/yolo_config.py | 228 ++++++++--------- pl_bolts/models/detection/yolo/yolo_layers.py | 88 +++---- pl_bolts/models/detection/yolo/yolo_module.py | 232 ++++++++---------- tests/models/test_detection.py | 18 +- 6 files changed, 275 insertions(+), 302 deletions(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 12d9253174..231c590a05 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -149,7 +149,7 @@ def train_dataloader( self, batch_size: int = 1, transforms: Optional[List[Callable]] = None, - image_transforms: Optional[Callable] = None + image_transforms: Optional[Callable] = None, ) -> DataLoader: """VOCDetection train set uses the `train` subset. @@ -181,7 +181,7 @@ def val_dataloader( self, batch_size: int = 1, transforms: Optional[List[Callable]] = None, - image_transforms: Optional[Callable] = None + image_transforms: Optional[Callable] = None, ) -> DataLoader: """VOCDetection val set uses the `val` subset diff --git a/pl_bolts/models/detection/__init__.py b/pl_bolts/models/detection/__init__.py index 437a83eeb4..bcb97d7269 100644 --- a/pl_bolts/models/detection/__init__.py +++ b/pl_bolts/models/detection/__init__.py @@ -3,9 +3,4 @@ from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration from pl_bolts.models.detection.yolo.yolo_module import YOLO -__all__ = [ - "components", - "FasterRCNN", - "YOLOConfiguration", - "YOLO" -] +__all__ = ["components", "FasterRCNN", "YOLOConfiguration", "YOLO"] diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index 8e5ef909b7..ff6ea25a9f 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -9,21 +9,22 @@ class YOLOConfiguration: - """ - This class can be used to parse the configuration files of the Darknet YOLOv4 implementation. + """This class can be used to parse the configuration files of the Darknet + YOLOv4 implementation. + The :func:`~pl_bolts.models.detection.yolo.yolo_config.YOLOConfiguration.get_network` method returns a PyTorch module list that can be used to construct a YOLO model. """ def __init__(self, path: str) -> None: - """ - Saves the variables from the first configuration section to attributes of this object, and - the rest of the sections to the ``layer_configs`` list. + """Saves the variables from the first configuration section to + attributes of this object, and the rest of the sections to the + ``layer_configs`` list. Args: path: Path to a configuration file """ - with open(path, 'r') as config_file: + with open(path, "r") as config_file: sections = self._read_file(config_file) if len(sections) < 2: @@ -34,9 +35,9 @@ def __init__(self, path: str) -> None: self.layer_configs = sections[1:] def get_network(self) -> nn.ModuleList: - """ - Iterates through the layers from the configuration and creates corresponding PyTorch - modules. Returns the network structure that can be used to create a YOLO model. + """Iterates through the layers from the configuration and creates + corresponding PyTorch modules. Returns the network structure that can + be used to create a YOLO model. Returns: A :class:`~torch.nn.ModuleList` that defines the YOLO network. @@ -51,8 +52,8 @@ def get_network(self) -> nn.ModuleList: return result def _read_file(self, config_file: Iterable[str]) -> List[Dict[str, Any]]: - """ - Reads a YOLOv4 network configuration file and returns a list of configuration sections. + """Reads a YOLOv4 network configuration file and returns a list of + configuration sections. Args: config_file: The configuration file to read. @@ -60,56 +61,56 @@ def _read_file(self, config_file: Iterable[str]) -> List[Dict[str, Any]]: Returns: A list of configuration sections. """ - section_re = re.compile(r'\[([^]]+)\]') - list_variables = ('layers', 'anchors', 'mask', 'scales') + section_re = re.compile(r"\[([^]]+)\]") + list_variables = ("layers", "anchors", "mask", "scales") variable_types = { - 'activation': str, - 'anchors': int, - 'angle': float, - 'batch': int, - 'batch_normalize': bool, - 'beta_nms': float, - 'burn_in': int, - 'channels': int, - 'classes': int, - 'cls_normalizer': float, - 'decay': float, - 'exposure': float, - 'filters': int, - 'from': int, - 'groups': int, - 'group_id': int, - 'height': int, - 'hue': float, - 'ignore_thresh': float, - 'iou_loss': str, - 'iou_normalizer': float, - 'iou_thresh': float, - 'jitter': float, - 'layers': int, - 'learning_rate': float, - 'mask': int, - 'max_batches': int, - 'max_delta': float, - 'momentum': float, - 'mosaic': bool, - 'new_coords': int, - 'nms_kind': str, - 'num': int, - 'obj_normalizer': float, - 'pad': bool, - 'policy': str, - 'random': bool, - 'resize': float, - 'saturation': float, - 'scales': float, - 'scale_x_y': float, - 'size': int, - 'steps': str, - 'stride': int, - 'subdivisions': int, - 'truth_thresh': float, - 'width': int + "activation": str, + "anchors": int, + "angle": float, + "batch": int, + "batch_normalize": bool, + "beta_nms": float, + "burn_in": int, + "channels": int, + "classes": int, + "cls_normalizer": float, + "decay": float, + "exposure": float, + "filters": int, + "from": int, + "groups": int, + "group_id": int, + "height": int, + "hue": float, + "ignore_thresh": float, + "iou_loss": str, + "iou_normalizer": float, + "iou_thresh": float, + "jitter": float, + "layers": int, + "learning_rate": float, + "mask": int, + "max_batches": int, + "max_delta": float, + "momentum": float, + "mosaic": bool, + "new_coords": int, + "nms_kind": str, + "num": int, + "obj_normalizer": float, + "pad": bool, + "policy": str, + "random": bool, + "resize": float, + "saturation": float, + "scales": float, + "scale_x_y": float, + "size": int, + "steps": str, + "stride": int, + "subdivisions": int, + "truth_thresh": float, + "width": int, } section = None @@ -118,26 +119,26 @@ def _read_file(self, config_file: Iterable[str]) -> List[Dict[str, Any]]: def convert(key, value): """Converts a value to the correct type based on key.""" if key not in variable_types: - warn('Unknown YOLO configuration variable: ' + key) + warn("Unknown YOLO configuration variable: " + key) return key, value if key in list_variables: - value = [variable_types[key](v) for v in value.split(',')] + value = [variable_types[key](v) for v in value.split(",")] else: value = variable_types[key](value) return key, value for line in config_file: line = line.strip() - if (not line) or (line[0] == '#'): + if (not line) or (line[0] == "#"): continue section_match = section_re.match(line) if section_match: if section is not None: sections.append(section) - section = {'type': section_match.group(1)} + section = {"type": section_match.group(1)} else: - key, value = line.split('=') + key, value = line.split("=") key = key.rstrip() value = value.lstrip() key, value = convert(key, value) @@ -149,9 +150,8 @@ def convert(key, value): def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]: - """ - Calls one of the ``_create_(config, num_inputs)`` functions to create a PyTorch - module from the layer config. + """Calls one of the ``_create_(config, num_inputs)`` functions + to create a PyTorch module from the layer config. Args: config: Dictionary of configuration options for this layer. @@ -162,65 +162,65 @@ def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]: number of channels in its output. """ create_func = { - 'convolutional': _create_convolutional, - 'maxpool': _create_maxpool, - 'route': _create_route, - 'shortcut': _create_shortcut, - 'upsample': _create_upsample, - 'yolo': _create_yolo + "convolutional": _create_convolutional, + "maxpool": _create_maxpool, + "route": _create_route, + "shortcut": _create_shortcut, + "upsample": _create_upsample, + "yolo": _create_yolo, } - return create_func[config['type']](config, num_inputs) + return create_func[config["type"]](config, num_inputs) def _create_convolutional(config, num_inputs): module = nn.Sequential() - batch_normalize = config.get('batch_normalize', False) - padding = (config['size'] - 1) // 2 if config['pad'] else 0 + batch_normalize = config.get("batch_normalize", False) + padding = (config["size"] - 1) // 2 if config["pad"] else 0 conv = nn.Conv2d( - num_inputs[-1], config['filters'], config['size'], config['stride'], padding, bias=not batch_normalize + num_inputs[-1], config["filters"], config["size"], config["stride"], padding, bias=not batch_normalize ) - module.add_module('conv', conv) + module.add_module("conv", conv) if batch_normalize: - bn = nn.BatchNorm2d(config['filters']) - module.add_module('bn', bn) + bn = nn.BatchNorm2d(config["filters"]) + module.add_module("bn", bn) - activation_name = config['activation'] - if activation_name == 'leaky': + activation_name = config["activation"] + if activation_name == "leaky": leakyrelu = nn.LeakyReLU(0.1, inplace=True) - module.add_module('leakyrelu', leakyrelu) - elif activation_name == 'mish': + module.add_module("leakyrelu", leakyrelu) + elif activation_name == "mish": mish = yolo_layers.Mish() - module.add_module('mish', mish) - elif activation_name == 'swish': + module.add_module("mish", mish) + elif activation_name == "swish": swish = nn.SiLU(inplace=True) - module.add_module('swish', swish) - elif activation_name == 'logistic': + module.add_module("swish", swish) + elif activation_name == "logistic": logistic = nn.Sigmoid() - module.add_module('logistic', logistic) - elif activation_name == 'linear': + module.add_module("logistic", logistic) + elif activation_name == "linear": pass else: - raise ValueError('Unknown activation: ' + activation_name) + raise ValueError("Unknown activation: " + activation_name) - return module, config['filters'] + return module, config["filters"] def _create_maxpool(config, num_inputs): - padding = (config['size'] - 1) // 2 - module = nn.MaxPool2d(config['size'], config['stride'], padding) + padding = (config["size"] - 1) // 2 + module = nn.MaxPool2d(config["size"], config["stride"], padding) return module, num_inputs[-1] def _create_route(config, num_inputs): - num_chunks = config.get('groups', 1) - chunk_idx = config.get('group_id', 0) + num_chunks = config.get("groups", 1) + chunk_idx = config.get("group_id", 0) # 0 is the first layer, -1 is the previous layer last = len(num_inputs) - 1 - source_layers = [layer if layer >= 0 else last + layer for layer in config['layers']] + source_layers = [layer if layer >= 0 else last + layer for layer in config["layers"]] module = yolo_layers.RouteLayer(source_layers, num_chunks, chunk_idx) @@ -231,47 +231,47 @@ def _create_route(config, num_inputs): def _create_shortcut(config, num_inputs): - module = yolo_layers.ShortcutLayer(config['from']) + module = yolo_layers.ShortcutLayer(config["from"]) return module, num_inputs[-1] def _create_upsample(config, num_inputs): - module = nn.Upsample(scale_factor=config["stride"], mode='nearest') + module = nn.Upsample(scale_factor=config["stride"], mode="nearest") return module, num_inputs[-1] def _create_yolo(config, num_inputs): # The "anchors" list alternates width and height. - anchor_dims = config['anchors'] + anchor_dims = config["anchors"] anchor_dims = [(anchor_dims[i], anchor_dims[i + 1]) for i in range(0, len(anchor_dims), 2)] - xy_scale = config.get('scale_x_y', 1.0) - input_is_normalized = config.get('new_coords', 0) > 0 - ignore_threshold = config.get('ignore_thresh', 1.0) - overlap_loss_multiplier = config.get('iou_normalizer', 1.0) - class_loss_multiplier = config.get('cls_normalizer', 1.0) - confidence_loss_multiplier = config.get('obj_normalizer', 1.0) + xy_scale = config.get("scale_x_y", 1.0) + input_is_normalized = config.get("new_coords", 0) > 0 + ignore_threshold = config.get("ignore_thresh", 1.0) + overlap_loss_multiplier = config.get("iou_normalizer", 1.0) + class_loss_multiplier = config.get("cls_normalizer", 1.0) + confidence_loss_multiplier = config.get("obj_normalizer", 1.0) - overlap_loss_name = config.get('iou_loss', 'mse') - if overlap_loss_name == 'mse': + overlap_loss_name = config.get("iou_loss", "mse") + if overlap_loss_name == "mse": overlap_loss_func = yolo_layers.SELoss() - elif overlap_loss_name == 'giou': + elif overlap_loss_name == "giou": overlap_loss_func = yolo_layers.GIoULoss() else: overlap_loss_func = yolo_layers.IoULoss() module = yolo_layers.DetectionLayer( - num_classes=config['classes'], + num_classes=config["classes"], anchor_dims=anchor_dims, - anchor_ids=config['mask'], + anchor_ids=config["mask"], xy_scale=xy_scale, input_is_normalized=input_is_normalized, ignore_threshold=ignore_threshold, overlap_loss_func=overlap_loss_func, - image_space_loss=overlap_loss_name != 'mse', + image_space_loss=overlap_loss_name != "mse", overlap_loss_multiplier=overlap_loss_multiplier, class_loss_multiplier=class_loss_multiplier, - confidence_loss_multiplier=confidence_loss_multiplier + confidence_loss_multiplier=confidence_loss_multiplier, ) return module, num_inputs[-1] diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 98586efc19..36a00e74ef 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -2,13 +2,14 @@ import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch import nn, Tensor +from torch import Tensor, nn from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: from torchvision.ops import box_iou + try: from torchvision.ops import generalized_box_iou except ImportError: @@ -16,12 +17,11 @@ else: _GIOU_AVAILABLE = True else: - warn_missing_pkg('torchvision') + warn_missing_pkg("torchvision") def _corner_coordinates(xy: Tensor, wh: Tensor) -> Tensor: - """ - Converts box center points and sizes to corner coordinates. + """Converts box center points and sizes to corner coordinates. Args: xy: Center coordinates. Tensor of size ``[..., 2]``. @@ -37,9 +37,8 @@ def _corner_coordinates(xy: Tensor, wh: Tensor) -> Tensor: def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor: - """ - Calculates a matrix of intersections over union from box dimensions, assuming that the boxes - are located at the same coordinates. + """Calculates a matrix of intersections over union from box dimensions, + assuming that the boxes are located at the same coordinates. Args: dims1: Width and height of `N` boxes. Tensor of size ``[N, 2]``. @@ -60,27 +59,24 @@ def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor: class SELoss(nn.MSELoss): - def __init__(self): - super().__init__(reduction='none') + super().__init__(reduction="none") def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return super().forward(inputs, target).sum(1) class IoULoss(nn.Module): - def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return 1.0 - box_iou(inputs, target).diagonal() class GIoULoss(nn.Module): - def __init__(self) -> None: super().__init__() if not _GIOU_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover - 'A more recent version of `torchvision` is needed for generalized IoU loss.' + "A more recent version of `torchvision` is needed for generalized IoU loss." ) def forward(self, inputs: Tensor, target: Tensor) -> Tensor: @@ -88,8 +84,9 @@ def forward(self, inputs: Tensor, target: Tensor) -> Tensor: class DetectionLayer(nn.Module): - """ - A YOLO detection layer. A YOLO model has usually 1 - 3 detection layers at different + """A YOLO detection layer. + + A YOLO model has usually 1 - 3 detection layers at different resolutions. The loss should be summed from all of them. """ @@ -107,7 +104,7 @@ def __init__( image_space_loss: bool = False, overlap_loss_multiplier: float = 1.0, class_loss_multiplier: float = 1.0, - confidence_loss_multiplier: float = 1.0 + confidence_loss_multiplier: float = 1.0, ) -> None: """ Args: @@ -144,7 +141,7 @@ def __init__( super().__init__() if not _TORCHVISION_AVAILABLE: # pragma: no cover - raise ModuleNotFoundError('YOLO model uses `torchvision`, which is not installed yet.') + raise ModuleNotFoundError("YOLO model uses `torchvision`, which is not installed yet.") self.num_classes = num_classes self.all_anchor_dims = anchor_dims @@ -156,18 +153,16 @@ def __init__( self.overlap_loss_func = overlap_loss_func or SELoss() self.class_loss_func = class_loss_func or SELoss() - self.confidence_loss_func = confidence_loss_func or nn.MSELoss(reduction='none') + self.confidence_loss_func = confidence_loss_func or nn.MSELoss(reduction="none") self.image_space_loss = image_space_loss self.overlap_loss_multiplier = overlap_loss_multiplier self.class_loss_multiplier = class_loss_multiplier self.confidence_loss_multiplier = confidence_loss_multiplier - def forward(self, - x: Tensor, - image_size: Tensor, - targets: Optional[List[Dict[str, Tensor]]] = None) -> Tuple[Tensor, Dict[str, Tensor]]: - """ - Runs a forward pass through this YOLO detection layer. + def forward( + self, x: Tensor, image_size: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None + ) -> Tuple[Tensor, Dict[str, Tensor]]: + """Runs a forward pass through this YOLO detection layer. Maps cell-local coordinates to global coordinates in the image space, scales the bounding boxes with the anchors, converts the center coordinates to corner coordinates, and maps @@ -240,9 +235,8 @@ def forward(self, return output, losses, hits def _global_xy(self, xy: Tensor, image_size: Tensor) -> Tensor: - """ - Adds offsets to the predicted box center coordinates to obtain global coordinates to the - image. + """Adds offsets to the predicted box center coordinates to obtain + global coordinates to the image. The predicted coordinates are interpreted as coordinates inside a grid cell whose width and height is 1. Adding offset to the cell, dividing by the grid size, and multiplying by the @@ -271,10 +265,10 @@ def _global_xy(self, xy: Tensor, image_size: Tensor) -> Tensor: return (xy + offset) * scale def _low_confidence_mask(self, boxes: Tensor, targets: List[Dict[str, Tensor]]) -> Tensor: - """ - Initializes the mask that will be used to select predictors that are not predicting any - ground-truth target. The value will be ``True``, unless the predicted box overlaps any target - significantly (IoU greater than ``self.ignore_threshold``). + """Initializes the mask that will be used to select predictors that are + not predicting any ground-truth target. The value will be ``True``, + unless the predicted box overlaps any target significantly (IoU greater + than ``self.ignore_threshold``). Args: boxes: The predicted corner coordinates in the image space. Tensor of size @@ -291,7 +285,7 @@ def _low_confidence_mask(self, boxes: Tensor, targets: List[Dict[str, Tensor]]) results = torch.ones((batch_size, num_preds), dtype=torch.bool, device=boxes.device) for image_idx, (image_boxes, image_targets) in enumerate(zip(boxes, targets)): - target_boxes = image_targets['boxes'] + target_boxes = image_targets["boxes"] if target_boxes.shape[0] > 0: ious = box_iou(image_boxes, target_boxes) # [num_preds, num_targets] best_iou = ious.max(-1).values # [num_preds] @@ -300,12 +294,17 @@ def _low_confidence_mask(self, boxes: Tensor, targets: List[Dict[str, Tensor]]) return results.view((batch_size, height, width, boxes_per_cell)) def _calculate_losses( - self, boxes: Tensor, confidence: Tensor, classprob: Tensor, targets: List[Dict[str, Tensor]], - image_size: Tensor, lc_mask: Tensor + self, + boxes: Tensor, + confidence: Tensor, + classprob: Tensor, + targets: List[Dict[str, Tensor]], + image_size: Tensor, + lc_mask: Tensor, ) -> Dict[str, Tensor]: - """ - From the targets that are in the image space calculates the actual targets for the network - predictions, and returns a dictionary of training losses. + """From the targets that are in the image space calculates the actual + targets for the network predictions, and returns a dictionary of + training losses. Args: boxes: The predicted bounding boxes. A tensor sized @@ -347,7 +346,7 @@ def _calculate_losses( hits = 0 for image_idx, image_targets in enumerate(targets): - target_boxes = image_targets['boxes'] + target_boxes = image_targets["boxes"] if target_boxes.shape[0] < 1: continue @@ -412,7 +411,7 @@ def _calculate_losses( # The data may contain a different number of classes than this detection layer. In case # a label is greater than the number of classes that this layer predicts, it will be # mapped to the last class. - labels = image_targets['labels'] + labels = image_targets["labels"] labels = labels[selected] labels = torch.min(labels, torch.tensor(self.num_classes - 1, device=device)) target_label.append(labels) @@ -433,9 +432,9 @@ def _calculate_losses( overlap_loss = self.overlap_loss_func(pred_boxes, target_boxes) overlap_loss = overlap_loss * size_compensation overlap_loss = overlap_loss.sum() / batch_size - losses['overlap'] = overlap_loss * self.overlap_loss_multiplier + losses["overlap"] = overlap_loss * self.overlap_loss_multiplier else: - losses['overlap'] = torch.tensor(0.0, device=device) + losses["overlap"] = torch.tensor(0.0, device=device) if pred_classprob and target_label: pred_classprob = torch.cat(pred_classprob) @@ -444,9 +443,9 @@ def _calculate_losses( target_classprob = target_classprob.to(dtype=pred_classprob.dtype) class_loss = self.class_loss_func(pred_classprob, target_classprob) class_loss = class_loss.sum() / batch_size - losses['class'] = class_loss * self.class_loss_multiplier + losses["class"] = class_loss * self.class_loss_multiplier else: - losses['class'] = torch.tensor(0.0, device=device) + losses["class"] = torch.tensor(0.0, device=device) pred_low_confidence = confidence[lc_mask] target_low_confidence = torch.zeros_like(pred_low_confidence) @@ -460,7 +459,7 @@ def _calculate_losses( target_confidence = target_low_confidence confidence_loss = self.confidence_loss_func(pred_confidence, target_confidence) confidence_loss = confidence_loss.sum() / batch_size - losses['confidence'] = confidence_loss * self.confidence_loss_multiplier + losses["confidence"] = confidence_loss * self.confidence_loss_multiplier return losses, hits @@ -473,7 +472,8 @@ def forward(self, x): class RouteLayer(nn.Module): - """Route layer concatenates the output (or part of it) from given layers.""" + """Route layer concatenates the output (or part of it) from given + layers.""" def __init__(self, source_layers: List[int], num_chunks: int, chunk_idx: int) -> None: """ diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 0d1eb38341..c17cccf463 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -6,7 +6,7 @@ import torch.nn as nn from pytorch_lightning import LightningModule from pytorch_lightning.utilities import rank_zero_info -from torch import optim, Tensor +from torch import Tensor, optim from pl_bolts.models.detection.yolo.yolo_layers import DetectionLayer, RouteLayer, ShortcutLayer from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR @@ -17,15 +17,15 @@ from torchvision.ops import nms from torchvision.transforms import functional as F else: - warn_missing_pkg('torchvision') + warn_missing_pkg("torchvision") log = logging.getLogger(__name__) class YOLO(LightningModule): - """ - PyTorch Lightning implementation of `YOLOv3 `_ with some - improvements from `YOLOv4 `_. + """PyTorch Lightning implementation of `YOLOv3 + `_ with some improvements from `YOLOv4 + `_. *YOLOv3 paper authors*: Joseph Redmon and Ali Farhadi @@ -77,20 +77,12 @@ def __init__( self, network: nn.ModuleList, optimizer: Type[optim.Optimizer] = optim.SGD, - optimizer_params: Dict[str, Any] = { - 'lr': 0.001, - 'momentum': 0.9, - 'weight_decay': 0.0005 - }, + optimizer_params: Dict[str, Any] = {"lr": 0.001, "momentum": 0.9, "weight_decay": 0.0005}, lr_scheduler: Type[optim.lr_scheduler._LRScheduler] = LinearWarmupCosineAnnealingLR, - lr_scheduler_params: Dict[str, Any] = { - 'warmup_epochs': 1, - 'max_epochs': 300, - 'warmup_start_lr': 0.0 - }, + lr_scheduler_params: Dict[str, Any] = {"warmup_epochs": 1, "max_epochs": 300, "warmup_start_lr": 0.0}, confidence_threshold: float = 0.2, nms_threshold: float = 0.45, - max_predictions_per_image: int = -1 + max_predictions_per_image: int = -1, ) -> None: """ Args: @@ -111,8 +103,7 @@ def __init__( super().__init__() if not _TORCHVISION_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover - 'YOLO model uses `torchvision`, which is not installed yet.') + raise ModuleNotFoundError("YOLO model uses `torchvision`, which is not installed yet.") # pragma: no-cover self.network = network self.optimizer_class = optimizer @@ -124,13 +115,11 @@ def __init__( self.max_predictions_per_image = max_predictions_per_image def forward( - self, - images: Tensor, - targets: Optional[List[Dict[str, Tensor]]] = None + self, images: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None ) -> Tuple[Tensor, Dict[str, Tensor]]: - """ - Runs a forward pass through the network (all layers listed in ``self.network``), and if - training targets are provided, computes the losses from the detection layers. + """Runs a forward pass through the network (all layers listed in + ``self.network``), and if training targets are provided, computes the + losses from the detection layers. Detections are concatenated from the detection layers. Each image will produce `N * num_anchors * grid_height * grid_width` detections, where `N` depends on the number of @@ -185,15 +174,15 @@ def forward( return detections total_hits = sum(hits) - num_targets = sum(len(image_targets['boxes']) for image_targets in targets) + num_targets = sum(len(image_targets["boxes"]) for image_targets in targets) if total_hits != num_targets: log.warning( - f'{num_targets} training targets were matched a total of {total_hits} times by detection layers. ' - 'Anchors may have been configured incorrectly.' + f"{num_targets} training targets were matched a total of {total_hits} times by detection layers. " + "Anchors may have been configured incorrectly." ) for layer_idx, layer_hits in enumerate(hits): hit_rate = layer_hits / total_hits if total_hits > 0 else 1.0 - self.log(f'layer_{layer_idx}_hit_rate', hit_rate, sync_dist=False) + self.log(f"layer_{layer_idx}_hit_rate", hit_rate, sync_dist=False) def total_loss(loss_name): """Returns the sum of the loss over detection layers.""" @@ -210,8 +199,7 @@ def configure_optimizers(self) -> Tuple[List, List]: return [optimizer], [lr_scheduler] def training_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int) -> Dict[str, Tensor]: - """ - Computes the training loss. + """Computes the training loss. Args: batch: A tuple of images and targets. Images is a list of 3-dimensional tensors. @@ -228,14 +216,13 @@ def training_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], bat # sync_dist=True is broken in some versions of Lightning and may cause the sum of the loss # across GPUs to be returned. for name, value in losses.items(): - self.log(f'train/{name}_loss', value, prog_bar=True, sync_dist=False) - self.log('train/total_loss', total_loss, sync_dist=False) + self.log(f"train/{name}_loss", value, prog_bar=True, sync_dist=False) + self.log("train/total_loss", total_loss, sync_dist=False) - return {'loss': total_loss} + return {"loss": total_loss} def validation_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int): - """ - Evaluates a batch of data from the validation set. + """Evaluates a batch of data from the validation set. Args: batch: A tuple of images and targets. Images is a list of 3-dimensional tensors. @@ -249,12 +236,11 @@ def validation_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], b total_loss = torch.stack(tuple(losses.values())).sum() for name, value in losses.items(): - self.log(f'val/{name}_loss', value, sync_dist=True) - self.log('val/total_loss', total_loss, sync_dist=True) + self.log(f"val/{name}_loss", value, sync_dist=True) + self.log("val/total_loss", total_loss, sync_dist=True) def test_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int): - """ - Evaluates a batch of data from the test set. + """Evaluates a batch of data from the test set. Args: batch: A tuple of images and targets. Images is a list of 3-dimensional tensors. @@ -268,13 +254,12 @@ def test_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_i total_loss = torch.stack(tuple(losses.values())).sum() for name, value in losses.items(): - self.log(f'test/{name}_loss', value, sync_dist=True) - self.log('test/total_loss', total_loss, sync_dist=True) + self.log(f"test/{name}_loss", value, sync_dist=True) + self.log("test/total_loss", total_loss, sync_dist=True) def infer(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - """ - Feeds an image to the network and returns the detected bounding boxes, confidence scores, - and class labels. + """Feeds an image to the network and returns the detected bounding + boxes, confidence scores, and class labels. Args: image: An input image, a tensor of uint8 values sized ``[channels, height, width]``. @@ -291,14 +276,13 @@ def infer(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: detections = self(image.unsqueeze(0)) detections = self._split_detections(detections) detections = self._filter_detections(detections) - boxes = detections['boxes'][0] - scores = detections['scores'][0] - labels = detections['labels'][0] + boxes = detections["boxes"][0] + scores = detections["scores"][0] + labels = detections["labels"][0] return boxes, scores, labels def load_darknet_weights(self, weight_file): - """ - Loads weights to layer modules from a pretrained Darknet model. + """Loads weights to layer modules from a pretrained Darknet model. One may want to continue training from the pretrained weights, on a dataset with a different number of object categories. The number of kernels in the convolutional layers @@ -313,14 +297,16 @@ def load_darknet_weights(self, weight_file): version = np.fromfile(weight_file, count=3, dtype=np.int32) images_seen = np.fromfile(weight_file, count=1, dtype=np.int64) rank_zero_info( - f'Loading weights from Darknet model version {version[0]}.{version[1]}.{version[2]} ' - f'that has been trained on {images_seen[0]} images.' + f"Loading weights from Darknet model version {version[0]}.{version[1]}.{version[2]} " + f"that has been trained on {images_seen[0]} images." ) def read(tensor): - """ - Reads the contents of ``tensor`` from the current position of ``weight_file``. - If there's no more data in ``weight_file``, returns without error. + """Reads the contents of ``tensor`` from the current position of + ``weight_file``. + + If there's no more data in ``weight_file``, returns without + error. """ x = np.fromfile(weight_file, count=tensor.numel(), dtype=np.float32) if x.shape[0] == 0: @@ -351,11 +337,10 @@ def read(tensor): read(conv.weight) def _validate_batch( - self, - batch: Tuple[List[Tensor], List[Dict[str, Tensor]]] + self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]] ) -> Tuple[Tensor, List[Dict[str, Tensor]]]: - """ - Reads a batch of data, validates the format, and stacks the images into a single tensor. + """Reads a batch of data, validates the format, and stacks the images + into a single tensor. Args: batch: The batch of data read by the :class:`~torch.utils.data.DataLoader`. @@ -373,14 +358,14 @@ def _validate_batch( raise ValueError("Expected image to be of type Tensor, got {}.".format(type(image))) for target in targets: - boxes = target['boxes'] + boxes = target["boxes"] if not isinstance(boxes, Tensor): raise ValueError("Expected target boxes to be of type Tensor, got {}.".format(type(boxes))) if (len(boxes.shape) != 2) or (boxes.shape[-1] != 4): raise ValueError( "Expected target boxes to be tensors of shape [N, 4], got {}.".format(list(boxes.shape)) ) - labels = target['labels'] + labels = target["labels"] if not isinstance(labels, Tensor): raise ValueError("Expected target labels to be of type Tensor, got {}.".format(type(labels))) if len(labels.shape) != 1: @@ -392,8 +377,8 @@ def _validate_batch( return images, targets def _split_detections(self, detections: Tensor) -> Dict[str, Tensor]: - """ - Splits the detection tensor returned by a forward pass into a dictionary. + """Splits the detection tensor returned by a forward pass into a + dictionary. The fields of the dictionary are as follows: - boxes (``Tensor[batch_size, N, 4]``): detected bounding box `(x1, y1, x2, y2)` coordinates @@ -411,15 +396,16 @@ def _split_detections(self, detections: Tensor) -> Dict[str, Tensor]: scores = detections[..., 4] classprobs = detections[..., 5:] classprobs, labels = torch.max(classprobs, -1) - return {'boxes': boxes, 'scores': scores, 'classprobs': classprobs, 'labels': labels} + return {"boxes": boxes, "scores": scores, "classprobs": classprobs, "labels": labels} def _filter_detections(self, detections: Dict[str, Tensor]) -> Dict[str, List[Tensor]]: - """ - Filters detections based on confidence threshold. Then for every class performs non-maximum - suppression (NMS). NMS iterates the bounding boxes that predict this class in descending - order of confidence score, and removes lower scoring boxes that have an IoU greater than - the NMS threshold with a higher scoring box. Finally the detections are sorted by descending - confidence and possible truncated to the maximum number of predictions. + """Filters detections based on confidence threshold. Then for every + class performs non-maximum suppression (NMS). NMS iterates the bounding + boxes that predict this class in descending order of confidence score, + and removes lower scoring boxes that have an IoU greater than the NMS + threshold with a higher scoring box. Finally the detections are sorted + by descending confidence and possible truncated to the maximum number + of predictions. Args: detections: All detections. A dictionary of tensors, each containing the predictions @@ -428,10 +414,10 @@ def _filter_detections(self, detections: Dict[str, Tensor]) -> Dict[str, List[Te Returns: Filtered detections. A dictionary of lists, each containing a tensor per image. """ - boxes = detections['boxes'] - scores = detections['scores'] - classprobs = detections['classprobs'] - labels = detections['labels'] + boxes = detections["boxes"] + scores = detections["scores"] + classprobs = detections["classprobs"] + labels = detections["labels"] out_boxes = [] out_scores = [] @@ -473,13 +459,13 @@ def _filter_detections(self, detections: Dict[str, Tensor]) -> Dict[str, List[Te # Sort by descending confidence and limit the maximum number of predictions. indices = torch.argsort(img_out_scores, descending=True) if self.max_predictions_per_image >= 0: - indices = indices[:self.max_predictions_per_image] + indices = indices[: self.max_predictions_per_image] out_boxes.append(img_out_boxes[indices]) out_scores.append(img_out_scores[indices]) out_classprobs.append(img_out_classprobs[indices]) out_labels.append(img_out_labels[indices]) - return {'boxes': out_boxes, 'scores': out_scores, 'classprobs': out_classprobs, 'labels': out_labels} + return {"boxes": out_boxes, "scores": out_scores, "classprobs": out_classprobs, "labels": out_labels} class Resize: @@ -500,22 +486,17 @@ def __call__(self, image, target): resize_ratio = torch.tensor(self.output_size) / original_size image = F.resize(image, self.output_size) scale = torch.tensor( - [ - resize_ratio[1], # y - resize_ratio[0], # x - resize_ratio[1], # y - resize_ratio[0] # x - ], - device=target['boxes'].device + [resize_ratio[1], resize_ratio[0], resize_ratio[1], resize_ratio[0]], # y, x, y, x + device=target["boxes"].device, ) - target['boxes'] = target['boxes'] * scale + target["boxes"] = target["boxes"] * scale return image, target def run_cli(): from argparse import ArgumentParser - from pytorch_lightning import seed_everything, Trainer + from pytorch_lightning import Trainer, seed_everything from pl_bolts.datamodules import VOCDetectionDataModule from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration @@ -523,54 +504,50 @@ def run_cli(): seed_everything(42) parser = ArgumentParser() + parser.add_argument("--config", type=str, metavar="PATH", required=True, help="read model configuration from PATH") parser.add_argument( - '--config', type=str, metavar='PATH', required=True, - help='read model configuration from PATH' - ) - parser.add_argument( - '--darknet-weights', type=str, metavar='PATH', - help='read the initial model weights from PATH in Darknet format' - ) - parser.add_argument( - '--batch-size', type=int, metavar='N', default=16, - help='batch size is N image' + "--darknet-weights", type=str, metavar="PATH", help="read the initial model weights from PATH in Darknet format" ) + parser.add_argument("--batch-size", type=int, metavar="N", default=16, help="batch size is N image") + parser.add_argument("--lr", type=float, metavar="LR", default=0.0013, help="learning rate after the warmup period") parser.add_argument( - '--lr', type=float, metavar='LR', default=0.0013, - help='learning rate after the warmup period' + "--momentum", + type=float, + metavar="GAMMA", + default=0.9, + help="if nonzero, the optimizer uses momentum with factor GAMMA", ) parser.add_argument( - '--momentum', type=float, metavar='GAMMA', default=0.9, - help='if nonzero, the optimizer uses momentum with factor GAMMA' + "--weight-decay", + type=float, + metavar="LAMBDA", + default=0.0005, + help="if nonzero, the optimizer uses weight decay (L2 penalty) with factor LAMBDA", ) parser.add_argument( - '--weight-decay', type=float, metavar='LAMBDA', default=0.0005, - help='if nonzero, the optimizer uses weight decay (L2 penalty) with factor LAMBDA' + "--warmup-epochs", type=int, metavar="N", default=1, help="learning rate warmup period is N epochs" ) + parser.add_argument("--max-epochs", type=int, metavar="N", default=300, help="train at most N epochs") parser.add_argument( - '--warmup-epochs', type=int, metavar='N', default=1, - help='learning rate warmup period is N epochs' + "--initial-lr", type=float, metavar="LR", default=0.0, help="learning rate before the warmup period" ) parser.add_argument( - '--max-epochs', type=int, metavar='N', default=300, - help='train at most N epochs' + "--confidence-threshold", + type=float, + metavar="THRESHOLD", + default=0.001, + help="keep predictions only if the confidence is above THRESHOLD", ) parser.add_argument( - '--initial-lr', type=float, metavar='LR', default=0.0, - help='learning rate before the warmup period' + "--nms-threshold", + type=float, + metavar="THRESHOLD", + default=0.45, + help="non-maximum suppression removes predicted boxes that have IoU greater than " + "THRESHOLD with a higher scoring box", ) parser.add_argument( - '--confidence-threshold', type=float, metavar='THRESHOLD', default=0.001, - help='keep predictions only if the confidence is above THRESHOLD' - ) - parser.add_argument( - '--nms-threshold', type=float, metavar='THRESHOLD', default=0.45, - help='non-maximum suppression removes predicted boxes that have IoU greater than ' - 'THRESHOLD with a higher scoring box' - ) - parser.add_argument( - '--max-predictions-per-image', type=int, metavar='N', default=100, - help='keep at most N best predictions' + "--max-predictions-per-image", type=int, metavar="N", default=100, help="keep at most N best predictions" ) parser = VOCDetectionDataModule.add_argparse_args(parser) @@ -583,15 +560,11 @@ def run_cli(): datamodule = VOCDetectionDataModule.from_argparse_args(args) datamodule.prepare_data() - optimizer_params = { - 'lr': args.lr, - 'momentum': args.momentum, - 'weight_decay': args.weight_decay - } + optimizer_params = {"lr": args.lr, "momentum": args.momentum, "weight_decay": args.weight_decay} lr_scheduler_params = { - 'warmup_epochs': args.warmup_epochs, - 'max_epochs': args.max_epochs, - 'warmup_start_lr': args.initial_lr + "warmup_epochs": args.warmup_epochs, + "max_epochs": args.max_epochs, + "warmup_start_lr": args.initial_lr, } model = YOLO( network=config.get_network(), @@ -599,16 +572,17 @@ def run_cli(): lr_scheduler_params=lr_scheduler_params, confidence_threshold=args.confidence_threshold, nms_threshold=args.nms_threshold, - max_predictions_per_image=args.max_predictions_per_image + max_predictions_per_image=args.max_predictions_per_image, ) if args.darknet_weights is not None: - with open(args.darknet_weights, 'r') as weight_file: + with open(args.darknet_weights, "r") as weight_file: model.load_darknet_weights(weight_file) trainer = Trainer.from_argparse_args(args) trainer.fit( - model, datamodule.train_dataloader(args.batch_size, transforms), - datamodule.val_dataloader(args.batch_size, transforms) + model, + datamodule.train_dataloader(args.batch_size, transforms), + datamodule.val_dataloader(args.batch_size, transforms), ) diff --git a/tests/models/test_detection.py b/tests/models/test_detection.py index 54b062981d..82ca40c4d8 100644 --- a/tests/models/test_detection.py +++ b/tests/models/test_detection.py @@ -6,7 +6,7 @@ from torch.utils.data import DataLoader from pl_bolts.datasets import DummyDetectionDataset -from pl_bolts.models.detection import FasterRCNN, YOLO, YOLOConfiguration +from pl_bolts.models.detection import YOLO, FasterRCNN, YOLOConfiguration from pl_bolts.models.detection.yolo.yolo_layers import _aligned_iou from tests import TEST_ROOT @@ -43,7 +43,7 @@ def test_fasterrcnn_bbone_train(tmpdir): def test_yolo(tmpdir): - config_path = Path(TEST_ROOT) / 'data' / 'yolo.cfg' + config_path = Path(TEST_ROOT) / "data" / "yolo.cfg" config = YOLOConfiguration(config_path) model = YOLO(config.get_network()) @@ -52,7 +52,7 @@ def test_yolo(tmpdir): def test_yolo_train(tmpdir): - config_path = Path(TEST_ROOT) / 'data' / 'yolo.cfg' + config_path = Path(TEST_ROOT) / "data" / "yolo.cfg" config = YOLOConfiguration(config_path) model = YOLO(config.get_network()) @@ -64,10 +64,14 @@ def test_yolo_train(tmpdir): @pytest.mark.parametrize( - "dims1, dims2, expected_ious", [( - torch.tensor([[1.0, 1.0], [10.0, 1.0], [100.0, 10.0]]), torch.tensor([[1.0, 10.0], [2.0, 20.0]]), - torch.tensor([[1.0 / 10.0, 1.0 / 40.0], [1.0 / 19.0, 2.0 / 48.0], [10.0 / 1000.0, 20.0 / 1020.0]]) - )] + "dims1, dims2, expected_ious", + [ + ( + torch.tensor([[1.0, 1.0], [10.0, 1.0], [100.0, 10.0]]), + torch.tensor([[1.0, 10.0], [2.0, 20.0]]), + torch.tensor([[1.0 / 10.0, 1.0 / 40.0], [1.0 / 19.0, 2.0 / 48.0], [10.0 / 1000.0, 20.0 / 1020.0]]), + ) + ], ) def test_aligned_iou(dims1, dims2, expected_ious): torch.testing.assert_allclose(_aligned_iou(dims1, dims2), expected_ious) From 7b32d64bfdc6e7cd6c7d368c1e59d6f378e2cb86 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Thu, 19 Aug 2021 03:28:44 +0300 Subject: [PATCH 54/61] Ran docformatter with correct config --- .../datamodules/vocdetection_datamodule.py | 2 +- pl_bolts/models/detection/yolo/yolo_config.py | 20 ++++----- pl_bolts/models/detection/yolo/yolo_layers.py | 20 ++++----- pl_bolts/models/detection/yolo/yolo_module.py | 45 +++++++------------ 4 files changed, 34 insertions(+), 53 deletions(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 231c590a05..20f84d541b 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -183,7 +183,7 @@ def val_dataloader( transforms: Optional[List[Callable]] = None, image_transforms: Optional[Callable] = None, ) -> DataLoader: - """VOCDetection val set uses the `val` subset + """VOCDetection val set uses the `val` subset. Args: batch_size: size of batch diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index ff6ea25a9f..8c48292ff0 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -9,17 +9,15 @@ class YOLOConfiguration: - """This class can be used to parse the configuration files of the Darknet - YOLOv4 implementation. + """This class can be used to parse the configuration files of the Darknet YOLOv4 implementation. The :func:`~pl_bolts.models.detection.yolo.yolo_config.YOLOConfiguration.get_network` method returns a PyTorch module list that can be used to construct a YOLO model. """ def __init__(self, path: str) -> None: - """Saves the variables from the first configuration section to - attributes of this object, and the rest of the sections to the - ``layer_configs`` list. + """Saves the variables from the first configuration section to attributes of this object, and the rest of + the sections to the ``layer_configs`` list. Args: path: Path to a configuration file @@ -35,9 +33,8 @@ def __init__(self, path: str) -> None: self.layer_configs = sections[1:] def get_network(self) -> nn.ModuleList: - """Iterates through the layers from the configuration and creates - corresponding PyTorch modules. Returns the network structure that can - be used to create a YOLO model. + """Iterates through the layers from the configuration and creates corresponding PyTorch modules. Returns + the network structure that can be used to create a YOLO model. Returns: A :class:`~torch.nn.ModuleList` that defines the YOLO network. @@ -52,8 +49,7 @@ def get_network(self) -> nn.ModuleList: return result def _read_file(self, config_file: Iterable[str]) -> List[Dict[str, Any]]: - """Reads a YOLOv4 network configuration file and returns a list of - configuration sections. + """Reads a YOLOv4 network configuration file and returns a list of configuration sections. Args: config_file: The configuration file to read. @@ -150,8 +146,8 @@ def convert(key, value): def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]: - """Calls one of the ``_create_(config, num_inputs)`` functions - to create a PyTorch module from the layer config. + """Calls one of the ``_create_(config, num_inputs)`` functions to create a PyTorch module from the + layer config. Args: config: Dictionary of configuration options for this layer. diff --git a/pl_bolts/models/detection/yolo/yolo_layers.py b/pl_bolts/models/detection/yolo/yolo_layers.py index 36a00e74ef..9b1ee891df 100644 --- a/pl_bolts/models/detection/yolo/yolo_layers.py +++ b/pl_bolts/models/detection/yolo/yolo_layers.py @@ -37,8 +37,8 @@ def _corner_coordinates(xy: Tensor, wh: Tensor) -> Tensor: def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor: - """Calculates a matrix of intersections over union from box dimensions, - assuming that the boxes are located at the same coordinates. + """Calculates a matrix of intersections over union from box dimensions, assuming that the boxes are located at + the same coordinates. Args: dims1: Width and height of `N` boxes. Tensor of size ``[N, 2]``. @@ -235,8 +235,7 @@ def forward( return output, losses, hits def _global_xy(self, xy: Tensor, image_size: Tensor) -> Tensor: - """Adds offsets to the predicted box center coordinates to obtain - global coordinates to the image. + """Adds offsets to the predicted box center coordinates to obtain global coordinates to the image. The predicted coordinates are interpreted as coordinates inside a grid cell whose width and height is 1. Adding offset to the cell, dividing by the grid size, and multiplying by the @@ -265,9 +264,8 @@ def _global_xy(self, xy: Tensor, image_size: Tensor) -> Tensor: return (xy + offset) * scale def _low_confidence_mask(self, boxes: Tensor, targets: List[Dict[str, Tensor]]) -> Tensor: - """Initializes the mask that will be used to select predictors that are - not predicting any ground-truth target. The value will be ``True``, - unless the predicted box overlaps any target significantly (IoU greater + """Initializes the mask that will be used to select predictors that are not predicting any ground-truth + target. The value will be ``True``, unless the predicted box overlaps any target significantly (IoU greater than ``self.ignore_threshold``). Args: @@ -302,9 +300,8 @@ def _calculate_losses( image_size: Tensor, lc_mask: Tensor, ) -> Dict[str, Tensor]: - """From the targets that are in the image space calculates the actual - targets for the network predictions, and returns a dictionary of - training losses. + """From the targets that are in the image space calculates the actual targets for the network predictions, + and returns a dictionary of training losses. Args: boxes: The predicted bounding boxes. A tensor sized @@ -472,8 +469,7 @@ def forward(self, x): class RouteLayer(nn.Module): - """Route layer concatenates the output (or part of it) from given - layers.""" + """Route layer concatenates the output (or part of it) from given layers.""" def __init__(self, source_layers: List[int], num_chunks: int, chunk_idx: int) -> None: """ diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index c17cccf463..4cf27c4f88 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -23,17 +23,13 @@ class YOLO(LightningModule): - """PyTorch Lightning implementation of `YOLOv3 - `_ with some improvements from `YOLOv4 - `_. + """PyTorch Lightning implementation of YOLOv3 and YOLOv4. - *YOLOv3 paper authors*: Joseph Redmon and Ali Farhadi + *YOLOv3 paper*: `Joseph Redmon and Ali Farhadi `_ - *YOLOv4 paper authors*: Alexey Bochkovskiy, Chien-Yao Wang, Hong-Yuan Mark Liao + *YOLOv4 paper*: `Alexey Bochkovskiy, Chien-Yao Wang, and Hong-Yuan Mark Liao `_ - *Model implemented by*: - - - `Seppo Enarvi `_ + *Implementation*: `Seppo Enarvi `_ The network architecture can be read from a Darknet configuration file using the :class:`~pl_bolts.models.detection.yolo.yolo_config.YOLOConfiguration` class, or created by @@ -117,9 +113,8 @@ def __init__( def forward( self, images: Tensor, targets: Optional[List[Dict[str, Tensor]]] = None ) -> Tuple[Tensor, Dict[str, Tensor]]: - """Runs a forward pass through the network (all layers listed in - ``self.network``), and if training targets are provided, computes the - losses from the detection layers. + """Runs a forward pass through the network (all layers listed in ``self.network``), and if training targets + are provided, computes the losses from the detection layers. Detections are concatenated from the detection layers. Each image will produce `N * num_anchors * grid_height * grid_width` detections, where `N` depends on the number of @@ -258,8 +253,8 @@ def test_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_i self.log("test/total_loss", total_loss, sync_dist=True) def infer(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - """Feeds an image to the network and returns the detected bounding - boxes, confidence scores, and class labels. + """Feeds an image to the network and returns the detected bounding boxes, confidence scores, and class + labels. Args: image: An input image, a tensor of uint8 values sized ``[channels, height, width]``. @@ -302,11 +297,9 @@ def load_darknet_weights(self, weight_file): ) def read(tensor): - """Reads the contents of ``tensor`` from the current position of - ``weight_file``. + """Reads the contents of ``tensor`` from the current position of ``weight_file``. - If there's no more data in ``weight_file``, returns without - error. + If there's no more data in ``weight_file``, returns without error. """ x = np.fromfile(weight_file, count=tensor.numel(), dtype=np.float32) if x.shape[0] == 0: @@ -339,8 +332,7 @@ def read(tensor): def _validate_batch( self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]] ) -> Tuple[Tensor, List[Dict[str, Tensor]]]: - """Reads a batch of data, validates the format, and stacks the images - into a single tensor. + """Reads a batch of data, validates the format, and stacks the images into a single tensor. Args: batch: The batch of data read by the :class:`~torch.utils.data.DataLoader`. @@ -377,8 +369,7 @@ def _validate_batch( return images, targets def _split_detections(self, detections: Tensor) -> Dict[str, Tensor]: - """Splits the detection tensor returned by a forward pass into a - dictionary. + """Splits the detection tensor returned by a forward pass into a dictionary. The fields of the dictionary are as follows: - boxes (``Tensor[batch_size, N, 4]``): detected bounding box `(x1, y1, x2, y2)` coordinates @@ -399,13 +390,11 @@ def _split_detections(self, detections: Tensor) -> Dict[str, Tensor]: return {"boxes": boxes, "scores": scores, "classprobs": classprobs, "labels": labels} def _filter_detections(self, detections: Dict[str, Tensor]) -> Dict[str, List[Tensor]]: - """Filters detections based on confidence threshold. Then for every - class performs non-maximum suppression (NMS). NMS iterates the bounding - boxes that predict this class in descending order of confidence score, - and removes lower scoring boxes that have an IoU greater than the NMS - threshold with a higher scoring box. Finally the detections are sorted - by descending confidence and possible truncated to the maximum number - of predictions. + """Filters detections based on confidence threshold. Then for every class performs non-maximum suppression + (NMS). NMS iterates the bounding boxes that predict this class in descending order of confidence score, and + removes lower scoring boxes that have an IoU greater than the NMS threshold with a higher scoring box. + Finally the detections are sorted by descending confidence and possible truncated to the maximum number of + predictions. Args: detections: All detections. A dictionary of tensors, each containing the predictions From 191e8650fe15e8f0ab6042ab9e04ce61ee225196 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Thu, 19 Aug 2021 03:41:38 +0300 Subject: [PATCH 55/61] Ran pyupgrade --- pl_bolts/models/detection/yolo/yolo_config.py | 2 +- pl_bolts/models/detection/yolo/yolo_module.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_config.py b/pl_bolts/models/detection/yolo/yolo_config.py index 8c48292ff0..fea56df7e8 100644 --- a/pl_bolts/models/detection/yolo/yolo_config.py +++ b/pl_bolts/models/detection/yolo/yolo_config.py @@ -22,7 +22,7 @@ def __init__(self, path: str) -> None: Args: path: Path to a configuration file """ - with open(path, "r") as config_file: + with open(path) as config_file: sections = self._read_file(config_file) if len(sections) < 2: diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 4cf27c4f88..0990493fb7 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -343,26 +343,26 @@ def _validate_batch( images, targets = batch if len(images) != len(targets): - raise ValueError("Got {} images, but targets for {} images.".format(len(images), len(targets))) + raise ValueError(f"Got {len(images)} images, but targets for {len(targets)} images.") for image in images: if not isinstance(image, Tensor): - raise ValueError("Expected image to be of type Tensor, got {}.".format(type(image))) + raise ValueError(f"Expected image to be of type Tensor, got {type(image)}.") for target in targets: boxes = target["boxes"] if not isinstance(boxes, Tensor): - raise ValueError("Expected target boxes to be of type Tensor, got {}.".format(type(boxes))) + raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.") if (len(boxes.shape) != 2) or (boxes.shape[-1] != 4): raise ValueError( - "Expected target boxes to be tensors of shape [N, 4], got {}.".format(list(boxes.shape)) + f"Expected target boxes to be tensors of shape [N, 4], got {list(boxes.shape)}." ) labels = target["labels"] if not isinstance(labels, Tensor): - raise ValueError("Expected target labels to be of type Tensor, got {}.".format(type(labels))) + raise ValueError(f"Expected target labels to be of type Tensor, got {type(labels)}.") if len(labels.shape) != 1: raise ValueError( - "Expected target labels to be tensors of shape [N], got {}.".format(list(labels.shape)) + f"Expected target labels to be tensors of shape [N], got {list(labels.shape)}." ) images = torch.stack(images) @@ -564,7 +564,7 @@ def run_cli(): max_predictions_per_image=args.max_predictions_per_image, ) if args.darknet_weights is not None: - with open(args.darknet_weights, "r") as weight_file: + with open(args.darknet_weights) as weight_file: model.load_darknet_weights(weight_file) trainer = Trainer.from_argparse_args(args) From 66c8111c39073e4ac92237bf910adcb8b7907270 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Thu, 19 Aug 2021 03:49:28 +0300 Subject: [PATCH 56/61] Reformatted --- pl_bolts/models/detection/yolo/yolo_module.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 0990493fb7..fda96ed3f4 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -354,16 +354,12 @@ def _validate_batch( if not isinstance(boxes, Tensor): raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.") if (len(boxes.shape) != 2) or (boxes.shape[-1] != 4): - raise ValueError( - f"Expected target boxes to be tensors of shape [N, 4], got {list(boxes.shape)}." - ) + raise ValueError(f"Expected target boxes to be tensors of shape [N, 4], got {list(boxes.shape)}.") labels = target["labels"] if not isinstance(labels, Tensor): raise ValueError(f"Expected target labels to be of type Tensor, got {type(labels)}.") if len(labels.shape) != 1: - raise ValueError( - f"Expected target labels to be tensors of shape [N], got {list(labels.shape)}." - ) + raise ValueError(f"Expected target labels to be tensors of shape [N], got {list(labels.shape)}.") images = torch.stack(images) return images, targets From 572ca4f69d5542afd077c13c4f590760fbb1a347 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 30 Aug 2021 12:45:59 +0300 Subject: [PATCH 57/61] VOCDetectionDataModule constructor takes batch size and the transforms take image and target --- .../datamodules/vocdetection_datamodule.py | 83 +++++++------------ .../faster_rcnn/faster_rcnn_module.py | 3 +- pl_bolts/models/detection/yolo/yolo_module.py | 82 ++++++++++++------ 3 files changed, 88 insertions(+), 80 deletions(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 20f84d541b..ff62313908 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -1,3 +1,4 @@ +import os from typing import Any, Callable, Dict, List, Optional, Tuple import torch @@ -107,10 +108,11 @@ class VOCDetectionDataModule(LightningDataModule): def __init__( self, - data_dir: str, + data_dir: Optional[str] = None, year: str = "2012", num_workers: int = 0, normalize: bool = False, + batch_size: int = 16, shuffle: bool = True, pin_memory: bool = True, drop_last: bool = False, @@ -125,9 +127,10 @@ def __init__( super().__init__(*args, **kwargs) self.year = year - self.data_dir = data_dir + self.data_dir = data_dir if data_dir is not None else os.getcwd() self.num_workers = num_workers self.normalize = normalize + self.batch_size = batch_size self.shuffle = shuffle self.pin_memory = pin_memory self.drop_last = drop_last @@ -145,78 +148,50 @@ def prepare_data(self) -> None: VOCDetection(self.data_dir, year=self.year, image_set="train", download=True) VOCDetection(self.data_dir, year=self.year, image_set="val", download=True) - def train_dataloader( - self, - batch_size: int = 1, - transforms: Optional[List[Callable]] = None, - image_transforms: Optional[Callable] = None, - ) -> DataLoader: + def train_dataloader(self, image_transforms: Optional[Callable] = None) -> DataLoader: """VOCDetection train set uses the `train` subset. Args: - batch_size: size of batch - transforms: custom transforms for image and target image_transforms: custom image-only transforms """ - if transforms is None: - transforms = [_prepare_voc_instance] - else: - transforms = [_prepare_voc_instance] + transforms - image_transforms = image_transforms or self.train_transforms or self._default_transforms() + transforms = [ + _prepare_voc_instance, + self.default_transforms() if self.train_transforms is None else self.train_transforms, + ] transforms = Compose(transforms, image_transforms) dataset = VOCDetection(self.data_dir, year=self.year, image_set="train", transforms=transforms) - loader = DataLoader( - dataset, - batch_size=batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - drop_last=self.drop_last, - pin_memory=self.pin_memory, - collate_fn=_collate_fn, - ) - return loader + return self._data_loader(dataset, shuffle=self.shuffle) - def val_dataloader( - self, - batch_size: int = 1, - transforms: Optional[List[Callable]] = None, - image_transforms: Optional[Callable] = None, - ) -> DataLoader: + def val_dataloader(self, image_transforms: Optional[Callable] = None) -> DataLoader: """VOCDetection val set uses the `val` subset. Args: - batch_size: size of batch - transforms: custom transforms for image and target image_transforms: custom image-only transforms """ - if transforms is None: - transforms = [_prepare_voc_instance] - else: - transforms = [_prepare_voc_instance] + transforms - image_transforms = image_transforms or self.train_transforms or self._default_transforms() + transforms = [ + _prepare_voc_instance, + self.default_transforms() if self.val_transforms is None else self.val_transforms, + ] transforms = Compose(transforms, image_transforms) dataset = VOCDetection(self.data_dir, year=self.year, image_set="val", transforms=transforms) - loader = DataLoader( + return self._data_loader(dataset, shuffle=False) + + def default_transforms(self) -> Callable: + voc_transforms = [transform_lib.ToTensor()] + if self.normalize: + voc_transforms += [transform_lib.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] + voc_transforms = transform_lib.Compose(voc_transforms) + return lambda image, target: (voc_transforms(image), target) + + def _data_loader(self, dataset: VOCDetection, shuffle: bool = False) -> DataLoader: + return DataLoader( dataset, - batch_size=batch_size, - shuffle=False, + batch_size=self.batch_size, + shuffle=shuffle, num_workers=self.num_workers, drop_last=self.drop_last, pin_memory=self.pin_memory, collate_fn=_collate_fn, ) - return loader - - def _default_transforms(self) -> Callable: - if self.normalize: - voc_transforms = transform_lib.Compose( - [ - transform_lib.ToTensor(), - transform_lib.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ] - ) - else: - voc_transforms = transform_lib.Compose([transform_lib.ToTensor()]) - return voc_transforms diff --git a/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py b/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py index 2773473ad7..bb1c4c1051 100644 --- a/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py +++ b/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py @@ -147,9 +147,8 @@ def run_cli(): seed_everything(42) parser = ArgumentParser() + parser = VOCDetectionDataModule.add_argparse_args(parser) parser = Trainer.add_argparse_args(parser) - parser.add_argument("--data_dir", type=str, default=".") - parser.add_argument("--batch_size", type=int, default=1) parser = FasterRCNN.add_model_specific_args(parser) args = parser.parse_args() diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index fda96ed3f4..3e93210b3f 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -66,7 +66,7 @@ class YOLO(LightningModule): # PascalVOC wget https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4-tiny-3l.cfg - python yolo_module.py --config yolov4-tiny-3l.cfg --data_dir . --gpus 8 --batch-size 8 + python yolo_module.py --config yolov4-tiny-3l.cfg --data_dir . --gpus 8 --batch_size 8 """ def __init__( @@ -465,15 +465,20 @@ class Resize: def __init__(self, output_size: tuple) -> None: self.output_size = output_size - def __call__(self, image, target): - width, height = image.size + def __call__(self, image: Tensor, target: Dict[str, Any]): + """ + Args: + tensor: Tensor image to be resized. + target: Dictionary of detection targets. + + Returns: + Resized Tensor image. + """ + height, width = image.shape[-2:] original_size = torch.tensor([height, width]) - resize_ratio = torch.tensor(self.output_size) / original_size + scale_y, scale_x = torch.tensor(self.output_size) / original_size + scale = torch.tensor([scale_x, scale_y, scale_x, scale_y], device=target["boxes"].device) image = F.resize(image, self.output_size) - scale = torch.tensor( - [resize_ratio[1], resize_ratio[0], resize_ratio[1], resize_ratio[0]], # y, x, y, x - device=target["boxes"].device, - ) target["boxes"] = target["boxes"] * scale return image, target @@ -484,17 +489,32 @@ def run_cli(): from pytorch_lightning import Trainer, seed_everything from pl_bolts.datamodules import VOCDetectionDataModule + from pl_bolts.datamodules.vocdetection_datamodule import Compose from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration seed_everything(42) parser = ArgumentParser() - parser.add_argument("--config", type=str, metavar="PATH", required=True, help="read model configuration from PATH") parser.add_argument( - "--darknet-weights", type=str, metavar="PATH", help="read the initial model weights from PATH in Darknet format" + "--config", + type=str, + metavar="PATH", + required=True, + help="read model configuration from PATH", + ) + parser.add_argument( + "--darknet-weights", + type=str, + metavar="PATH", + help="read the initial model weights from PATH in Darknet format", + ) + parser.add_argument( + "--lr", + type=float, + metavar="LR", + default=0.0013, + help="learning rate after the warmup period", ) - parser.add_argument("--batch-size", type=int, metavar="N", default=16, help="batch size is N image") - parser.add_argument("--lr", type=float, metavar="LR", default=0.0013, help="learning rate after the warmup period") parser.add_argument( "--momentum", type=float, @@ -510,11 +530,25 @@ def run_cli(): help="if nonzero, the optimizer uses weight decay (L2 penalty) with factor LAMBDA", ) parser.add_argument( - "--warmup-epochs", type=int, metavar="N", default=1, help="learning rate warmup period is N epochs" + "--warmup-epochs", + type=int, + metavar="N", + default=1, + help="learning rate warmup period is N epochs", ) - parser.add_argument("--max-epochs", type=int, metavar="N", default=300, help="train at most N epochs") parser.add_argument( - "--initial-lr", type=float, metavar="LR", default=0.0, help="learning rate before the warmup period" + "--max-epochs", + type=int, + metavar="N", + default=300, + help="train at most N epochs", + ) + parser.add_argument( + "--initial-lr", + type=float, + metavar="LR", + default=0.0, + help="learning rate before the warmup period", ) parser.add_argument( "--confidence-threshold", @@ -532,7 +566,11 @@ def run_cli(): "THRESHOLD with a higher scoring box", ) parser.add_argument( - "--max-predictions-per-image", type=int, metavar="N", default=100, help="keep at most N best predictions" + "--max-predictions-per-image", + type=int, + metavar="N", + default=100, + help="keep at most N best predictions", ) parser = VOCDetectionDataModule.add_argparse_args(parser) @@ -541,9 +579,9 @@ def run_cli(): config = YOLOConfiguration(args.config) - transforms = [Resize((config.height, config.width))] - datamodule = VOCDetectionDataModule.from_argparse_args(args) - datamodule.prepare_data() + transforms = [lambda image, target: (F.to_tensor(image), target), Resize((config.height, config.width))] + transforms = Compose(transforms) + datamodule = VOCDetectionDataModule.from_argparse_args(args, train_transforms=transforms, val_transforms=transforms) optimizer_params = {"lr": args.lr, "momentum": args.momentum, "weight_decay": args.weight_decay} lr_scheduler_params = { @@ -564,11 +602,7 @@ def run_cli(): model.load_darknet_weights(weight_file) trainer = Trainer.from_argparse_args(args) - trainer.fit( - model, - datamodule.train_dataloader(args.batch_size, transforms), - datamodule.val_dataloader(args.batch_size, transforms), - ) + trainer.fit(model, datamodule=datamodule) if __name__ == "__main__": From db278be75919633b42ab2100210ccc96526ae979 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 30 Aug 2021 14:02:07 +0300 Subject: [PATCH 58/61] Use true_divide() for integer division --- pl_bolts/models/detection/yolo/yolo_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/models/detection/yolo/yolo_module.py b/pl_bolts/models/detection/yolo/yolo_module.py index 3e93210b3f..ebb494f5ef 100644 --- a/pl_bolts/models/detection/yolo/yolo_module.py +++ b/pl_bolts/models/detection/yolo/yolo_module.py @@ -176,7 +176,7 @@ def forward( "Anchors may have been configured incorrectly." ) for layer_idx, layer_hits in enumerate(hits): - hit_rate = layer_hits / total_hits if total_hits > 0 else 1.0 + hit_rate = torch.true_divide(layer_hits, total_hits) if total_hits > 0 else 1.0 self.log(f"layer_{layer_idx}_hit_rate", hit_rate, sync_dist=False) def total_loss(loss_name): From 4731b24c80fc97e583ce22c7da2f0ea64c274eb9 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 30 Aug 2021 14:48:51 +0300 Subject: [PATCH 59/61] Fixed doc and package build without Torchvision --- pl_bolts/datamodules/vocdetection_datamodule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index ff62313908..15e3fc8a80 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -4,7 +4,7 @@ import torch from pytorch_lightning import LightningDataModule from torch import Tensor -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg @@ -185,7 +185,7 @@ def default_transforms(self) -> Callable: voc_transforms = transform_lib.Compose(voc_transforms) return lambda image, target: (voc_transforms(image), target) - def _data_loader(self, dataset: VOCDetection, shuffle: bool = False) -> DataLoader: + def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader: return DataLoader( dataset, batch_size=self.batch_size, From 022192a4baca1efdd96ef3c6b7efe93955e08de2 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Fri, 10 Sep 2021 05:20:22 +0300 Subject: [PATCH 60/61] YOLO moved to unreleased --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6057523cb..a6a3312434 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - - Added Soft Actor Critic (SAC) Model [#627](https://github.com/PyTorchLightning/lightning-bolts/pull/627)) - Added `EMNISTDataModule`, `BinaryEMNISTDataModule`, and `BinaryEMNIST` dataset ([#676](https://github.com/PyTorchLightning/lightning-bolts/pull/676)) @@ -19,6 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added SparseML Callback [#724](https://github.com/PyTorchLightning/lightning-bolts/pull/724)) +- Added YOLO model ([#552](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/552)) + ### Changed - Changed the default values `pin_memory=False`, `shuffle=False` and `num_workers=16` to `pin_memory=True`, `shuffle=True` and `num_workers=0` of datamodules ([#701](https://github.com/PyTorchLightning/lightning-bolts/pull/701)) @@ -88,7 +89,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added Pix2Pix model ([#533](https://github.com/PyTorchLightning/lightning-bolts/pull/533)) -- Added YOLO model ([#552](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/552)) ### Changed From bfc774c9d8c7bb658fbd4763387828926cbe3da5 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Fri, 10 Sep 2021 06:46:33 +0200 Subject: [PATCH 61/61] Code formatting Co-authored-by: Aki Nitta --- pl_bolts/models/detection/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pl_bolts/models/detection/__init__.py b/pl_bolts/models/detection/__init__.py index bcb97d7269..2d7a4a2d95 100644 --- a/pl_bolts/models/detection/__init__.py +++ b/pl_bolts/models/detection/__init__.py @@ -3,4 +3,9 @@ from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration from pl_bolts.models.detection.yolo.yolo_module import YOLO -__all__ = ["components", "FasterRCNN", "YOLOConfiguration", "YOLO"] +__all__ = [ + "components", + "FasterRCNN", + "YOLOConfiguration", + "YOLO", +]