Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

IceVision integration #608

Merged
merged 64 commits into from
Aug 16, 2021
Merged
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
5802dcf
Initial commit
ethanwharris Jul 16, 2021
37044c2
Merge branch 'master' into feature/icevision
ethanwharris Jul 19, 2021
35c0465
Add instance segmentation and keypoint detection tasks
ethanwharris Jul 20, 2021
e79d2ff
Merge branch 'master' into feature/icevision
ethanwharris Jul 20, 2021
21a236d
Updates
ethanwharris Jul 20, 2021
b9dfc48
Updates
ethanwharris Jul 20, 2021
89385bd
Updates
ethanwharris Jul 20, 2021
addfe96
Add docs
ethanwharris Jul 20, 2021
22b4152
Update API reference
ethanwharris Jul 20, 2021
14dd36f
Fix some tests
ethanwharris Jul 20, 2021
1b0642e
Small fix
ethanwharris Jul 21, 2021
4a6c399
Drop failing JIT test
ethanwharris Jul 21, 2021
9e30034
Updates
ethanwharris Jul 21, 2021
00f391e
Updates
ethanwharris Jul 21, 2021
e6ee994
Fix a test
ethanwharris Jul 21, 2021
19a30e1
Merge branch 'master' into feature/icevision
ethanwharris Jul 26, 2021
d548607
Initial credits support
ethanwharris Jul 26, 2021
93cb652
Merge branch 'master' into feature/icevision
ethanwharris Jul 27, 2021
7d9838b
Credit -> provider
ethanwharris Jul 27, 2021
2e8a777
Update available backbones
ethanwharris Jul 27, 2021
a102d31
Add adapter
ethanwharris Jul 29, 2021
8338ba5
Merge branch 'master' into feature/icevision
ethanwharris Jul 29, 2021
ad7722e
Fix a test
ethanwharris Jul 29, 2021
3f34159
Merge branch 'feature/icevision' of https://github.com/PyTorchLightni…
ethanwharris Jul 29, 2021
0cc27da
Merge branch 'master' into feature/icevision
ethanwharris Jul 29, 2021
22afaae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 29, 2021
c50fd24
Merge branch 'master' into feature/icevision
ethanwharris Aug 4, 2021
e19b4c2
Updates
ethanwharris Aug 4, 2021
4cf6332
Fixes
ethanwharris Aug 4, 2021
858acdb
Refactor
ethanwharris Aug 4, 2021
7c6fb2f
Refactor
ethanwharris Aug 4, 2021
89f6978
Refactor
ethanwharris Aug 4, 2021
53e171e
minor changes
ethanwharris Aug 6, 2021
a307f0b
Merge branch 'master' into feature/icevision
ethanwharris Aug 6, 2021
cb3a2f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2021
8725028
0.5.0dev
Borda Aug 6, 2021
dba3145
Merge branch 'master' into feature/icevision
ethanwharris Aug 9, 2021
335073a
pl
Borda Aug 9, 2021
19143db
imports
Borda Aug 9, 2021
b72375e
Update adapter.py
ethanwharris Aug 9, 2021
5a1cb64
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2021
45d121e
Merge branch 'feature/icevision' of https://github.com/PyTorchLightni…
ethanwharris Aug 9, 2021
55377f1
Update adapter.py
ethanwharris Aug 9, 2021
0b43313
Merge branch 'feature/icevision' of https://github.com/PyTorchLightni…
ethanwharris Aug 10, 2021
68648ab
Updates
ethanwharris Aug 10, 2021
878c7b9
Merge branch 'master' into feature/icevision
ethanwharris Aug 10, 2021
12a89dd
Add transforms to and from icevision records
ethanwharris Aug 11, 2021
00b49dd
Merge branch 'feature/icevision' of https://github.com/PyTorchLightni…
ethanwharris Aug 11, 2021
cee3edf
Fix tests
ethanwharris Aug 11, 2021
0b02c55
Try fix
ethanwharris Aug 11, 2021
1824e5e
Update CHANGELOG.md
ethanwharris Aug 11, 2021
6fb7ee3
Fix tests
ethanwharris Aug 11, 2021
221b01c
Fix a test
ethanwharris Aug 11, 2021
1ca9b6b
Try fix
ethanwharris Aug 12, 2021
d97dbdf
Try fix
ethanwharris Aug 12, 2021
ecff056
Merge branch 'master' into feature/icevision
ethanwharris Aug 12, 2021
3b387f7
Add some docs
ethanwharris Aug 12, 2021
16ed49c
Add API reference
ethanwharris Aug 12, 2021
40b7c9b
Small updates
ethanwharris Aug 12, 2021
69cbaf8
Merge branch 'master' into feature/icevision
ethanwharris Aug 13, 2021
cfc02bb
Merge branch 'master' into feature/icevision
ananyahjha93 Aug 16, 2021
ac7743b
pep fix
ananyahjha93 Aug 16, 2021
6c74155
Fixes
ethanwharris Aug 16, 2021
75e1c31
Merge branch 'feature/icevision' of https://github.com/PyTorchLightni…
ananyahjha93 Aug 16, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ jobs:
run: |
sudo apt-get install libsndfile1
pip install matplotlib
pip install '.[image]' --pre --upgrade
pip install '.[audio,image]' --pre --upgrade

- name: Cache datasets
uses: actions/cache@v2
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ jigsaw_toxic_comments
flash_examples/serve/tabular_classification/data
logs/cache/*
flash_examples/data
flash_examples/cli/*/data
flash_examples/checkpoints
timit/
urban8k_images/
__MACOSX
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added option to pass a `resolver` to the `from_csv` and `from_pandas` methods of `ImageClassificationData`, which is used to resolve filenames given IDs ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651))

- Added integration with IceVision for the `ObjectDetector` ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608))

- Added keypoint detection task ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608))

- Added instance segmentation task ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand All @@ -48,6 +54,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed the behaviour of the `sampler` argument of the `DataModule` to take a `Sampler` type rather than instantiated object ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651))

- Changed arguments to `ObjectDetector`, use `head` instead of `model` and append `_fpn` to the backbone name instead of the `fpn` argument ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608))

### Fixed

- Fixed a bug where serve sanity checking would not be triggered using the latest PyTorchLightning version ([#493](https://github.com/PyTorchLightning/lightning-flash/pull/493))
Expand Down
13 changes: 13 additions & 0 deletions docs/source/api/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@ flash.core
:local:
:backlinks: top

flash.core.adapter
__________________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~flash.core.adapter.Adapter
~flash.core.adapter.AdapterTask

flash.core.classification
_________________________

Expand Down Expand Up @@ -56,6 +67,8 @@ ________________

~flash.core.model.BenchmarkConvergenceCI
~flash.core.model.CheckDependenciesMeta
~flash.core.model.ModuleWrapperBase
~flash.core.model.DatasetProcessor
~flash.core.model.Task

flash.core.registry
Expand Down
32 changes: 24 additions & 8 deletions docs/source/api/image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ ______________
classification.transforms.default_transforms
classification.transforms.train_default_transforms

Detection
_________
Object Detection
________________

.. autosummary::
:toctree: generated/
Expand All @@ -42,21 +42,37 @@ _________
~detection.model.ObjectDetector
~detection.data.ObjectDetectionData

detection.data.COCODataSource
detection.data.FiftyOneParser
detection.data.ObjectDetectionFiftyOneDataSource
detection.data.ObjectDetectionPreprocess
detection.finetuning.ObjectDetectionFineTuning
detection.model.ObjectDetector
detection.serialization.DetectionLabels
detection.serialization.FiftyOneDetectionLabels

Keypoint Detection
__________________

.. autosummary::
:toctree: generated/
:nosignatures:
:template:
:template: classtemplate.rst

~keypoint_detection.model.KeypointDetector
~keypoint_detection.data.KeypointDetectionData

keypoint_detection.data.KeypointDetectionPreprocess

Instance Segmentation
_____________________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~instance_segmentation.model.InstanceSegmentation
~instance_segmentation.data.InstanceSegmentationData

detection.transforms.collate
detection.transforms.default_transforms
instance_segmentation.data.InstanceSegmentationPreprocess

Embedding
_________
Expand Down
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ Lightning Flash
reference/image_classification_multi_label
reference/image_embedder
reference/object_detection
reference/keypoint_detection
reference/instance_segmentation
reference/semantic_segmentation
reference/style_transfer
reference/video_classification
Expand Down
31 changes: 31 additions & 0 deletions docs/source/reference/instance_segmentation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

.. _instance_segmentation:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can have these labels automatically generated by docs


#####################
Instance Segmentation
#####################

********
The Task
********

Instance segmentation is the task of segmenting objects images and determining their associated classes.

The :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` and :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` classes internally rely on `IceVision <https://airctic.com/>`_.

------

*******
Example
*******

Let's look at instance segmentation with `The Oxford-IIIT Pet Dataset <https://www.robots.ox.ac.uk/~vgg/data/pets/>`_ from `IceData <https://github.com/airctic/icedata>`_.
Once we've downloaded the data, we can create the :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData`.
We select a ``mask_rcnn`` with a ``resnet18_fpn`` backbone to use for our :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` and fine-tune on the pets data.
We then use the trained :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` for inference.
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/instance_segmentation.py
:language: python
:lines: 14-
31 changes: 31 additions & 0 deletions docs/source/reference/keypoint_detection.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

.. _keypoint_detection:

##################
Keypoint Detection
##################

********
The Task
********

Keypoint detection is the task of identifying keypoints in images and their associated classes.

The :class:`~flash.image.keypoint_detection.model.KeypointDetector` and :class:`~flash.image.keypoint_detection.data.KeypointDetectionData` classes internally rely on `IceVision <https://airctic.com/>`_.

------

*******
Example
*******

Let's look at keypoint detection with `BIWI Sample Keypoints (center of face) <https://www.kaggle.com/kmader/biwi-kinect-head-pose-database>`_ from `IceData <https://github.com/airctic/icedata>`_.
Once we've downloaded the data, we can create the :class:`~flash.image.keypoint_detection.data.KeypointDetectionData`.
We select a ``keypoint_rcnn`` with a ``resnet18_fpn`` backbone to use for our :class:`~flash.image.keypoint_detection.model.KeypointDetector` and fine-tune on the BIWI data.
We then use the trained :class:`~flash.image.keypoint_detection.model.KeypointDetector` for inference.
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/keypoint_detection.py
:language: python
:lines: 14-
2 changes: 2 additions & 0 deletions docs/source/reference/object_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ The Task

Object detection is the task of identifying objects in images and their associated classes and bounding boxes.

The :class:`~flash.image.detection.model.ObjectDetector` and :class:`~flash.image.detection.data.ObjectDetectionData` classes internally rely on `IceVision <https://airctic.com/>`_.

------

*******
Expand Down
2 changes: 1 addition & 1 deletion flash/__about__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.4.1dev"
__version__ = "0.5.0dev"
__author__ = "PyTorchLightning et al."
__author_email__ = "[email protected]"
__license__ = "Apache-2.0"
Expand Down
162 changes: 162 additions & 0 deletions flash/core/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
from typing import Any, Callable, Optional

from torch import nn
from torch.utils.data import DataLoader, Sampler

import flash
from flash.core.data.auto_dataset import BaseAutoDataset
from flash.core.model import DatasetProcessor, ModuleWrapperBase, Task


class Adapter(DatasetProcessor, ModuleWrapperBase, nn.Module):
"""The ``Adapter`` is a lightweight interface that can be used to encapsulate the logic from a particular
provider within a :class:`~flash.core.model.Task`."""

@classmethod
@abstractmethod
def from_task(cls, task: "flash.Task", **kwargs) -> "Adapter":
"""Instantiate the adapter from the given :class:`~flash.core.model.Task`.

This includes resolution / creation of backbones / heads and any other provider specific options.
"""

def forward(self, x: Any) -> Any:
pass

def training_step(self, batch: Any, batch_idx: int) -> Any:
pass

def validation_step(self, batch: Any, batch_idx: int) -> None:
pass

def test_step(self, batch: Any, batch_idx: int) -> None:
pass

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
pass

def training_epoch_end(self, outputs) -> None:
pass

def validation_epoch_end(self, outputs) -> None:
pass

def test_epoch_end(self, outputs) -> None:
pass


class AdapterTask(Task):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the AdapterTask use the mixins ?

"""The ``AdapterTask`` is a :class:`~flash.core.model.Task` which wraps an :class:`~flash.core.adapter.Adapter`
and forwards all of the hooks.

Args:
adapter: The :class:`~flash.core.adapter.Adapter` to wrap.
kwargs: Keyword arguments to be passed to the base :class:`~flash.core.model.Task`.
"""

def __init__(self, adapter: Adapter, **kwargs):
super().__init__(**kwargs)

self.adapter = adapter

@property
def backbone(self) -> nn.Module:
return self.adapter.backbone

def forward(self, x: Any) -> Any:
return self.adapter.forward(x)

def training_step(self, batch: Any, batch_idx: int) -> Any:
return self.adapter.training_step(batch, batch_idx)

def validation_step(self, batch: Any, batch_idx: int) -> None:
return self.adapter.validation_step(batch, batch_idx)

def test_step(self, batch: Any, batch_idx: int) -> None:
return self.adapter.test_step(batch, batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self.adapter.predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)

def training_epoch_end(self, outputs) -> None:
return self.adapter.training_epoch_end(outputs)

def validation_epoch_end(self, outputs) -> None:
return self.adapter.validation_epoch_end(outputs)

def test_epoch_end(self, outputs) -> None:
return self.adapter.test_epoch_end(outputs)

def process_train_dataset(
self,
dataset: BaseAutoDataset,
batch_size: int,
num_workers: int,
pin_memory: bool,
collate_fn: Callable,
shuffle: bool = False,
drop_last: bool = True,
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_train_dataset(
dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
)

def process_val_dataset(
self,
dataset: BaseAutoDataset,
batch_size: int,
num_workers: int,
pin_memory: bool,
collate_fn: Callable,
shuffle: bool = False,
drop_last: bool = False,
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_val_dataset(
dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
)

def process_test_dataset(
self,
dataset: BaseAutoDataset,
batch_size: int,
num_workers: int,
pin_memory: bool,
collate_fn: Callable,
shuffle: bool = False,
drop_last: bool = False,
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_test_dataset(
dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
)

def process_predict_dataset(
self,
dataset: BaseAutoDataset,
batch_size: int = 1,
num_workers: int = 0,
pin_memory: bool = False,
collate_fn: Callable = lambda x: x,
shuffle: bool = False,
drop_last: bool = True,
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_predict_dataset(
dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
)
2 changes: 1 addition & 1 deletion flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def _predict_dataloader(self) -> DataLoader:
pin_memory = True

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_test_dataset(
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
return self.trainer.lightning_module.process_predict_dataset(
predict_ds,
batch_size=batch_size,
num_workers=self.num_workers,
Expand Down
6 changes: 4 additions & 2 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,10 @@ def _identity(samples: Sequence[Any]) -> Sequence[Any]:
def deserialize_processor(self) -> _DeserializeProcessor:
return self._create_collate_preprocessors(RunningStage.PREDICTING)[0]

def worker_preprocessor(self, running_stage: RunningStage, is_serving: bool = False) -> _Preprocessor:
return self._create_collate_preprocessors(running_stage, is_serving=is_serving)[1]
def worker_preprocessor(
self, running_stage: RunningStage, collate_fn: Optional[Callable] = None, is_serving: bool = False
) -> _Preprocessor:
return self._create_collate_preprocessors(running_stage, collate_fn=collate_fn, is_serving=is_serving)[1]

def device_preprocessor(self, running_stage: RunningStage) -> _Preprocessor:
return self._create_collate_preprocessors(running_stage)[2]
Expand Down
Empty file.
Loading