Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

IceVision integration #608

Merged
merged 64 commits into from
Aug 16, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
5802dcf
Initial commit
ethanwharris Jul 16, 2021
37044c2
Merge branch 'master' into feature/icevision
ethanwharris Jul 19, 2021
35c0465
Add instance segmentation and keypoint detection tasks
ethanwharris Jul 20, 2021
e79d2ff
Merge branch 'master' into feature/icevision
ethanwharris Jul 20, 2021
21a236d
Updates
ethanwharris Jul 20, 2021
b9dfc48
Updates
ethanwharris Jul 20, 2021
89385bd
Updates
ethanwharris Jul 20, 2021
addfe96
Add docs
ethanwharris Jul 20, 2021
22b4152
Update API reference
ethanwharris Jul 20, 2021
14dd36f
Fix some tests
ethanwharris Jul 20, 2021
1b0642e
Small fix
ethanwharris Jul 21, 2021
4a6c399
Drop failing JIT test
ethanwharris Jul 21, 2021
9e30034
Updates
ethanwharris Jul 21, 2021
00f391e
Updates
ethanwharris Jul 21, 2021
e6ee994
Fix a test
ethanwharris Jul 21, 2021
19a30e1
Merge branch 'master' into feature/icevision
ethanwharris Jul 26, 2021
d548607
Initial credits support
ethanwharris Jul 26, 2021
93cb652
Merge branch 'master' into feature/icevision
ethanwharris Jul 27, 2021
7d9838b
Credit -> provider
ethanwharris Jul 27, 2021
2e8a777
Update available backbones
ethanwharris Jul 27, 2021
a102d31
Add adapter
ethanwharris Jul 29, 2021
8338ba5
Merge branch 'master' into feature/icevision
ethanwharris Jul 29, 2021
ad7722e
Fix a test
ethanwharris Jul 29, 2021
3f34159
Merge branch 'feature/icevision' of https://github.com/PyTorchLightni…
ethanwharris Jul 29, 2021
0cc27da
Merge branch 'master' into feature/icevision
ethanwharris Jul 29, 2021
22afaae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 29, 2021
c50fd24
Merge branch 'master' into feature/icevision
ethanwharris Aug 4, 2021
e19b4c2
Updates
ethanwharris Aug 4, 2021
4cf6332
Fixes
ethanwharris Aug 4, 2021
858acdb
Refactor
ethanwharris Aug 4, 2021
7c6fb2f
Refactor
ethanwharris Aug 4, 2021
89f6978
Refactor
ethanwharris Aug 4, 2021
53e171e
minor changes
ethanwharris Aug 6, 2021
a307f0b
Merge branch 'master' into feature/icevision
ethanwharris Aug 6, 2021
cb3a2f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2021
8725028
0.5.0dev
Borda Aug 6, 2021
dba3145
Merge branch 'master' into feature/icevision
ethanwharris Aug 9, 2021
335073a
pl
Borda Aug 9, 2021
19143db
imports
Borda Aug 9, 2021
b72375e
Update adapter.py
ethanwharris Aug 9, 2021
5a1cb64
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2021
45d121e
Merge branch 'feature/icevision' of https://github.com/PyTorchLightni…
ethanwharris Aug 9, 2021
55377f1
Update adapter.py
ethanwharris Aug 9, 2021
0b43313
Merge branch 'feature/icevision' of https://github.com/PyTorchLightni…
ethanwharris Aug 10, 2021
68648ab
Updates
ethanwharris Aug 10, 2021
878c7b9
Merge branch 'master' into feature/icevision
ethanwharris Aug 10, 2021
12a89dd
Add transforms to and from icevision records
ethanwharris Aug 11, 2021
00b49dd
Merge branch 'feature/icevision' of https://github.com/PyTorchLightni…
ethanwharris Aug 11, 2021
cee3edf
Fix tests
ethanwharris Aug 11, 2021
0b02c55
Try fix
ethanwharris Aug 11, 2021
1824e5e
Update CHANGELOG.md
ethanwharris Aug 11, 2021
6fb7ee3
Fix tests
ethanwharris Aug 11, 2021
221b01c
Fix a test
ethanwharris Aug 11, 2021
1ca9b6b
Try fix
ethanwharris Aug 12, 2021
d97dbdf
Try fix
ethanwharris Aug 12, 2021
ecff056
Merge branch 'master' into feature/icevision
ethanwharris Aug 12, 2021
3b387f7
Add some docs
ethanwharris Aug 12, 2021
16ed49c
Add API reference
ethanwharris Aug 12, 2021
40b7c9b
Small updates
ethanwharris Aug 12, 2021
69cbaf8
Merge branch 'master' into feature/icevision
ethanwharris Aug 13, 2021
cfc02bb
Merge branch 'master' into feature/icevision
ananyahjha93 Aug 16, 2021
ac7743b
pep fix
ananyahjha93 Aug 16, 2021
6c74155
Fixes
ethanwharris Aug 16, 2021
75e1c31
Merge branch 'feature/icevision' of https://github.com/PyTorchLightni…
ananyahjha93 Aug 16, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,4 @@ jigsaw_toxic_comments
flash_examples/serve/tabular_classification/data
logs/cache/*
flash_examples/data
flash_examples/checkpoints
32 changes: 24 additions & 8 deletions docs/source/api/image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ ______________
classification.transforms.default_transforms
classification.transforms.train_default_transforms

Detection
_________
Object Detection
________________

.. autosummary::
:toctree: generated/
Expand All @@ -42,21 +42,37 @@ _________
~detection.model.ObjectDetector
~detection.data.ObjectDetectionData

detection.data.COCODataSource
detection.data.FiftyOneParser
detection.data.ObjectDetectionFiftyOneDataSource
detection.data.ObjectDetectionPreprocess
detection.finetuning.ObjectDetectionFineTuning
detection.model.ObjectDetector
detection.serialization.DetectionLabels
detection.serialization.FiftyOneDetectionLabels

Keypoint Detection
__________________

.. autosummary::
:toctree: generated/
:nosignatures:
:template:
:template: classtemplate.rst

~keypoint_detection.model.KeypointDetector
~keypoint_detection.data.KeypointDetectionData

keypoint_detection.data.KeypointDetectionPreprocess

Instance Segmentation
_____________________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~instance_segmentation.model.InstanceSegmentation
~instance_segmentation.data.InstanceSegmentationData

detection.transforms.collate
detection.transforms.default_transforms
instance_segmentation.data.InstanceSegmentationPreprocess

Embedding
_________
Expand Down
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ Lightning Flash
reference/image_classification_multi_label
reference/image_embedder
reference/object_detection
reference/keypoint_detection
reference/instance_segmentation
reference/semantic_segmentation
reference/style_transfer
reference/video_classification
Expand Down
31 changes: 31 additions & 0 deletions docs/source/reference/instance_segmentation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

.. _instance_segmentation:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can have these labels automatically generated by docs


#####################
Instance Segmentation
#####################

********
The Task
********

Instance segmentation is the task of segmenting objects images and determining their associated classes.

The :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` and :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` classes internally rely on `IceVision <https://airctic.com/>`_.

------

*******
Example
*******

Let's look at instance segmentation with `The Oxford-IIIT Pet Dataset <https://www.robots.ox.ac.uk/~vgg/data/pets/>`_ from `IceData <https://github.com/airctic/icedata>`_.
Once we've downloaded the data, we can create the :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData`.
We select a ``mask_rcnn`` with a ``resnet18_fpn`` backbone to use for our :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` and fine-tune on the pets data.
We then use the trained :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` for inference.
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/instance_segmentation.py
:language: python
:lines: 14-
31 changes: 31 additions & 0 deletions docs/source/reference/keypoint_detection.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

.. _keypoint_detection:

##################
Keypoint Detection
##################

********
The Task
********

Keypoint detection is the task of identifying keypoints in images and their associated classes.

The :class:`~flash.image.keypoint_detection.model.KeypointDetector` and :class:`~flash.image.keypoint_detection.data.KeypointDetectionData` classes internally rely on `IceVision <https://airctic.com/>`_.

------

*******
Example
*******

Let's look at keypoint detection with `BIWI Sample Keypoints (center of face) <https://www.kaggle.com/kmader/biwi-kinect-head-pose-database>`_ from `IceData <https://github.com/airctic/icedata>`_.
Once we've downloaded the data, we can create the :class:`~flash.image.keypoint_detection.data.KeypointDetectionData`.
We select a ``keypoint_rcnn`` with a ``resnet18_fpn`` backbone to use for our :class:`~flash.image.keypoint_detection.model.KeypointDetector` and fine-tune on the BIWI data.
We then use the trained :class:`~flash.image.keypoint_detection.model.KeypointDetector` for inference.
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/keypoint_detection.py
:language: python
:lines: 14-
2 changes: 2 additions & 0 deletions docs/source/reference/object_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ The Task

Object detection is the task of identifying objects in images and their associated classes and bounding boxes.

The :class:`~flash.image.detection.model.ObjectDetector` and :class:`~flash.image.detection.data.ObjectDetectionData` classes internally rely on `IceVision <https://airctic.com/>`_.

------

*******
Expand Down
5 changes: 3 additions & 2 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,12 +363,13 @@ def _predict_dataloader(self) -> DataLoader:
pin_memory = True

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_test_dataset(
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
return self.trainer.lightning_module.process_predict_dataset(
predict_ds,
batch_size=batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn
collate_fn=collate_fn,
convert_to_dataloader=True,
)

return DataLoader(
Expand Down
5 changes: 5 additions & 0 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def _create_collate_preprocessors(
prefix: str = _STAGES_PREFIX[stage]

if collate_fn is not None:
preprocess._original_default_collate = preprocess._default_collate
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
preprocess._default_collate = collate_fn

func_names: Dict[str, str] = {
Expand Down Expand Up @@ -486,6 +487,10 @@ def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[Runnin
elif isinstance(stage, RunningStage):
stages = [stage]

self._preprocess_pipeline._default_collate = getattr(
self._preprocess_pipeline, "_original_default_collate", self._preprocess_pipeline._default_collate
)

for stage in stages:

device_collate = None
Expand Down
Empty file.
67 changes: 67 additions & 0 deletions flash/core/integrations/icevision/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from inspect import getmembers

from torch import nn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _ICEVISION_AVAILABLE

if _ICEVISION_AVAILABLE:
from icevision.backbones import BackboneConfig

OBJECT_DETECTION_HEADS = FlashRegistry("heads")


def icevision_model_adapter(model_type):

class IceVisionModelAdapter(model_type.lightning.ModelAdapter):

def log(self, name, value, **kwargs):
if "prog_bar" not in kwargs:
kwargs["prog_bar"] = True
return super().log(name.split("/")[-1], value, **kwargs)

return IceVisionModelAdapter


def load_icevision(adapter, model_type, backbone, num_classes, **kwargs):
model = model_type.model(backbone=backbone, num_classes=num_classes, **kwargs)

backbone = nn.Module()
params = model.param_groups()[0]
for i, param in enumerate(params):
backbone.register_parameter(f"backbone_{i}", param)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

return model_type, model, adapter(model_type), backbone


def load_icevision_ignore_image_size(adapter, model_type, backbone, num_classes, image_size=None, **kwargs):
return load_icevision(adapter, model_type, backbone, num_classes, **kwargs)


def load_icevision_with_image_size(adapter, model_type, backbone, num_classes, image_size=None, **kwargs):
kwargs["img_size"] = image_size
return load_icevision(adapter, model_type, backbone, num_classes, **kwargs)


def get_backbones(model_type):
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
_BACKBONES = FlashRegistry("backbones")

for backbone_name, backbone_config in getmembers(model_type.backbones, lambda x: isinstance(x, BackboneConfig)):
_BACKBONES(
backbone_config,
name=backbone_name,
)
return _BACKBONES
79 changes: 79 additions & 0 deletions flash/core/integrations/icevision/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type

import numpy as np

from flash.core.data.data_source import DefaultDataKeys
from flash.core.utilities.imports import _ICEVISION_AVAILABLE
from flash.image.data import ImagePathsDataSource

if _ICEVISION_AVAILABLE:
from icevision.core import BaseRecord, ClassMapRecordComponent, ImageRecordComponent, tasks
from icevision.data import SingleSplitSplitter
from icevision.parsers import Parser


class IceVisionPathsDataSource(ImagePathsDataSource):

def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]:
return super().predict_load_data(data, dataset)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
return sample[DefaultDataKeys.INPUT].load()

def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
sample = super().load_sample(sample)
image = np.array(sample[DefaultDataKeys.INPUT])
record = BaseRecord([ImageRecordComponent()])

record.set_img(image)
record.add_component(ClassMapRecordComponent(task=tasks.detection))
return record


class IceVisionParserDataSource(IceVisionPathsDataSource):

def __init__(self, parser: Optional[Type['Parser']] = None):
super().__init__()
self.parser = parser

def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]:
root, ann_file = data

if self.parser is not None:
parser = self.parser(ann_file, root)
dataset.num_classes = len(parser.class_map)
records = parser.parse(data_splitter=SingleSplitSplitter())
return [{DefaultDataKeys.INPUT: record} for record in records[0]]
else:
raise ValueError("The parser type must be provided")


class IceDataParserDataSource(IceVisionPathsDataSource):

def __init__(self, parser: Optional[Callable] = None):
super().__init__()
self.parser = parser

def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]:
root = data

if self.parser is not None:
parser = self.parser(root)
dataset.num_classes = len(parser.class_map)
records = parser.parse(data_splitter=SingleSplitSplitter())
return [{DefaultDataKeys.INPUT: record} for record in records[0]]
else:
raise ValueError("The parser must be provided")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could they be merged together ?

Loading