Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge cbcc2b1 into 1bb7d72
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 authored Feb 9, 2021
2 parents 1bb7d72 + cbcc2b1 commit 13f5106
Show file tree
Hide file tree
Showing 10 changed files with 504 additions and 12 deletions.
2 changes: 1 addition & 1 deletion flash/core/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo
Args:
attr_names: Name(s) of the module attributes of the model to be frozen.
train_bn: Wether to train Batch Norm layer
train_bn: Whether to train Batch Norm layer
"""

Expand Down
2 changes: 1 addition & 1 deletion flash/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from flash.vision.classification import ImageClassificationData, ImageClassifier
from flash.vision.detection import ImageDetector
from flash.vision.detection import ImageDetectionData, ImageDetector
from flash.vision.embedding import ImageEmbedder
1 change: 1 addition & 0 deletions flash/vision/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from flash.vision.detection.data import ImageDetectionData
from flash.vision.detection.model import ImageDetector
201 changes: 201 additions & 0 deletions flash/vision/detection/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Callable, List, Optional, Tuple

import torch
from PIL import Image
from pytorch_lightning.utilities import _module_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor
from torch._six import container_abcs
from torch.utils.data._utils.collate import default_collate
from torchvision import transforms as T

from flash.core.data import TaskDataPipeline
from flash.core.data.datamodule import DataModule
from flash.core.data.utils import _contains_any_tensor
from flash.vision.classification.data import _pil_loader

_COCO_AVAILABLE = _module_available("pycocotools")
if _COCO_AVAILABLE:
from pycocotools.coco import COCO


class CustomCOCODataset(torch.utils.data.Dataset):

def __init__(
self,
root: str,
ann_file: str,
transforms: Optional[Callable] = None,
):
if not _COCO_AVAILABLE:
raise ImportError("Kindly install the COCO API `pycocotools` to use the Dataset")

self.root = root
self.transforms = transforms
self.coco = COCO(ann_file)
self.ids = list(sorted(self.coco.imgs.keys()))

@property
def num_classes(self):
categories = self.coco.loadCats(self.coco.getCatIds())
if not categories:
raise ValueError("No Categories found")
return categories[-1]["id"] + 1

def __getitem__(self, index: int) -> Tuple[Any, Any]:
coco = self.coco
img_idx = self.ids[index]

ann_ids = coco.getAnnIds(imgIds=img_idx)
annotations = coco.loadAnns(ann_ids)

image_path = coco.loadImgs(img_idx)[0]["file_name"]
img = Image.open(os.path.join(self.root, image_path))

boxes = []
labels = []
areas = []
iscrowd = []

for obj in annotations:
xmin = obj["bbox"][0]
ymin = obj["bbox"][1]
xmax = xmin + obj["bbox"][2]
ymax = ymin + obj["bbox"][3]

bbox = [xmin, ymin, xmax, ymax]
keep = (bbox[3] > bbox[1]) & (bbox[2] > bbox[0])
if keep:
boxes.append(bbox)
labels.append(obj["category_id"])
areas.append(obj["area"])
iscrowd.append(obj["iscrowd"])

target = {}
target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
target["image_id"] = torch.tensor([img_idx])
target["area"] = torch.as_tensor(areas, dtype=torch.float32)
target["iscrowd"] = torch.as_tensor(iscrowd, dtype=torch.int64)

if self.transforms is not None:
img = self.transforms(img)

return img, target

def __len__(self):
return len(self.ids)


def _coco_remove_images_without_annotations(dataset):
# Ref: https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py

def _has_only_empty_bbox(anno: List):
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)

def _has_valid_annotation(anno: List):
# if it's empty, there is no annotation
if len(anno) == 0:
return False
# if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno):
return False
return True

ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = dataset.coco.loadAnns(ann_ids)
if _has_valid_annotation(anno):
ids.append(ds_idx)

dataset = torch.utils.data.Subset(dataset, ids)
return dataset


_default_transform = T.ToTensor()


class ImageDetectorDataPipeline(TaskDataPipeline):

def __init__(self, valid_transform: Optional[Callable] = _default_transform, loader: Callable = _pil_loader):
self._valid_transform = valid_transform
self._loader = loader

def before_collate(self, samples: Any) -> Any:
if _contains_any_tensor(samples):
return samples

if isinstance(samples, str):
samples = [samples]

if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples):
outputs = []
for sample in samples:
output = self._loader(sample)
outputs.append(self._valid_transform(output))
return outputs
raise MisconfigurationException("The samples should either be a tensor, a list of paths or a path.")

def collate(self, samples: Any) -> Any:
if not isinstance(samples, Tensor):
elem = samples[0]
if isinstance(elem, container_abcs.Sequence):
return tuple(zip(*samples))
return default_collate(samples)
return samples.unsqueeze(dim=0)


class ImageDetectionData(DataModule):

@classmethod
def from_coco(
cls,
train_folder: Optional[str] = None,
train_ann_file: Optional[str] = None,
train_transform: Optional[Callable] = _default_transform,
valid_folder: Optional[str] = None,
valid_ann_file: Optional[str] = None,
valid_transform: Optional[Callable] = _default_transform,
test_folder: Optional[str] = None,
test_ann_file: Optional[str] = None,
test_transform: Optional[Callable] = _default_transform,
batch_size: int = 4,
num_workers: Optional[int] = None,
**kwargs
):
train_ds = CustomCOCODataset(train_folder, train_ann_file, train_transform)
num_classes = train_ds.num_classes
train_ds = _coco_remove_images_without_annotations(train_ds)

valid_ds = (
CustomCOCODataset(valid_folder, valid_ann_file, valid_transform) if valid_folder is not None else None
)

test_ds = (CustomCOCODataset(test_folder, test_ann_file, test_transform) if test_folder is not None else None)

datamodule = cls(
train_ds=train_ds,
valid_ds=valid_ds,
test_ds=test_ds,
batch_size=batch_size,
num_workers=num_workers,
)

datamodule.num_classes = num_classes
datamodule.data_pipeline = ImageDetectorDataPipeline()
return datamodule
29 changes: 29 additions & 0 deletions flash/vision/detection/finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytorch_lightning as pl

from flash.core.finetuning import FlashBaseFinetuning


class ImageDetectorFineTuning(FlashBaseFinetuning):
"""
Freezes the backbone during Detector training.
"""

def __init__(self, train_bn: bool = True):
self.train_bn = train_bn

def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
model = pl_module.model
self.freeze(module=model.backbone, train_bn=self.train_bn)
64 changes: 54 additions & 10 deletions flash/vision/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Sequence, Type, Union
from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union

import torch
import torchvision
from torch import nn
from torch.optim import Optimizer
from torchvision.ops import box_iou

from flash.core.classification import ClassificationTask
from flash.core import Task
from flash.core.data import DataPipeline
from flash.vision.detection.data import ImageDetectorDataPipeline
from flash.vision.detection.finetuning import ImageDetectorFineTuning

_models = {"fasterrcnn_resnet50_fpn": torchvision.models.detection.fasterrcnn_resnet50_fpn}


class ImageDetector(ClassificationTask):
def _evaluate_iou(target, pred):
"""
Evaluate intersection over union (IOU) for target from dataset and output prediction
from model
"""
if pred["boxes"].shape[0] == 0:
# no box detected, 0 IOU
return torch.tensor(0.0, device=pred["boxes"].device)
return box_iou(target["boxes"], pred["boxes"]).diag().mean()


class ImageDetector(Task):
"""Image detection task
Ref: Lightning Bolts https://github.com/PyTorchLightning/pytorch-lightning-bolts
Args:
num_classes: the number of classes for detection, including background
model: either a string of :attr`_models` or a custom nn.Module.
Expand All @@ -52,17 +68,16 @@ def __init__(
learning_rate=1e-3,
**kwargs,
):

self.save_hyperparameters()

if model in _models:
model = _models[model](pretrained=pretrained)
if isinstance(model, torchvision.models.detection.FasterRCNN):
in_features = model.roi_heads.box_predictor.cls_score.in_features
head = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
model.roi_heads.box_predictor = head

if loss is None:
# TODO: maybe better way of handling no loss,
loss = {}

super().__init__(
model=model,
loss_fn=loss,
Expand All @@ -81,7 +96,36 @@ def training_step(self, batch, batch_idx) -> Any:
# fasterrcnn takes both images and targets for training, returns loss_dict
loss_dict = self.model(images, targets)
loss = sum(loss_dict.values())
for k, v in loss_dict.items():
self.log("train_k", v)

self.log_dict({f"train_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True, prog_bar=True)
return loss

def validation_step(self, batch, batch_idx):
images, targets = batch
# fasterrcnn takes only images for eval() mode
outs = self.model(images)
iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean()
return {"val_iou": iou}

def validation_epoch_end(self, outs):
avg_iou = torch.stack([o["val_iou"] for o in outs]).mean()
logs = {"val_iou": avg_iou}
return {"avg_val_iou": avg_iou, "log": logs}

def test_step(self, batch, batch_idx):
images, targets = batch
# fasterrcnn takes only images for eval() mode
outs = self.model(images)
iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean()
return {"test_iou": iou}

def test_epoch_end(self, outs):
avg_iou = torch.stack([o["test_iou"] for o in outs]).mean()
logs = {"test_iou": avg_iou}
return {"avg_test_iou": avg_iou, "log": logs}

@staticmethod
def default_pipeline() -> ImageDetectorDataPipeline:
return ImageDetectorDataPipeline()

def configure_finetune_callback(self):
return [ImageDetectorFineTuning(train_bn=True)]
39 changes: 39 additions & 0 deletions flash_examples/finetuning/image_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import flash
from flash.core.data import download_data
from flash.vision import ImageDetectionData, ImageDetector

# 1. Download the data
# Dataset Credit: https://www.kaggle.com/ultralytics/coco128
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/")

# 2. Load the Data
datamodule = ImageDetectionData.from_coco(
train_folder="data/coco128/images/train2017/",
train_ann_file="data/coco128/annotations/instances_train2017.json",
batch_size=2
)

# 3. Build the model
model = ImageDetector(num_classes=datamodule.num_classes)

# 4. Create the trainer. Run twice on data
trainer = flash.Trainer(max_epochs=2)

# 5. Finetune the model
trainer.finetune(model, datamodule)

# 6. Save it!
trainer.save_checkpoint("image_detection_model.pt")
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ tqdm # comes with 3rd-party dependency
rouge-score>=0.0.4
sentencepiece>=0.1.95
pytorch-lightning-bolts==0.3.0
filelock # comes with 3rd-party dependency
pycocotools>=2.0.2 ; python_version >= "3.7"
Loading

0 comments on commit 13f5106

Please sign in to comment.