Skip to content

Commit

Permalink
Add YOLO object detection model (#552)
Browse files Browse the repository at this point in the history
* Add YOLO object detection model
* Reading Darknet weights works also with truncated files.
* Use torch.min() instead of torch.minimum() to avoid error with older PyTorch versions.
* Generalized interface for custom losses
* IoU loss functions take image space coordinates as input.
* box_area() implementation copied from torchvision
* IoU losses use torchvision
* IoU losses take the diagnoal of torchvision iou ops instead of implementing their own elementwise ops.
* YOLO written with all caps in class names
* Generic way to specify optimizer and LR scheduler
* Possible to limit the number of predictions per image
* No need to check for NaN values as Trainer has terminate_on_nan argument.
* YOLO test configuration moved to tests/data/yolo.cfg
* Synchronize validation and test step logging calls
* Log losses to progress bar
* Fixed documentation formatting
* Coordinate predictions are in image scale
* Use default dtype for torch.arange() to fix export to TensorRT
* Network input size can differ from the image size specified in the configuration
* Image size is given to detection layer forward() instead of the constructor to allow variable image sizes.
* Use default data type for torch.arange() to fix export to TensorRT.
* Loss is normalized by batch size only once
* Fixed division by zero when there are no targets in a batch
* Always return all losses to avoid deadlock with DDP when there are no targets
* Hit rates are always logged so don't prefix the names
* The vector of overlap losses was accidentally transformed to a square matrix.
* Some versions of Lightning don't work correctly when logging losses with sync_dist=True.
* Truncate nms() inputs to avoid it crashing when too many boxes are detected
* Use sum() instead of count_nonzero() as it's available already before PyTorch 1.7
* Squared error loss takes the sum over the predicted attributes
* Swish and logistic activation functions
* VOCDetectionDataModule constructor takes batch size and the transforms take image and target
* Use true_divide() for integer division
* Fixed doc and package build without Torchvision

Co-authored-by: Akihiro Nitta <[email protected]>
  • Loading branch information
senarvi and akihironitta authored Sep 10, 2021
1 parent 031e880 commit b9b35c4
Show file tree
Hide file tree
Showing 12 changed files with 1,568 additions and 45 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added YOLO model ([#552](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/552))

### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ Lightning-Bolts documentation

autoencoders
convolutional
object_detection
gans
reinforce_learn
self_supervised_models
Expand Down
20 changes: 20 additions & 0 deletions docs/source/object_detection.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
Object Detection
================
This package lists contributed object detection models.

--------------


Faster R-CNN
------------

.. autoclass:: pl_bolts.models.detection.faster_rcnn.faster_rcnn_module.FasterRCNN
:noindex:

-------------

YOLO
----

.. autoclass:: pl_bolts.models.detection.yolo.yolo_module.YOLO
:noindex:
77 changes: 35 additions & 42 deletions pl_bolts/datamodules/vocdetection_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import os
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from pytorch_lightning import LightningDataModule
from torch import Tensor
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
Expand Down Expand Up @@ -107,10 +108,11 @@ class VOCDetectionDataModule(LightningDataModule):

def __init__(
self,
data_dir: str,
data_dir: Optional[str] = None,
year: str = "2012",
num_workers: int = 0,
normalize: bool = False,
batch_size: int = 16,
shuffle: bool = True,
pin_memory: bool = True,
drop_last: bool = False,
Expand All @@ -125,9 +127,10 @@ def __init__(
super().__init__(*args, **kwargs)

self.year = year
self.data_dir = data_dir
self.data_dir = data_dir if data_dir is not None else os.getcwd()
self.num_workers = num_workers
self.normalize = normalize
self.batch_size = batch_size
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
Expand All @@ -145,60 +148,50 @@ def prepare_data(self) -> None:
VOCDetection(self.data_dir, year=self.year, image_set="train", download=True)
VOCDetection(self.data_dir, year=self.year, image_set="val", download=True)

def train_dataloader(
self, batch_size: int = 1, image_transforms: Union[List[Callable], Callable] = None
) -> DataLoader:
def train_dataloader(self, image_transforms: Optional[Callable] = None) -> DataLoader:
"""VOCDetection train set uses the `train` subset.
Args:
batch_size: size of batch
transforms: custom transforms
image_transforms: custom image-only transforms
"""
transforms = [_prepare_voc_instance]
image_transforms = image_transforms or self.train_transforms or self._default_transforms()
transforms = [
_prepare_voc_instance,
self.default_transforms() if self.train_transforms is None else self.train_transforms,
]
transforms = Compose(transforms, image_transforms)

dataset = VOCDetection(self.data_dir, year=self.year, image_set="train", transforms=transforms)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
collate_fn=_collate_fn,
)
return loader
return self._data_loader(dataset, shuffle=self.shuffle)

def val_dataloader(self, batch_size: int = 1, image_transforms: Optional[List[Callable]] = None) -> DataLoader:
def val_dataloader(self, image_transforms: Optional[Callable] = None) -> DataLoader:
"""VOCDetection val set uses the `val` subset.
Args:
batch_size: size of batch
transforms: custom transforms
image_transforms: custom image-only transforms
"""
transforms = [_prepare_voc_instance]
image_transforms = image_transforms or self.train_transforms or self._default_transforms()
transforms = [
_prepare_voc_instance,
self.default_transforms() if self.val_transforms is None else self.val_transforms,
]
transforms = Compose(transforms, image_transforms)

dataset = VOCDetection(self.data_dir, year=self.year, image_set="val", transforms=transforms)
loader = DataLoader(
return self._data_loader(dataset, shuffle=False)

def default_transforms(self) -> Callable:
voc_transforms = [transform_lib.ToTensor()]
if self.normalize:
voc_transforms += [transform_lib.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
voc_transforms = transform_lib.Compose(voc_transforms)
return lambda image, target: (voc_transforms(image), target)

def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader:
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
batch_size=self.batch_size,
shuffle=shuffle,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
collate_fn=_collate_fn,
)
return loader

def _default_transforms(self) -> Callable:
if self.normalize:
voc_transforms = transform_lib.Compose(
[
transform_lib.ToTensor(),
transform_lib.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
else:
voc_transforms = transform_lib.Compose([transform_lib.ToTensor()])
return voc_transforms
4 changes: 4 additions & 0 deletions pl_bolts/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from pl_bolts.models.detection import components
from pl_bolts.models.detection.faster_rcnn import FasterRCNN
from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration
from pl_bolts.models.detection.yolo.yolo_module import YOLO

__all__ = [
"components",
"FasterRCNN",
"YOLOConfiguration",
"YOLO",
]
3 changes: 1 addition & 2 deletions pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,8 @@ def run_cli():

seed_everything(42)
parser = ArgumentParser()
parser = VOCDetectionDataModule.add_argparse_args(parser)
parser = Trainer.add_argparse_args(parser)
parser.add_argument("--data_dir", type=str, default=".")
parser.add_argument("--batch_size", type=int, default=1)
parser = FasterRCNN.add_model_specific_args(parser)

args = parser.parse_args()
Expand Down
Empty file.
Loading

0 comments on commit b9b35c4

Please sign in to comment.