-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add YOLO object detection model (#552)
* Add YOLO object detection model * Reading Darknet weights works also with truncated files. * Use torch.min() instead of torch.minimum() to avoid error with older PyTorch versions. * Generalized interface for custom losses * IoU loss functions take image space coordinates as input. * box_area() implementation copied from torchvision * IoU losses use torchvision * IoU losses take the diagnoal of torchvision iou ops instead of implementing their own elementwise ops. * YOLO written with all caps in class names * Generic way to specify optimizer and LR scheduler * Possible to limit the number of predictions per image * No need to check for NaN values as Trainer has terminate_on_nan argument. * YOLO test configuration moved to tests/data/yolo.cfg * Synchronize validation and test step logging calls * Log losses to progress bar * Fixed documentation formatting * Coordinate predictions are in image scale * Use default dtype for torch.arange() to fix export to TensorRT * Network input size can differ from the image size specified in the configuration * Image size is given to detection layer forward() instead of the constructor to allow variable image sizes. * Use default data type for torch.arange() to fix export to TensorRT. * Loss is normalized by batch size only once * Fixed division by zero when there are no targets in a batch * Always return all losses to avoid deadlock with DDP when there are no targets * Hit rates are always logged so don't prefix the names * The vector of overlap losses was accidentally transformed to a square matrix. * Some versions of Lightning don't work correctly when logging losses with sync_dist=True. * Truncate nms() inputs to avoid it crashing when too many boxes are detected * Use sum() instead of count_nonzero() as it's available already before PyTorch 1.7 * Squared error loss takes the sum over the predicted attributes * Swish and logistic activation functions * VOCDetectionDataModule constructor takes batch size and the transforms take image and target * Use true_divide() for integer division * Fixed doc and package build without Torchvision Co-authored-by: Akihiro Nitta <[email protected]>
- Loading branch information
1 parent
031e880
commit b9b35c4
Showing
12 changed files
with
1,568 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
Object Detection | ||
================ | ||
This package lists contributed object detection models. | ||
|
||
-------------- | ||
|
||
|
||
Faster R-CNN | ||
------------ | ||
|
||
.. autoclass:: pl_bolts.models.detection.faster_rcnn.faster_rcnn_module.FasterRCNN | ||
:noindex: | ||
|
||
------------- | ||
|
||
YOLO | ||
---- | ||
|
||
.. autoclass:: pl_bolts.models.detection.yolo.yolo_module.YOLO | ||
:noindex: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,11 @@ | ||
from pl_bolts.models.detection import components | ||
from pl_bolts.models.detection.faster_rcnn import FasterRCNN | ||
from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration | ||
from pl_bolts.models.detection.yolo.yolo_module import YOLO | ||
|
||
__all__ = [ | ||
"components", | ||
"FasterRCNN", | ||
"YOLOConfiguration", | ||
"YOLO", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.