Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add instance segmentation and keypoint detection to flash zero #672

Merged
merged 5 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added Torch ORT support to Transformer based tasks ([#667](https://github.com/PyTorchLightning/lightning-flash/pull/667))

- Added support for flash zero with the `InstanceSegmentation` and `KeypointDetector` tasks ([#672](https://github.com/PyTorchLightning/lightning-flash/pull/672))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
19 changes: 19 additions & 0 deletions docs/source/reference/instance_segmentation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,22 @@ Here's the full example:
.. literalinclude:: ../../../flash_examples/instance_segmentation.py
:language: python
:lines: 14-

------

**********
Flash Zero
**********

The instance segmentation task can be used directly from the command line with zero code using :ref:`flash_zero`.
You can run the above example with:

.. code-block:: bash

flash instance_segmentation

To view configuration options and options for running the instance segmentation task with your own data, use:

.. code-block:: bash

flash instance_segmentation --help
19 changes: 19 additions & 0 deletions docs/source/reference/keypoint_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,22 @@ Here's the full example:
.. literalinclude:: ../../../flash_examples/keypoint_detection.py
:language: python
:lines: 14-

------

**********
Flash Zero
**********

The keypoint detector can be used directly from the command line with zero code using :ref:`flash_zero`.
You can run the above example with:

.. code-block:: bash

flash keypoint_detection

To view configuration options and options for running the keypoint detector with your own data, use:

.. code-block:: bash

flash keypoint_detection --help
2 changes: 2 additions & 0 deletions flash/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def wrapper(cli_args):
"flash.graph.classification",
"flash.image.classification",
"flash.image.detection",
"flash.image.instance_segmentation",
"flash.image.keypoint_detection",
"flash.image.segmentation",
"flash.image.style_transfer",
"flash.pointcloud.detection",
Expand Down
2 changes: 2 additions & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _compare_version(package: str, op, version) -> bool:
_SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece")
_DATASETS_AVAILABLE = _module_available("datasets")
_ICEVISION_AVAILABLE = _module_available("icevision")
_ICEDATA_AVAILABLE = _module_available("icedata")
_TORCH_ORT_AVAILABLE = _module_available("torch_ort")

if Version:
Expand All @@ -120,6 +121,7 @@ def _compare_version(package: str, op, version) -> bool:
_PYSTICHE_AVAILABLE,
_SEGMENTATION_MODELS_AVAILABLE,
_ICEVISION_AVAILABLE,
_ICEDATA_AVAILABLE,
]
)
_SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE
Expand Down
66 changes: 66 additions & 0 deletions flash/image/instance_segmentation/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional

from flash.core.utilities.flash_cli import FlashCLI
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires_extras
from flash.image import InstanceSegmentation, InstanceSegmentationData

if _ICEDATA_AVAILABLE:
import icedata

__all__ = ["instance_segmentation"]


@requires_extras("image")
def from_pets(
val_split: float = 0.1,
batch_size: int = 4,
num_workers: Optional[int] = None,
parser: Optional[Callable] = None,
**preprocess_kwargs,
) -> InstanceSegmentationData:
"""Downloads and loads the pets data set from icedata."""
data_dir = icedata.pets.load_data()

if parser is None:
parser = partial(icedata.pets.parser, mask=True)

return InstanceSegmentationData.from_folders(
train_folder=data_dir,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
parser=parser,
**preprocess_kwargs,
)


def instance_segmentation():
"""Segment object instances in images."""
cli = FlashCLI(
InstanceSegmentation,
InstanceSegmentationData,
default_datamodule_builder=from_pets,
default_arguments={
"trainer.max_epochs": 3,
},
)

cli.trainer.save_checkpoint("instance_segmentation_model.pt")


if __name__ == "__main__":
instance_segmentation()
66 changes: 66 additions & 0 deletions flash/image/keypoint_detection/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional

from flash.core.utilities.flash_cli import FlashCLI
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires_extras
from flash.image import KeypointDetectionData, KeypointDetector

if _ICEDATA_AVAILABLE:
import icedata

__all__ = ["keypoint_detection"]


@requires_extras("image")
def from_biwi(
val_split: float = 0.1,
batch_size: int = 4,
num_workers: Optional[int] = None,
parser: Optional[Callable] = None,
**preprocess_kwargs,
) -> KeypointDetectionData:
"""Downloads and loads the BIWI data set from icedata."""
data_dir = icedata.biwi.load_data()

if parser is None:
parser = icedata.biwi.parser

return KeypointDetectionData.from_folders(
train_folder=data_dir,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
parser=parser,
**preprocess_kwargs,
)


def keypoint_detection():
"""Detect keypoints in images."""
cli = FlashCLI(
KeypointDetector,
KeypointDetectionData,
default_datamodule_builder=from_biwi,
default_arguments={
"model.num_keypoints": 1,
"trainer.max_epochs": 3,
},
)

cli.trainer.save_checkpoint("keypoint_detection_model.pt")


if __name__ == "__main__":
keypoint_detection()
1 change: 0 additions & 1 deletion flash_examples/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
datamodule = InstanceSegmentationData.from_folders(
train_folder=data_dir,
val_split=0.1,
image_size=128,
parser=partial(icedata.pets.parser, mask=True),
)

Expand Down
3 changes: 1 addition & 2 deletions flash_examples/keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
datamodule = KeypointDetectionData.from_folders(
train_folder=data_dir,
val_split=0.1,
image_size=128,
parser=icedata.biwi.parser,
)

Expand All @@ -52,4 +51,4 @@
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("object_detection_model.pt")
trainer.save_checkpoint("keypoint_detection_model.pt")
3 changes: 1 addition & 2 deletions tests/image/detection/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ def test_load_from_checkpoint_dependency_error():
ObjectDetector.load_from_checkpoint("not_a_real_checkpoint.pt")


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="icevision is not installed.")
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_cli():
cli_args = ["flash", "object_detection", "--trainer.fast_dev_run", "True"]
with mock.patch("sys.argv", cli_args):
Expand Down
Empty file.
29 changes: 29 additions & 0 deletions tests/image/instance_segmentation/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock

import pytest

from flash.__main__ import main
from tests.helpers.utils import _IMAGE_TESTING


@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_cli():
cli_args = ["flash", "instance_segmentation", "--trainer.fast_dev_run", "True"]
with mock.patch("sys.argv", cli_args):
try:
main()
except SystemExit:
pass
Empty file.
29 changes: 29 additions & 0 deletions tests/image/keypoint_detection/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock

import pytest

from flash.__main__ import main
from tests.helpers.utils import _IMAGE_TESTING


@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_cli():
cli_args = ["flash", "keypoint_detection", "--trainer.fast_dev_run", "True"]
with mock.patch("sys.argv", cli_args):
try:
main()
except SystemExit:
pass