Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
IceVision integration (#608)
Browse files Browse the repository at this point in the history
* Initial commit

* Add instance segmentation and keypoint detection tasks

* Updates

* Updates

* Updates

* Add docs

* Update API reference

* Fix some tests

* Small fix

* Drop failing JIT test

* Updates

* Updates

* Fix a test

* Initial credits support

* Credit -> provider

* Update available backbones

* Add adapter

* Fix a test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Updates

* Fixes

* Refactor

* Refactor

* Refactor

* minor changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* 0.5.0dev

* pl

* imports

* Update adapter.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update adapter.py

* Updates

* Add transforms to and from icevision records

* Fix tests

* Try fix

* Update CHANGELOG.md

* Fix tests

* Fix a test

* Try fix

* Try fix

* Add some docs

* Add API reference

* Small updates

* pep fix

* Fixes

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <[email protected]>
Co-authored-by: Ananya Harsh Jha <[email protected]>
  • Loading branch information
4 people authored Aug 16, 2021
1 parent d094fee commit d9dc2f0
Show file tree
Hide file tree
Showing 54 changed files with 2,482 additions and 731 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ jobs:
run: |
sudo apt-get install libsndfile1
pip install matplotlib
pip install '.[image]' --pre --upgrade
pip install '.[audio,image]' --pre --upgrade
- name: Cache datasets
uses: actions/cache@v2
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ jigsaw_toxic_comments
flash_examples/serve/tabular_classification/data
logs/cache/*
flash_examples/data
flash_examples/cli/*/data
flash_examples/checkpoints
timit/
urban8k_images/
__MACOSX
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added option to pass a `resolver` to the `from_csv` and `from_pandas` methods of `ImageClassificationData`, which is used to resolve filenames given IDs ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651))

- Added integration with IceVision for the `ObjectDetector` ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608))

- Added keypoint detection task ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608))

- Added instance segmentation task ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand All @@ -48,6 +54,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed the behaviour of the `sampler` argument of the `DataModule` to take a `Sampler` type rather than instantiated object ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651))

- Changed arguments to `ObjectDetector`, use `head` instead of `model` and append `_fpn` to the backbone name instead of the `fpn` argument ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608))

### Fixed

- Fixed a bug where serve sanity checking would not be triggered using the latest PyTorchLightning version ([#493](https://github.com/PyTorchLightning/lightning-flash/pull/493))
Expand Down
13 changes: 13 additions & 0 deletions docs/source/api/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@ flash.core
:local:
:backlinks: top

flash.core.adapter
__________________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~flash.core.adapter.Adapter
~flash.core.adapter.AdapterTask

flash.core.classification
_________________________

Expand Down Expand Up @@ -56,6 +67,8 @@ ________________

~flash.core.model.BenchmarkConvergenceCI
~flash.core.model.CheckDependenciesMeta
~flash.core.model.ModuleWrapperBase
~flash.core.model.DatasetProcessor
~flash.core.model.Task

flash.core.registry
Expand Down
32 changes: 24 additions & 8 deletions docs/source/api/image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ ______________
classification.transforms.default_transforms
classification.transforms.train_default_transforms

Detection
_________
Object Detection
________________

.. autosummary::
:toctree: generated/
Expand All @@ -42,21 +42,37 @@ _________
~detection.model.ObjectDetector
~detection.data.ObjectDetectionData

detection.data.COCODataSource
detection.data.FiftyOneParser
detection.data.ObjectDetectionFiftyOneDataSource
detection.data.ObjectDetectionPreprocess
detection.finetuning.ObjectDetectionFineTuning
detection.model.ObjectDetector
detection.serialization.DetectionLabels
detection.serialization.FiftyOneDetectionLabels

Keypoint Detection
__________________

.. autosummary::
:toctree: generated/
:nosignatures:
:template:
:template: classtemplate.rst

~keypoint_detection.model.KeypointDetector
~keypoint_detection.data.KeypointDetectionData

keypoint_detection.data.KeypointDetectionPreprocess

Instance Segmentation
_____________________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~instance_segmentation.model.InstanceSegmentation
~instance_segmentation.data.InstanceSegmentationData

detection.transforms.collate
detection.transforms.default_transforms
instance_segmentation.data.InstanceSegmentationPreprocess

Embedding
_________
Expand Down
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ Lightning Flash
reference/image_classification_multi_label
reference/image_embedder
reference/object_detection
reference/keypoint_detection
reference/instance_segmentation
reference/semantic_segmentation
reference/style_transfer
reference/video_classification
Expand Down
31 changes: 31 additions & 0 deletions docs/source/reference/instance_segmentation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

.. _instance_segmentation:

#####################
Instance Segmentation
#####################

********
The Task
********

Instance segmentation is the task of segmenting objects images and determining their associated classes.

The :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` and :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` classes internally rely on `IceVision <https://airctic.com/>`_.

------

*******
Example
*******

Let's look at instance segmentation with `The Oxford-IIIT Pet Dataset <https://www.robots.ox.ac.uk/~vgg/data/pets/>`_ from `IceData <https://github.com/airctic/icedata>`_.
Once we've downloaded the data, we can create the :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData`.
We select a ``mask_rcnn`` with a ``resnet18_fpn`` backbone to use for our :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` and fine-tune on the pets data.
We then use the trained :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` for inference.
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/instance_segmentation.py
:language: python
:lines: 14-
31 changes: 31 additions & 0 deletions docs/source/reference/keypoint_detection.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

.. _keypoint_detection:

##################
Keypoint Detection
##################

********
The Task
********

Keypoint detection is the task of identifying keypoints in images and their associated classes.

The :class:`~flash.image.keypoint_detection.model.KeypointDetector` and :class:`~flash.image.keypoint_detection.data.KeypointDetectionData` classes internally rely on `IceVision <https://airctic.com/>`_.

------

*******
Example
*******

Let's look at keypoint detection with `BIWI Sample Keypoints (center of face) <https://www.kaggle.com/kmader/biwi-kinect-head-pose-database>`_ from `IceData <https://github.com/airctic/icedata>`_.
Once we've downloaded the data, we can create the :class:`~flash.image.keypoint_detection.data.KeypointDetectionData`.
We select a ``keypoint_rcnn`` with a ``resnet18_fpn`` backbone to use for our :class:`~flash.image.keypoint_detection.model.KeypointDetector` and fine-tune on the BIWI data.
We then use the trained :class:`~flash.image.keypoint_detection.model.KeypointDetector` for inference.
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/keypoint_detection.py
:language: python
:lines: 14-
2 changes: 2 additions & 0 deletions docs/source/reference/object_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ The Task

Object detection is the task of identifying objects in images and their associated classes and bounding boxes.

The :class:`~flash.image.detection.model.ObjectDetector` and :class:`~flash.image.detection.data.ObjectDetectionData` classes internally rely on `IceVision <https://airctic.com/>`_.

------

*******
Expand Down
2 changes: 1 addition & 1 deletion flash/__about__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.4.1dev"
__version__ = "0.5.0dev"
__author__ = "PyTorchLightning et al."
__author_email__ = "[email protected]"
__license__ = "Apache-2.0"
Expand Down
162 changes: 162 additions & 0 deletions flash/core/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
from typing import Any, Callable, Optional

from torch import nn
from torch.utils.data import DataLoader, Sampler

import flash
from flash.core.data.auto_dataset import BaseAutoDataset
from flash.core.model import DatasetProcessor, ModuleWrapperBase, Task


class Adapter(DatasetProcessor, ModuleWrapperBase, nn.Module):
"""The ``Adapter`` is a lightweight interface that can be used to encapsulate the logic from a particular
provider within a :class:`~flash.core.model.Task`."""

@classmethod
@abstractmethod
def from_task(cls, task: "flash.Task", **kwargs) -> "Adapter":
"""Instantiate the adapter from the given :class:`~flash.core.model.Task`.
This includes resolution / creation of backbones / heads and any other provider specific options.
"""

def forward(self, x: Any) -> Any:
pass

def training_step(self, batch: Any, batch_idx: int) -> Any:
pass

def validation_step(self, batch: Any, batch_idx: int) -> None:
pass

def test_step(self, batch: Any, batch_idx: int) -> None:
pass

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
pass

def training_epoch_end(self, outputs) -> None:
pass

def validation_epoch_end(self, outputs) -> None:
pass

def test_epoch_end(self, outputs) -> None:
pass


class AdapterTask(Task):
"""The ``AdapterTask`` is a :class:`~flash.core.model.Task` which wraps an :class:`~flash.core.adapter.Adapter`
and forwards all of the hooks.
Args:
adapter: The :class:`~flash.core.adapter.Adapter` to wrap.
kwargs: Keyword arguments to be passed to the base :class:`~flash.core.model.Task`.
"""

def __init__(self, adapter: Adapter, **kwargs):
super().__init__(**kwargs)

self.adapter = adapter

@property
def backbone(self) -> nn.Module:
return self.adapter.backbone

def forward(self, x: Any) -> Any:
return self.adapter.forward(x)

def training_step(self, batch: Any, batch_idx: int) -> Any:
return self.adapter.training_step(batch, batch_idx)

def validation_step(self, batch: Any, batch_idx: int) -> None:
return self.adapter.validation_step(batch, batch_idx)

def test_step(self, batch: Any, batch_idx: int) -> None:
return self.adapter.test_step(batch, batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self.adapter.predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)

def training_epoch_end(self, outputs) -> None:
return self.adapter.training_epoch_end(outputs)

def validation_epoch_end(self, outputs) -> None:
return self.adapter.validation_epoch_end(outputs)

def test_epoch_end(self, outputs) -> None:
return self.adapter.test_epoch_end(outputs)

def process_train_dataset(
self,
dataset: BaseAutoDataset,
batch_size: int,
num_workers: int,
pin_memory: bool,
collate_fn: Callable,
shuffle: bool = False,
drop_last: bool = True,
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_train_dataset(
dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
)

def process_val_dataset(
self,
dataset: BaseAutoDataset,
batch_size: int,
num_workers: int,
pin_memory: bool,
collate_fn: Callable,
shuffle: bool = False,
drop_last: bool = False,
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_val_dataset(
dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
)

def process_test_dataset(
self,
dataset: BaseAutoDataset,
batch_size: int,
num_workers: int,
pin_memory: bool,
collate_fn: Callable,
shuffle: bool = False,
drop_last: bool = False,
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_test_dataset(
dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
)

def process_predict_dataset(
self,
dataset: BaseAutoDataset,
batch_size: int = 1,
num_workers: int = 0,
pin_memory: bool = False,
collate_fn: Callable = lambda x: x,
shuffle: bool = False,
drop_last: bool = True,
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_predict_dataset(
dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
)
2 changes: 1 addition & 1 deletion flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def _predict_dataloader(self) -> DataLoader:
pin_memory = True

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_test_dataset(
return self.trainer.lightning_module.process_predict_dataset(
predict_ds,
batch_size=batch_size,
num_workers=self.num_workers,
Expand Down
6 changes: 4 additions & 2 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,10 @@ def _identity(samples: Sequence[Any]) -> Sequence[Any]:
def deserialize_processor(self) -> _DeserializeProcessor:
return self._create_collate_preprocessors(RunningStage.PREDICTING)[0]

def worker_preprocessor(self, running_stage: RunningStage, is_serving: bool = False) -> _Preprocessor:
return self._create_collate_preprocessors(running_stage, is_serving=is_serving)[1]
def worker_preprocessor(
self, running_stage: RunningStage, collate_fn: Optional[Callable] = None, is_serving: bool = False
) -> _Preprocessor:
return self._create_collate_preprocessors(running_stage, collate_fn=collate_fn, is_serving=is_serving)[1]

def device_preprocessor(self, running_stage: RunningStage) -> _Preprocessor:
return self._create_collate_preprocessors(running_stage)[2]
Expand Down
Empty file.
Loading

0 comments on commit d9dc2f0

Please sign in to comment.