This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 211
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add CustomCOCODataset for detection * add steps for the detection task * update steps for detector task * add data pipeline * add import error for coco api * add ref to bolts * add base finetuning * add dataset * add fine tuning script * update fine tuning script * add imge detector data module * handle images with no annotations * add test step * fix crowd coco assign * add test for COCO dataloader * update example format * update test for COCO dataloader * add pycoco to requirements * add pycoco to requirements with python version * skip test if coco not installed * add test for data model integration * skip integration test if coco not installed * add model predict in integration tests * update dummy image shapes * add save hyperparameters * fix labels assignment * add singular data pipeline * add updates for predict * add cython to requirements * add cython to requirements * fix failing tests * use download_data with coco128 url * update imports * handle degenerated boxes * Apply suggestions from code review Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
- Loading branch information
1 parent
7acd8b5
commit c60fef8
Showing
10 changed files
with
504 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from flash.vision.classification import ImageClassificationData, ImageClassifier | ||
from flash.vision.detection import ImageDetector | ||
from flash.vision.detection import ImageDetectionData, ImageDetector | ||
from flash.vision.embedding import ImageEmbedder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from flash.vision.detection.data import ImageDetectionData | ||
from flash.vision.detection.model import ImageDetector |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
from typing import Any, Callable, List, Optional, Tuple | ||
|
||
import torch | ||
from PIL import Image | ||
from pytorch_lightning.utilities import _module_available | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from torch import Tensor | ||
from torch._six import container_abcs | ||
from torch.utils.data._utils.collate import default_collate | ||
from torchvision import transforms as T | ||
|
||
from flash.core.data import TaskDataPipeline | ||
from flash.core.data.datamodule import DataModule | ||
from flash.core.data.utils import _contains_any_tensor | ||
from flash.vision.classification.data import _pil_loader | ||
|
||
_COCO_AVAILABLE = _module_available("pycocotools") | ||
if _COCO_AVAILABLE: | ||
from pycocotools.coco import COCO | ||
|
||
|
||
class CustomCOCODataset(torch.utils.data.Dataset): | ||
|
||
def __init__( | ||
self, | ||
root: str, | ||
ann_file: str, | ||
transforms: Optional[Callable] = None, | ||
): | ||
if not _COCO_AVAILABLE: | ||
raise ImportError("Kindly install the COCO API `pycocotools` to use the Dataset") | ||
|
||
self.root = root | ||
self.transforms = transforms | ||
self.coco = COCO(ann_file) | ||
self.ids = list(sorted(self.coco.imgs.keys())) | ||
|
||
@property | ||
def num_classes(self): | ||
categories = self.coco.loadCats(self.coco.getCatIds()) | ||
if not categories: | ||
raise ValueError("No Categories found") | ||
return categories[-1]["id"] + 1 | ||
|
||
def __getitem__(self, index: int) -> Tuple[Any, Any]: | ||
coco = self.coco | ||
img_idx = self.ids[index] | ||
|
||
ann_ids = coco.getAnnIds(imgIds=img_idx) | ||
annotations = coco.loadAnns(ann_ids) | ||
|
||
image_path = coco.loadImgs(img_idx)[0]["file_name"] | ||
img = Image.open(os.path.join(self.root, image_path)) | ||
|
||
boxes = [] | ||
labels = [] | ||
areas = [] | ||
iscrowd = [] | ||
|
||
for obj in annotations: | ||
xmin = obj["bbox"][0] | ||
ymin = obj["bbox"][1] | ||
xmax = xmin + obj["bbox"][2] | ||
ymax = ymin + obj["bbox"][3] | ||
|
||
bbox = [xmin, ymin, xmax, ymax] | ||
keep = (bbox[3] > bbox[1]) & (bbox[2] > bbox[0]) | ||
if keep: | ||
boxes.append(bbox) | ||
labels.append(obj["category_id"]) | ||
areas.append(obj["area"]) | ||
iscrowd.append(obj["iscrowd"]) | ||
|
||
target = {} | ||
target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32) | ||
target["labels"] = torch.as_tensor(labels, dtype=torch.int64) | ||
target["image_id"] = torch.tensor([img_idx]) | ||
target["area"] = torch.as_tensor(areas, dtype=torch.float32) | ||
target["iscrowd"] = torch.as_tensor(iscrowd, dtype=torch.int64) | ||
|
||
if self.transforms is not None: | ||
img = self.transforms(img) | ||
|
||
return img, target | ||
|
||
def __len__(self): | ||
return len(self.ids) | ||
|
||
|
||
def _coco_remove_images_without_annotations(dataset): | ||
# Ref: https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py | ||
|
||
def _has_only_empty_bbox(anno: List): | ||
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) | ||
|
||
def _has_valid_annotation(anno: List): | ||
# if it's empty, there is no annotation | ||
if not anno: | ||
return False | ||
# if all boxes have close to zero area, there is no annotation | ||
if _has_only_empty_bbox(anno): | ||
return False | ||
return True | ||
|
||
ids = [] | ||
for ds_idx, img_id in enumerate(dataset.ids): | ||
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) | ||
anno = dataset.coco.loadAnns(ann_ids) | ||
if _has_valid_annotation(anno): | ||
ids.append(ds_idx) | ||
|
||
dataset = torch.utils.data.Subset(dataset, ids) | ||
return dataset | ||
|
||
|
||
_default_transform = T.ToTensor() | ||
|
||
|
||
class ImageDetectorDataPipeline(TaskDataPipeline): | ||
|
||
def __init__(self, valid_transform: Optional[Callable] = _default_transform, loader: Callable = _pil_loader): | ||
self._valid_transform = valid_transform | ||
self._loader = loader | ||
|
||
def before_collate(self, samples: Any) -> Any: | ||
if _contains_any_tensor(samples): | ||
return samples | ||
|
||
if isinstance(samples, str): | ||
samples = [samples] | ||
|
||
if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples): | ||
outputs = [] | ||
for sample in samples: | ||
output = self._loader(sample) | ||
outputs.append(self._valid_transform(output)) | ||
return outputs | ||
raise MisconfigurationException("The samples should either be a tensor, a list of paths or a path.") | ||
|
||
def collate(self, samples: Any) -> Any: | ||
if not isinstance(samples, Tensor): | ||
elem = samples[0] | ||
if isinstance(elem, container_abcs.Sequence): | ||
return tuple(zip(*samples)) | ||
return default_collate(samples) | ||
return samples.unsqueeze(dim=0) | ||
|
||
|
||
class ImageDetectionData(DataModule): | ||
|
||
@classmethod | ||
def from_coco( | ||
cls, | ||
train_folder: Optional[str] = None, | ||
train_ann_file: Optional[str] = None, | ||
train_transform: Optional[Callable] = _default_transform, | ||
valid_folder: Optional[str] = None, | ||
valid_ann_file: Optional[str] = None, | ||
valid_transform: Optional[Callable] = _default_transform, | ||
test_folder: Optional[str] = None, | ||
test_ann_file: Optional[str] = None, | ||
test_transform: Optional[Callable] = _default_transform, | ||
batch_size: int = 4, | ||
num_workers: Optional[int] = None, | ||
**kwargs | ||
): | ||
train_ds = CustomCOCODataset(train_folder, train_ann_file, train_transform) | ||
num_classes = train_ds.num_classes | ||
train_ds = _coco_remove_images_without_annotations(train_ds) | ||
|
||
valid_ds = ( | ||
CustomCOCODataset(valid_folder, valid_ann_file, valid_transform) if valid_folder is not None else None | ||
) | ||
|
||
test_ds = (CustomCOCODataset(test_folder, test_ann_file, test_transform) if test_folder is not None else None) | ||
|
||
datamodule = cls( | ||
train_ds=train_ds, | ||
valid_ds=valid_ds, | ||
test_ds=test_ds, | ||
batch_size=batch_size, | ||
num_workers=num_workers, | ||
) | ||
|
||
datamodule.num_classes = num_classes | ||
datamodule.data_pipeline = ImageDetectorDataPipeline() | ||
return datamodule |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import pytorch_lightning as pl | ||
|
||
from flash.core.finetuning import FlashBaseFinetuning | ||
|
||
|
||
class ImageDetectorFineTuning(FlashBaseFinetuning): | ||
""" | ||
Freezes the backbone during Detector training. | ||
""" | ||
|
||
def __init__(self, train_bn: bool = True): | ||
self.train_bn = train_bn | ||
|
||
def freeze_before_training(self, pl_module: pl.LightningModule) -> None: | ||
model = pl_module.model | ||
self.freeze(module=model.backbone, train_bn=self.train_bn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import flash | ||
from flash.core.data import download_data | ||
from flash.vision import ImageDetectionData, ImageDetector | ||
|
||
# 1. Download the data | ||
# Dataset Credit: https://www.kaggle.com/ultralytics/coco128 | ||
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") | ||
|
||
# 2. Load the Data | ||
datamodule = ImageDetectionData.from_coco( | ||
train_folder="data/coco128/images/train2017/", | ||
train_ann_file="data/coco128/annotations/instances_train2017.json", | ||
batch_size=2 | ||
) | ||
|
||
# 3. Build the model | ||
model = ImageDetector(num_classes=datamodule.num_classes) | ||
|
||
# 4. Create the trainer. Run twice on data | ||
trainer = flash.Trainer(max_epochs=2) | ||
|
||
# 5. Finetune the model | ||
trainer.finetune(model, datamodule) | ||
|
||
# 6. Save it! | ||
trainer.save_checkpoint("image_detection_model.pt") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.