Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

feat: Add Retinanet and backbones for detection #121

Merged
merged 20 commits into from
Feb 16, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions flash/vision/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from typing import Any, Optional, Tuple

import torchvision
from pytorch_lightning.utilities import _BOLTS_AVAILABLE
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn as nn
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

if _BOLTS_AVAILABLE:
from pl_bolts.models.self_supervised import SimCLR, SwAV
Expand Down Expand Up @@ -109,3 +110,23 @@ def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = Tr
return backbone, num_features

raise ValueError(f"{model_name} is not supported yet.")


def fetch_fasterrcnn_backbone_and_num_features(
backbone: str,
fpn: bool = True,
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
pretrained: bool = True,
trainable_backbone_layers: int = 3,
**kwargs: Any
) -> nn.Module:
if fpn:
if backbone in RESNET_MODELS:
backbone = resnet_fpn_backbone(
backbone, pretrained=pretrained, trainable_layers=trainable_backbone_layers, **kwargs
)
fpn_out_channels = 256
return backbone, fpn_out_channels
else:
rank_zero_warn(f"{backbone} backbone is not supported with `fpn=True`, `fpn` won't be added.")
backbone, num_features = backbone_and_num_features(backbone, pretrained)
return backbone, num_features
58 changes: 41 additions & 17 deletions flash/vision/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Sequence, Type, Union
from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union

import torch
import torchvision
from torch import nn
from torch.optim import Optimizer
from torchvision.models.detection.faster_rcnn import FasterRCNN as torchvision_FasterRCNN
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn, FastRCNNPredictor
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import box_iou

from flash.core import Task
from flash.vision.backbones import fetch_fasterrcnn_backbone_and_num_features
from flash.vision.detection.data import ObjectDetectionDataPipeline
from flash.vision.detection.finetuning import ObjectDetectionFineTuning

_models = {"fasterrcnn_resnet50_fpn": torchvision.models.detection.fasterrcnn_resnet50_fpn}


def _evaluate_iou(target, pred):
"""
Expand All @@ -37,14 +39,17 @@ def _evaluate_iou(target, pred):


class ObjectDetector(Task):
"""Image detection task
"""Object detection task

Ref: Lightning Bolts https://github.com/PyTorchLightning/pytorch-lightning-bolts

Args:
num_classes: the number of classes for detection, including background
model: either a string of :attr`_models` or a custom nn.Module.
Defaults to 'fasterrcnn_resnet50_fpn'.
backbone: Pretained backbone CNN architecture.
fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs.
pretrained: if true, returns a model pre-trained on COCO train2017
pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers: number of trainable resnet layers starting from final block
loss: the function(s) to update the model with. Has no effect for torchvision detection models.
metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger.
optimizer: The optimizer to use for training. Can either be the actual class or the class name.
Expand All @@ -57,25 +62,44 @@ class ObjectDetector(Task):
def __init__(
self,
num_classes: int,
model: Union[str, nn.Module] = "fasterrcnn_resnet50_fpn",
backbone: Optional[str] = None,
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
fpn: bool = True,
pretrained: bool = True,
pretrained_backbone: bool = True,
trainable_backbone_layers: int = 3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic could/should be within FinetuningCallback.

If the user requires model= fasterrcnn, then it should choose the FasterRCNNFinetuning Callback.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both the models will have the same FineTuningCallback as they would have similar backbones, but different heads. But yes, could think of moving trainable_backbone_layers for FineTuningCallback OR we could offer some options for finetuning functionalities to the User and it would override the trainable_backbone_layers.

loss=None,
metrics: Union[Callable, nn.Module, Mapping, Sequence, None] = None,
optimizer: Type[Optimizer] = torch.optim.Adam,
pretrained: bool = True,
learning_rate=1e-3,
**kwargs,
learning_rate: float = 1e-3,
**kwargs: Any,
):

self.save_hyperparameters()

if model in _models:
model = _models[model](pretrained=pretrained)
if isinstance(model, torchvision.models.detection.FasterRCNN):
in_features = model.roi_heads.box_predictor.cls_score.in_features
head = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
model.roi_heads.box_predictor = head
if backbone is None:
model = fasterrcnn_resnet50_fpn(
pretrained=pretrained,
pretrained_backbone=pretrained_backbone,
trainable_backbone_layers=trainable_backbone_layers,
)
in_features = model.roi_heads.box_predictor.cls_score.in_features
head = FastRCNNPredictor(in_features, num_classes)
model.roi_heads.box_predictor = head
else:
ValueError(f"{model} is not supported yet.")
backbone_model, num_features = fetch_fasterrcnn_backbone_and_num_features(
backbone,
fpn,
pretrained_backbone,
trainable_backbone_layers,
**kwargs,
)
backbone_model.out_channels = num_features
anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512), ),
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
aspect_ratios=((0.5, 1.0,
2.0), )) if not hasattr(backbone_model, "fpn") else None
model = torchvision_FasterRCNN(
backbone_model, num_classes=num_classes, rpn_anchor_generator=anchor_generator, **kwargs
)

super().__init__(
model=model,
Expand Down
5 changes: 3 additions & 2 deletions tests/vision/detection/test_data_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@


@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing")
def test_detection(tmpdir):
@pytest.mark.parametrize("backbone", [None, "resnet34", "mobilenet_v2", "simclr-imagenet"])
def test_detection(tmpdir, backbone):

train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir)

data = ObjectDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1)
model = ObjectDetector(num_classes=data.num_classes)
model = ObjectDetector(backbone=backbone, num_classes=data.num_classes)

trainer = flash.Trainer(fast_dev_run=True)

Expand Down
2 changes: 1 addition & 1 deletion tests/vision/detection/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_init():


def test_training(tmpdir):
model = ObjectDetector(num_classes=2, model="fasterrcnn_resnet50_fpn")
model = ObjectDetector(num_classes=2)
ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10)
dl = DataLoader(ds, collate_fn=collate_fn)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
Expand Down
15 changes: 15 additions & 0 deletions tests/vision/test_backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest

from flash.vision.backbones import fetch_fasterrcnn_backbone_and_num_features


@pytest.mark.parametrize(["backbone", "expected_num_features"], [("resnet34", 512), ("mobilenet_v2", 1280),
("simclr-imagenet", 2048)])
def test_fetch_fasterrcnn_backbone_and_num_features(backbone, expected_num_features):

backbone_model, num_features = fetch_fasterrcnn_backbone_and_num_features(
backbone=backbone, pretrained=False, fpn=False
)

assert backbone_model
assert num_features == expected_num_features