This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
IceVision integration #608
Merged
Merged
Changes from 14 commits
Commits
Show all changes
64 commits
Select commit
Hold shift + click to select a range
5802dcf
Initial commit
ethanwharris 37044c2
Merge branch 'master' into feature/icevision
ethanwharris 35c0465
Add instance segmentation and keypoint detection tasks
ethanwharris e79d2ff
Merge branch 'master' into feature/icevision
ethanwharris 21a236d
Updates
ethanwharris b9dfc48
Updates
ethanwharris 89385bd
Updates
ethanwharris addfe96
Add docs
ethanwharris 22b4152
Update API reference
ethanwharris 14dd36f
Fix some tests
ethanwharris 1b0642e
Small fix
ethanwharris 4a6c399
Drop failing JIT test
ethanwharris 9e30034
Updates
ethanwharris 00f391e
Updates
ethanwharris e6ee994
Fix a test
ethanwharris 19a30e1
Merge branch 'master' into feature/icevision
ethanwharris d548607
Initial credits support
ethanwharris 93cb652
Merge branch 'master' into feature/icevision
ethanwharris 7d9838b
Credit -> provider
ethanwharris 2e8a777
Update available backbones
ethanwharris a102d31
Add adapter
ethanwharris 8338ba5
Merge branch 'master' into feature/icevision
ethanwharris ad7722e
Fix a test
ethanwharris 3f34159
Merge branch 'feature/icevision' of https://github.com/PyTorchLightni…
ethanwharris 0cc27da
Merge branch 'master' into feature/icevision
ethanwharris 22afaae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] c50fd24
Merge branch 'master' into feature/icevision
ethanwharris e19b4c2
Updates
ethanwharris 4cf6332
Fixes
ethanwharris 858acdb
Refactor
ethanwharris 7c6fb2f
Refactor
ethanwharris 89f6978
Refactor
ethanwharris 53e171e
minor changes
ethanwharris a307f0b
Merge branch 'master' into feature/icevision
ethanwharris cb3a2f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8725028
0.5.0dev
Borda dba3145
Merge branch 'master' into feature/icevision
ethanwharris 335073a
pl
Borda 19143db
imports
Borda b72375e
Update adapter.py
ethanwharris 5a1cb64
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 45d121e
Merge branch 'feature/icevision' of https://github.com/PyTorchLightni…
ethanwharris 55377f1
Update adapter.py
ethanwharris 0b43313
Merge branch 'feature/icevision' of https://github.com/PyTorchLightni…
ethanwharris 68648ab
Updates
ethanwharris 878c7b9
Merge branch 'master' into feature/icevision
ethanwharris 12a89dd
Add transforms to and from icevision records
ethanwharris 00b49dd
Merge branch 'feature/icevision' of https://github.com/PyTorchLightni…
ethanwharris cee3edf
Fix tests
ethanwharris 0b02c55
Try fix
ethanwharris 1824e5e
Update CHANGELOG.md
ethanwharris 6fb7ee3
Fix tests
ethanwharris 221b01c
Fix a test
ethanwharris 1ca9b6b
Try fix
ethanwharris d97dbdf
Try fix
ethanwharris ecff056
Merge branch 'master' into feature/icevision
ethanwharris 3b387f7
Add some docs
ethanwharris 16ed49c
Add API reference
ethanwharris 40b7c9b
Small updates
ethanwharris 69cbaf8
Merge branch 'master' into feature/icevision
ethanwharris cfc02bb
Merge branch 'master' into feature/icevision
ananyahjha93 ac7743b
pep fix
ananyahjha93 6c74155
Fixes
ethanwharris 75e1c31
Merge branch 'feature/icevision' of https://github.com/PyTorchLightni…
ananyahjha93 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
|
||
.. _instance_segmentation: | ||
|
||
##################### | ||
Instance Segmentation | ||
##################### | ||
|
||
******** | ||
The Task | ||
******** | ||
|
||
Instance segmentation is the task of segmenting objects images and determining their associated classes. | ||
|
||
The :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` and :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` classes internally rely on `IceVision <https://airctic.com/>`_. | ||
|
||
------ | ||
|
||
******* | ||
Example | ||
******* | ||
|
||
Let's look at instance segmentation with `The Oxford-IIIT Pet Dataset <https://www.robots.ox.ac.uk/~vgg/data/pets/>`_ from `IceData <https://github.com/airctic/icedata>`_. | ||
Once we've downloaded the data, we can create the :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData`. | ||
We select a ``mask_rcnn`` with a ``resnet18_fpn`` backbone to use for our :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` and fine-tune on the pets data. | ||
We then use the trained :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` for inference. | ||
Finally, we save the model. | ||
Here's the full example: | ||
|
||
.. literalinclude:: ../../../flash_examples/instance_segmentation.py | ||
:language: python | ||
:lines: 14- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
|
||
.. _keypoint_detection: | ||
|
||
################## | ||
Keypoint Detection | ||
################## | ||
|
||
******** | ||
The Task | ||
******** | ||
|
||
Keypoint detection is the task of identifying keypoints in images and their associated classes. | ||
|
||
The :class:`~flash.image.keypoint_detection.model.KeypointDetector` and :class:`~flash.image.keypoint_detection.data.KeypointDetectionData` classes internally rely on `IceVision <https://airctic.com/>`_. | ||
|
||
------ | ||
|
||
******* | ||
Example | ||
******* | ||
|
||
Let's look at keypoint detection with `BIWI Sample Keypoints (center of face) <https://www.kaggle.com/kmader/biwi-kinect-head-pose-database>`_ from `IceData <https://github.com/airctic/icedata>`_. | ||
Once we've downloaded the data, we can create the :class:`~flash.image.keypoint_detection.data.KeypointDetectionData`. | ||
We select a ``keypoint_rcnn`` with a ``resnet18_fpn`` backbone to use for our :class:`~flash.image.keypoint_detection.model.KeypointDetector` and fine-tune on the BIWI data. | ||
We then use the trained :class:`~flash.image.keypoint_detection.model.KeypointDetector` for inference. | ||
Finally, we save the model. | ||
Here's the full example: | ||
|
||
.. literalinclude:: ../../../flash_examples/keypoint_detection.py | ||
:language: python | ||
:lines: 14- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from inspect import getmembers | ||
|
||
from torch import nn | ||
|
||
from flash.core.registry import FlashRegistry | ||
from flash.core.utilities.imports import _ICEVISION_AVAILABLE | ||
|
||
if _ICEVISION_AVAILABLE: | ||
from icevision.backbones import BackboneConfig | ||
|
||
OBJECT_DETECTION_HEADS = FlashRegistry("heads") | ||
|
||
|
||
def icevision_model_adapter(model_type): | ||
|
||
class IceVisionModelAdapter(model_type.lightning.ModelAdapter): | ||
|
||
def log(self, name, value, **kwargs): | ||
if "prog_bar" not in kwargs: | ||
kwargs["prog_bar"] = True | ||
return super().log(name.split("/")[-1], value, **kwargs) | ||
|
||
return IceVisionModelAdapter | ||
|
||
|
||
def load_icevision(adapter, model_type, backbone, num_classes, **kwargs): | ||
model = model_type.model(backbone=backbone, num_classes=num_classes, **kwargs) | ||
|
||
backbone = nn.Module() | ||
params = model.param_groups()[0] | ||
for i, param in enumerate(params): | ||
backbone.register_parameter(f"backbone_{i}", param) | ||
tchaton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return model_type, model, adapter(model_type), backbone | ||
|
||
|
||
def load_icevision_ignore_image_size(adapter, model_type, backbone, num_classes, image_size=None, **kwargs): | ||
return load_icevision(adapter, model_type, backbone, num_classes, **kwargs) | ||
|
||
|
||
def load_icevision_with_image_size(adapter, model_type, backbone, num_classes, image_size=None, **kwargs): | ||
kwargs["img_size"] = image_size | ||
return load_icevision(adapter, model_type, backbone, num_classes, **kwargs) | ||
|
||
|
||
def get_backbones(model_type): | ||
ethanwharris marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_BACKBONES = FlashRegistry("backbones") | ||
|
||
for backbone_name, backbone_config in getmembers(model_type.backbones, lambda x: isinstance(x, BackboneConfig)): | ||
_BACKBONES( | ||
backbone_config, | ||
name=backbone_name, | ||
) | ||
return _BACKBONES |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type | ||
|
||
import numpy as np | ||
|
||
from flash.core.data.data_source import DefaultDataKeys | ||
from flash.core.utilities.imports import _ICEVISION_AVAILABLE | ||
from flash.image.data import ImagePathsDataSource | ||
|
||
if _ICEVISION_AVAILABLE: | ||
from icevision.core import BaseRecord, ClassMapRecordComponent, ImageRecordComponent, tasks | ||
from icevision.data import SingleSplitSplitter | ||
from icevision.parsers import Parser | ||
|
||
|
||
class IceVisionPathsDataSource(ImagePathsDataSource): | ||
|
||
def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: | ||
return super().predict_load_data(data, dataset) | ||
|
||
def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: | ||
return sample[DefaultDataKeys.INPUT].load() | ||
|
||
def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: | ||
sample = super().load_sample(sample) | ||
image = np.array(sample[DefaultDataKeys.INPUT]) | ||
record = BaseRecord([ImageRecordComponent()]) | ||
|
||
record.set_img(image) | ||
record.add_component(ClassMapRecordComponent(task=tasks.detection)) | ||
return record | ||
|
||
|
||
class IceVisionParserDataSource(IceVisionPathsDataSource): | ||
|
||
def __init__(self, parser: Optional[Type['Parser']] = None): | ||
super().__init__() | ||
self.parser = parser | ||
|
||
def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: | ||
root, ann_file = data | ||
|
||
if self.parser is not None: | ||
parser = self.parser(ann_file, root) | ||
dataset.num_classes = len(parser.class_map) | ||
records = parser.parse(data_splitter=SingleSplitSplitter()) | ||
return [{DefaultDataKeys.INPUT: record} for record in records[0]] | ||
else: | ||
raise ValueError("The parser type must be provided") | ||
|
||
|
||
class IceDataParserDataSource(IceVisionPathsDataSource): | ||
|
||
def __init__(self, parser: Optional[Callable] = None): | ||
super().__init__() | ||
self.parser = parser | ||
|
||
def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: | ||
root = data | ||
|
||
if self.parser is not None: | ||
parser = self.parser(root) | ||
dataset.num_classes = len(parser.class_map) | ||
records = parser.parse(data_splitter=SingleSplitSplitter()) | ||
return [{DefaultDataKeys.INPUT: record} for record in records[0]] | ||
else: | ||
raise ValueError("The parser must be provided") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could they be merged together ? |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can have these labels automatically generated by docs