Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add YOLO object detection model #552

Merged
merged 86 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
461de96
Add YOLO object detection model
senarvi Feb 2, 2021
2b9b073
Readability improvements
senarvi Feb 3, 2021
cc42540
Documentation improvements
senarvi Feb 3, 2021
876da0d
Fixed style issues.
senarvi Feb 3, 2021
f99930e
Refactoring
senarvi Feb 5, 2021
4415d41
Refactoring
senarvi Feb 8, 2021
7356fe0
Refactoring
senarvi Feb 8, 2021
39eb80d
Fixed YOLO test.
senarvi Feb 8, 2021
291f4be
Fixedd style issues
senarvi Feb 9, 2021
8db7947
Comply to isort rules.
senarvi Feb 9, 2021
2831755
Reading Darknet weights works also with truncated files.
senarvi Feb 9, 2021
eb26eba
Fixed code formatting.
senarvi Feb 9, 2021
9c155a9
Trying to fix Python 3.6 import problem.
senarvi Feb 9, 2021
efeb1c8
Fixed Python 3.6 import error.
senarvi Feb 9, 2021
1a1ecd3
Added YOLO to CHANGELOG.
senarvi Feb 9, 2021
26ff979
Use torch.min() instead of torch.minimum() to avoid error with older …
senarvi Feb 9, 2021
3e9bdde
Generalized interface for custom losses
senarvi Feb 12, 2021
c348619
box_area() implementation copied from torchvision
senarvi Feb 12, 2021
c2d7907
Confirm to yapf formatter rules.
senarvi Feb 12, 2021
3d7f440
Removed the unnecessary linter instructions.
senarvi Feb 15, 2021
eb6be46
IoU losses use torchvision
senarvi Feb 15, 2021
3d940f2
Merge branch 'yolo' of github.com:groke-technologies/pytorch-lightnin…
senarvi Feb 15, 2021
cf7420c
Improved strange yapf formatting
senarvi Feb 15, 2021
6c90cd4
Refactoring
senarvi Feb 15, 2021
7aeea63
Merge branch 'yolo' of github.com:groke-technologies/pytorch-lightnin…
senarvi Feb 15, 2021
60fda75
get_deprecated_arg_names() is not needed anymore.
senarvi Feb 15, 2021
b2e3e84
Fixed yapf formatting.
senarvi Feb 15, 2021
940947f
Fixed formatting.
senarvi Feb 15, 2021
6e3d5bf
Removed unused imports.
senarvi Feb 15, 2021
a5bed26
Merge branch 'yolo' of github.com:groke-technologies/pytorch-lightnin…
senarvi Feb 15, 2021
b2ea497
Fixed some type hints.
senarvi Feb 16, 2021
6d8fa7d
Sorted imports.
senarvi Feb 16, 2021
e68df7a
Possible to limit the number of predictions per image
senarvi Feb 24, 2021
f895530
None instead of an empty list as default argument
senarvi Feb 24, 2021
58f1456
Fixed capitalization of YOLO class.
senarvi Feb 24, 2021
19b9df7
Merge branch 'origin/master' into yolo
senarvi Mar 4, 2021
4e6d4cf
No need to check for NaN values as Trainer has terminate_on_nan argum…
senarvi Mar 8, 2021
af3e0e6
YOLO test configuration moved to tests/data/yolo.cfg
senarvi Mar 8, 2021
c4ae5ec
Merge branch 'origin/master' into yolo
senarvi Mar 8, 2021
4012247
Use Optional[] as the default value for transforms is now None
senarvi Mar 8, 2021
c8b76a5
Refactoring and documentation improvements
senarvi Mar 23, 2021
71a4c3c
Fixed documentation formatting
senarvi Mar 24, 2021
da4eace
Merge branch 'origin/master' into yolo
senarvi Mar 24, 2021
70f14b0
Coordinate predictions are in image scale
senarvi Mar 31, 2021
3c1a0fb
Merge branch 'origin/master' into yolo
senarvi Mar 31, 2021
9587b46
Use default dtype for torch.arange() to fix export to TensorRT
senarvi Apr 1, 2021
2b6c552
Network input size can differ from the image size specified in the co…
senarvi Apr 10, 2021
e97b198
Merge branch 'yolo' of github.com:groke-technologies/pytorch-lightnin…
senarvi Apr 10, 2021
dbe1d59
Merge branch 'origin/master' into yolo
senarvi Apr 10, 2021
b66dbd3
Merge branch 'master' into yolo
senarvi May 3, 2021
004d1ce
Use torch.true_divide() instead of /
senarvi May 3, 2021
ad1e48e
Use torch.true_divide() instead of /
senarvi May 5, 2021
a254517
Merge branch 'master' into yolo
senarvi May 13, 2021
1cde4f8
Merge branch 'origin/master' into yolo
senarvi Jun 1, 2021
e92c405
Merge branch 'master' into yolo
senarvi Jun 17, 2021
c237b37
Loss is normalized by batch size only once
senarvi Jun 23, 2021
9b010de
Fixed division by zero when there are no targets in a batch
senarvi Jun 23, 2021
dc7ae4c
Merge branch 'master' into yolo
senarvi Jun 23, 2021
f15282d
Always return all losses to avoid deadlock with DDP when there are no…
senarvi Jun 24, 2021
de52b75
Merge branch 'master' into yolo
senarvi Jun 24, 2021
f6d3476
Hit rates are always logged so don't prefix the names
senarvi Jul 1, 2021
8e12359
Merge branch 'master' into yolo
senarvi Jul 1, 2021
86a6b66
Fixed training loss
senarvi Jul 31, 2021
7fd38ca
Merge branch 'origin/master' into yolo
senarvi Jul 31, 2021
3286533
Truncate nms() inputs to avoid it crashing when too many boxes are de…
senarvi Aug 4, 2021
bb92076
Use sum() instead of count_nonzero() as it's available already before…
senarvi Aug 11, 2021
7d08350
Merge branch 'master' into yolo
senarvi Aug 11, 2021
b896112
Squared error loss takes the sum over the predicted attributes
senarvi Aug 17, 2021
55a1180
Swish and logistic activation functions
senarvi Aug 17, 2021
6fb82c1
Merge branch 'master' into yolo
senarvi Aug 17, 2021
7857bea
Added a comment
senarvi Aug 18, 2021
0804699
Fixed code formatting
senarvi Aug 18, 2021
7b32d64
Ran docformatter with correct config
senarvi Aug 19, 2021
7b316b8
Merge branch 'master' into yolo
senarvi Aug 19, 2021
191e865
Ran pyupgrade
senarvi Aug 19, 2021
66c8111
Reformatted
senarvi Aug 19, 2021
572ca4f
VOCDetectionDataModule constructor takes batch size and the transform…
senarvi Aug 30, 2021
198dd32
Merge branch 'master' into yolo
senarvi Aug 30, 2021
db278be
Use true_divide() for integer division
senarvi Aug 30, 2021
4731b24
Fixed doc and package build without Torchvision
senarvi Aug 30, 2021
f63df0b
Merge branch 'master' into yolo
senarvi Aug 31, 2021
3b70b77
Merge branch 'master' into yolo
senarvi Sep 8, 2021
022192a
YOLO moved to unreleased
senarvi Sep 10, 2021
a8e1ce3
Merge branch 'master' into yolo
senarvi Sep 10, 2021
bfc774c
Code formatting
senarvi Sep 10, 2021
8d4e43e
Merge branch 'yolo' of github.com:groke-technologies/pytorch-lightnin…
senarvi Sep 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#323](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/323))
- Added data monitor callbacks `ModuleDataMonitor` and `TrainingDataMonitor` ([#285](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/285))
- Added DCGAN module ([#403](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/403))
- Added YOLO module ([#552](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/552))
- Added `VisionDataModule` as parent class for `BinaryMNISTDataModule`, `CIFAR10DataModule`, `FashionMNISTDataModule`,
and `MNISTDataModule` ([#400](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/400))
- Added GIoU loss ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347))
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ PyTorch-Lightning-Bolts documentation

autoencoders
convolutional
object_detection
gans
reinforce_learn
self_supervised_models
Expand Down
20 changes: 20 additions & 0 deletions docs/source/object_detection.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
Object Detection
================
This package lists contributed object detection models.

--------------


Faster R-CNN
------------

.. autoclass:: pl_bolts.models.detection.faster_rcnn.faster_rcnn_module.FasterRCNN
:noindex:

-------------

YOLO
----

.. autoclass:: pl_bolts.models.detection.yolo.yolo_module.Yolo
:noindex:
24 changes: 17 additions & 7 deletions pl_bolts/datamodules/vocdetection_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from pytorch_lightning import LightningDataModule
Expand Down Expand Up @@ -152,16 +152,20 @@ def prepare_data(self) -> None:
VOCDetection(self.data_dir, year=self.year, image_set="val", download=True)

def train_dataloader(
self, batch_size: int = 1, image_transforms: Union[List[Callable], Callable] = None
self,
batch_size: int = 1,
transforms: List[Callable] = [],
senarvi marked this conversation as resolved.
Show resolved Hide resolved
image_transforms: Optional[Callable] = None
) -> DataLoader:
"""
VOCDetection train set uses the `train` subset

Args:
batch_size: size of batch
transforms: custom transforms
transforms: custom transforms for image and target
image_transforms: custom image-only transforms
"""
transforms = [_prepare_voc_instance]
transforms = [_prepare_voc_instance] + transforms
image_transforms = image_transforms or self.train_transforms or self._default_transforms()
transforms = Compose(transforms, image_transforms)
dataset = VOCDetection(self.data_dir, year=self.year, image_set="train", transforms=transforms)
Expand All @@ -176,15 +180,21 @@ def train_dataloader(
)
return loader

def val_dataloader(self, batch_size: int = 1, image_transforms: Optional[List[Callable]] = None) -> DataLoader:
def val_dataloader(
self,
batch_size: int = 1,
transforms: List[Callable] = [],
image_transforms: Optional[Callable] = None
senarvi marked this conversation as resolved.
Show resolved Hide resolved
) -> DataLoader:
"""
VOCDetection val set uses the `val` subset

Args:
batch_size: size of batch
transforms: custom transforms
transforms: custom transforms for image and target
image_transforms: custom image-only transforms
"""
transforms = [_prepare_voc_instance]
transforms = [_prepare_voc_instance] + transforms
image_transforms = image_transforms or self.train_transforms or self._default_transforms()
transforms = Compose(transforms, image_transforms)
dataset = VOCDetection(self.data_dir, year=self.year, image_set="val", transforms=transforms)
Expand Down
10 changes: 4 additions & 6 deletions pl_bolts/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from pl_bolts.models.detection import components # noqa: F401
from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401
from pl_bolts.models.detection import components
from pl_bolts.models.detection.faster_rcnn import FasterRCNN
from pl_bolts.models.detection.yolo import Yolo, YoloConfiguration

__all__ = [
"components",
"FasterRCNN",
]
__all__ = ["components", "FasterRCNN", "YoloConfiguration", "Yolo"]
4 changes: 4 additions & 0 deletions pl_bolts/models/detection/yolo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from pl_bolts.models.detection.yolo.yolo_config import YoloConfiguration
from pl_bolts.models.detection.yolo.yolo_module import Yolo

__all__ = ["YoloConfiguration", "Yolo"]
senarvi marked this conversation as resolved.
Show resolved Hide resolved
265 changes: 265 additions & 0 deletions pl_bolts/models/detection/yolo/yolo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
import re
from typing import List, Tuple
from warnings import warn

import torch.nn as nn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from pl_bolts.models.detection.yolo import yolo_layers


class YoloConfiguration:
senarvi marked this conversation as resolved.
Show resolved Hide resolved
"""
This class can be used to parse the configuration files of the Darknet YOLOv4 implementation.
The `get_network()` method returns a PyTorch module list that can be used to construct a YOLO
senarvi marked this conversation as resolved.
Show resolved Hide resolved
model.
"""

def __init__(self, path: str):
"""
Saves the variables from the first configuration section to attributes of this object, and
the rest of the sections to the `layer_configs` list.

Args:
path: Path to a configuration file
"""
with open(path, 'r') as config_file:
sections = self._read_file(config_file)

if len(sections) < 2:
raise MisconfigurationException("The model configuration file should include at least two sections.")

self.__dict__.update(sections[0])
self.global_config = sections[0]
self.layer_configs = sections[1:]

def get_network(self) -> nn.ModuleList:
"""
Iterates through the layers from the configuration and creates corresponding PyTorch
modules. Returns the network structure that can be used to create a YOLO model.

Returns:
modules: A `nn.ModuleList` that defines the YOLO network.
"""
result = nn.ModuleList()
num_inputs = [3] # Number of channels in the input of every layer up to the current layer
for layer_config in self.layer_configs:
config = {**self.global_config, **layer_config}
module, num_outputs = _create_layer(config, num_inputs)
result.append(module)
num_inputs.append(num_outputs)
return result

def _read_file(self, config_file):
"""
Reads a YOLOv4 network configuration file and returns a list of configuration sections.

Args:
config_file (iterable over lines): The configuration file to read.

Returns:
sections (List[dict]): A list of configuration sections.
"""
section_re = re.compile(r'\[([^]]+)\]')
list_variables = ('layers', 'anchors', 'mask', 'scales')
variable_types = {
'activation': str,
'anchors': int,
'angle': float,
'batch': int,
'batch_normalize': bool,
'beta_nms': float,
'burn_in': int,
'channels': int,
'classes': int,
'cls_normalizer': float,
'decay': float,
'exposure': float,
'filters': int,
'from': int,
'groups': int,
'group_id': int,
'height': int,
'hue': float,
'ignore_thresh': float,
'iou_loss': str,
'iou_normalizer': float,
'iou_thresh': float,
'jitter': float,
'layers': int,
'learning_rate': float,
'mask': int,
'max_batches': int,
'max_delta': float,
'momentum': float,
'mosaic': bool,
'nms_kind': str,
'num': int,
'obj_normalizer': float,
'pad': bool,
'policy': str,
'random': bool,
'resize': float,
'saturation': float,
'scales': float,
'scale_x_y': float,
'size': int,
'steps': str,
'stride': int,
'subdivisions': int,
'truth_thresh': float,
'width': int
}

section = None
sections = []

def convert(key, value):
"""Converts a value to the correct type based on key."""
if key not in variable_types:
warn('Unknown YOLO configuration variable: ' + key)
return key, value
if key in list_variables:
value = [variable_types[key](v) for v in value.split(',')]
else:
value = variable_types[key](value)
return key, value

for line in config_file:
line = line.strip()
if (not line) or (line[0] == '#'):
continue

section_match = section_re.match(line)
if section_match:
if section is not None:
sections.append(section)
section = {'type': section_match.group(1)}
else:
key, value = line.split('=')
key = key.rstrip()
value = value.lstrip()
key, value = convert(key, value)
section[key] = value
if section is not None:
sections.append(section)

return sections


def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]:
"""
Calls one of the `_create_<layertype>(config, num_inputs)` functions to create a PyTorch
module from the layer config.

Args:
config: Dictionary of configuration options for this layer.
num_inputs: Number of channels in the input of every layer up to this layer.

Returns:
module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the
number of channels in its output.
"""
create_func = {
'convolutional': _create_convolutional,
'maxpool': _create_maxpool,
'route': _create_route,
'shortcut': _create_shortcut,
'upsample': _create_upsample,
'yolo': _create_yolo
}
return create_func[config['type']](config, num_inputs)


def _create_convolutional(config, num_inputs):
module = nn.Sequential()

batch_normalize = config.get('batch_normalize', False)
padding = (config['size'] - 1) // 2 if config['pad'] else 0

conv = nn.Conv2d(
num_inputs[-1], config['filters'], config['size'], config['stride'], padding, bias=not batch_normalize
)
module.add_module('conv', conv)

if batch_normalize:
bn = nn.BatchNorm2d(config['filters'])
module.add_module('bn', bn)

if config['activation'] == 'leaky':
leakyrelu = nn.LeakyReLU(0.1, inplace=True)
module.add_module('leakyrelu', leakyrelu)
elif config['activation'] == 'mish':
mish = yolo_layers.Mish()
module.add_module('mish', mish)

return module, config['filters']


def _create_maxpool(config, num_inputs):
padding = (config['size'] - 1) // 2
module = nn.MaxPool2d(config['size'], config['stride'], padding)
return module, num_inputs[-1]


def _create_route(config, num_inputs):
num_chunks = config.get('groups', 1)
chunk_idx = config.get('group_id', 0)

# 0 is the first layer, -1 is the previous layer
last = len(num_inputs) - 1
source_layers = [layer if layer >= 0 else last + layer for layer in config['layers']]

module = yolo_layers.RouteLayer(source_layers, num_chunks, chunk_idx)

# The number of outputs of a source layer is the number of inputs of the next layer.
num_outputs = sum(num_inputs[layer + 1] // num_chunks for layer in source_layers)

return module, num_outputs


def _create_shortcut(config, num_inputs):
module = yolo_layers.ShortcutLayer(config['from'])
return module, num_inputs[-1]


def _create_upsample(config, num_inputs):
module = nn.Upsample(scale_factor=config["stride"], mode='nearest')
return module, num_inputs[-1]


def _create_yolo(config, num_inputs):
# The "anchors" list alternates width and height.
anchor_dims = config['anchors']
anchor_dims = [(anchor_dims[i], anchor_dims[i + 1]) for i in range(0, len(anchor_dims), 2)]

xy_scale = config.get('scale_x_y', 1.0)
ignore_threshold = config.get('ignore_thresh', 1.0)
overlap_loss_multiplier = config.get('iou_normalizer', 1.0)
class_loss_multiplier = config.get('cls_normalizer', 1.0)
confidence_loss_multiplier = config.get('obj_normalizer', 1.0)

overlap_loss_name = config.get('iou_loss', 'mse')
if overlap_loss_name == 'mse':
overlap_loss_func = nn.MSELoss(reduction='none')
elif overlap_loss_name == 'giou':
overlap_loss_func = yolo_layers.GIoULoss()
else:
overlap_loss_func = yolo_layers.IoULoss()

module = yolo_layers.DetectionLayer(
num_classes=config['classes'],
image_width=config['width'],
image_height=config['height'],
anchor_dims=anchor_dims,
anchor_ids=config['mask'],
xy_scale=xy_scale,
ignore_threshold=ignore_threshold,
overlap_loss_func=overlap_loss_func,
image_space_loss=overlap_loss_name != 'mse',
overlap_loss_multiplier=overlap_loss_multiplier,
class_loss_multiplier=class_loss_multiplier,
confidence_loss_multiplier=confidence_loss_multiplier
)

return module, num_inputs[-1]
Loading