From 67b227fcc94b6889d6855e2bd5bb0658f06876a3 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 17 Aug 2021 18:36:29 +0100 Subject: [PATCH] Add instance segmentation and keypoint detection to flash zero (#672) * Add instance segmentation and keypoint detection to flash zero * Add instance segmentation and keypoint detection to flash zero * Add docs * Uodate CHANGELOG.md * Fixes --- CHANGELOG.md | 2 + .../reference/instance_segmentation.rst | 19 ++++++ docs/source/reference/keypoint_detection.rst | 19 ++++++ flash/__main__.py | 2 + flash/core/utilities/imports.py | 2 + flash/image/instance_segmentation/cli.py | 66 +++++++++++++++++++ flash/image/keypoint_detection/cli.py | 66 +++++++++++++++++++ flash_examples/instance_segmentation.py | 1 - flash_examples/keypoint_detection.py | 3 +- tests/image/detection/test_model.py | 3 +- tests/image/instance_segmentation/__init__.py | 0 .../image/instance_segmentation/test_model.py | 29 ++++++++ tests/image/keypoint_detection/__init__.py | 0 tests/image/keypoint_detection/test_model.py | 29 ++++++++ 14 files changed, 236 insertions(+), 5 deletions(-) create mode 100644 flash/image/instance_segmentation/cli.py create mode 100644 flash/image/keypoint_detection/cli.py create mode 100644 tests/image/instance_segmentation/__init__.py create mode 100644 tests/image/instance_segmentation/test_model.py create mode 100644 tests/image/keypoint_detection/__init__.py create mode 100644 tests/image/keypoint_detection/test_model.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b5c9ec4dd5..d8d390f350 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added Torch ORT support to Transformer based tasks ([#667](https://github.com/PyTorchLightning/lightning-flash/pull/667)) +- Added support for flash zero with the `InstanceSegmentation` and `KeypointDetector` tasks ([#672](https://github.com/PyTorchLightning/lightning-flash/pull/672)) + ### Changed - Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) diff --git a/docs/source/reference/instance_segmentation.rst b/docs/source/reference/instance_segmentation.rst index 75408dc3fa..db864ad2bc 100644 --- a/docs/source/reference/instance_segmentation.rst +++ b/docs/source/reference/instance_segmentation.rst @@ -29,3 +29,22 @@ Here's the full example: .. literalinclude:: ../../../flash_examples/instance_segmentation.py :language: python :lines: 14- + +------ + +********** +Flash Zero +********** + +The instance segmentation task can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash instance_segmentation + +To view configuration options and options for running the instance segmentation task with your own data, use: + +.. code-block:: bash + + flash instance_segmentation --help diff --git a/docs/source/reference/keypoint_detection.rst b/docs/source/reference/keypoint_detection.rst index 76fd0dcdf5..2cc0fbef40 100644 --- a/docs/source/reference/keypoint_detection.rst +++ b/docs/source/reference/keypoint_detection.rst @@ -29,3 +29,22 @@ Here's the full example: .. literalinclude:: ../../../flash_examples/keypoint_detection.py :language: python :lines: 14- + +------ + +********** +Flash Zero +********** + +The keypoint detector can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash keypoint_detection + +To view configuration options and options for running the keypoint detector with your own data, use: + +.. code-block:: bash + + flash keypoint_detection --help diff --git a/flash/__main__.py b/flash/__main__.py index d967149d56..fba73c4fac 100644 --- a/flash/__main__.py +++ b/flash/__main__.py @@ -44,6 +44,8 @@ def wrapper(cli_args): "flash.graph.classification", "flash.image.classification", "flash.image.detection", + "flash.image.instance_segmentation", + "flash.image.keypoint_detection", "flash.image.segmentation", "flash.image.style_transfer", "flash.pointcloud.detection", diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 015c432c57..0c48ff6014 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -96,6 +96,7 @@ def _compare_version(package: str, op, version) -> bool: _SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece") _DATASETS_AVAILABLE = _module_available("datasets") _ICEVISION_AVAILABLE = _module_available("icevision") +_ICEDATA_AVAILABLE = _module_available("icedata") _TORCH_ORT_AVAILABLE = _module_available("torch_ort") if Version: @@ -120,6 +121,7 @@ def _compare_version(package: str, op, version) -> bool: _PYSTICHE_AVAILABLE, _SEGMENTATION_MODELS_AVAILABLE, _ICEVISION_AVAILABLE, + _ICEDATA_AVAILABLE, ] ) _SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE diff --git a/flash/image/instance_segmentation/cli.py b/flash/image/instance_segmentation/cli.py new file mode 100644 index 0000000000..3b0842c436 --- /dev/null +++ b/flash/image/instance_segmentation/cli.py @@ -0,0 +1,66 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Callable, Optional + +from flash.core.utilities.flash_cli import FlashCLI +from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires_extras +from flash.image import InstanceSegmentation, InstanceSegmentationData + +if _ICEDATA_AVAILABLE: + import icedata + +__all__ = ["instance_segmentation"] + + +@requires_extras("image") +def from_pets( + val_split: float = 0.1, + batch_size: int = 4, + num_workers: Optional[int] = None, + parser: Optional[Callable] = None, + **preprocess_kwargs, +) -> InstanceSegmentationData: + """Downloads and loads the pets data set from icedata.""" + data_dir = icedata.pets.load_data() + + if parser is None: + parser = partial(icedata.pets.parser, mask=True) + + return InstanceSegmentationData.from_folders( + train_folder=data_dir, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + parser=parser, + **preprocess_kwargs, + ) + + +def instance_segmentation(): + """Segment object instances in images.""" + cli = FlashCLI( + InstanceSegmentation, + InstanceSegmentationData, + default_datamodule_builder=from_pets, + default_arguments={ + "trainer.max_epochs": 3, + }, + ) + + cli.trainer.save_checkpoint("instance_segmentation_model.pt") + + +if __name__ == "__main__": + instance_segmentation() diff --git a/flash/image/keypoint_detection/cli.py b/flash/image/keypoint_detection/cli.py new file mode 100644 index 0000000000..b97345679e --- /dev/null +++ b/flash/image/keypoint_detection/cli.py @@ -0,0 +1,66 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional + +from flash.core.utilities.flash_cli import FlashCLI +from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires_extras +from flash.image import KeypointDetectionData, KeypointDetector + +if _ICEDATA_AVAILABLE: + import icedata + +__all__ = ["keypoint_detection"] + + +@requires_extras("image") +def from_biwi( + val_split: float = 0.1, + batch_size: int = 4, + num_workers: Optional[int] = None, + parser: Optional[Callable] = None, + **preprocess_kwargs, +) -> KeypointDetectionData: + """Downloads and loads the BIWI data set from icedata.""" + data_dir = icedata.biwi.load_data() + + if parser is None: + parser = icedata.biwi.parser + + return KeypointDetectionData.from_folders( + train_folder=data_dir, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + parser=parser, + **preprocess_kwargs, + ) + + +def keypoint_detection(): + """Detect keypoints in images.""" + cli = FlashCLI( + KeypointDetector, + KeypointDetectionData, + default_datamodule_builder=from_biwi, + default_arguments={ + "model.num_keypoints": 1, + "trainer.max_epochs": 3, + }, + ) + + cli.trainer.save_checkpoint("keypoint_detection_model.pt") + + +if __name__ == "__main__": + keypoint_detection() diff --git a/flash_examples/instance_segmentation.py b/flash_examples/instance_segmentation.py index 16e5699d14..3fdc4e8a4b 100644 --- a/flash_examples/instance_segmentation.py +++ b/flash_examples/instance_segmentation.py @@ -27,7 +27,6 @@ datamodule = InstanceSegmentationData.from_folders( train_folder=data_dir, val_split=0.1, - image_size=128, parser=partial(icedata.pets.parser, mask=True), ) diff --git a/flash_examples/keypoint_detection.py b/flash_examples/keypoint_detection.py index 731f0a8125..b1fa29cc02 100644 --- a/flash_examples/keypoint_detection.py +++ b/flash_examples/keypoint_detection.py @@ -25,7 +25,6 @@ datamodule = KeypointDetectionData.from_folders( train_folder=data_dir, val_split=0.1, - image_size=128, parser=icedata.biwi.parser, ) @@ -52,4 +51,4 @@ print(predictions) # 5. Save the model! -trainer.save_checkpoint("object_detection_model.pt") +trainer.save_checkpoint("keypoint_detection_model.pt") diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index f3ed0dc445..f5fd1fba85 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -124,8 +124,7 @@ def test_load_from_checkpoint_dependency_error(): ObjectDetector.load_from_checkpoint("not_a_real_checkpoint.pt") -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="icevision is not installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_cli(): cli_args = ["flash", "object_detection", "--trainer.fast_dev_run", "True"] with mock.patch("sys.argv", cli_args): diff --git a/tests/image/instance_segmentation/__init__.py b/tests/image/instance_segmentation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/image/instance_segmentation/test_model.py b/tests/image/instance_segmentation/test_model.py new file mode 100644 index 0000000000..8f54742d24 --- /dev/null +++ b/tests/image/instance_segmentation/test_model.py @@ -0,0 +1,29 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock + +import pytest + +from flash.__main__ import main +from tests.helpers.utils import _IMAGE_TESTING + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_cli(): + cli_args = ["flash", "instance_segmentation", "--trainer.fast_dev_run", "True"] + with mock.patch("sys.argv", cli_args): + try: + main() + except SystemExit: + pass diff --git a/tests/image/keypoint_detection/__init__.py b/tests/image/keypoint_detection/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/image/keypoint_detection/test_model.py b/tests/image/keypoint_detection/test_model.py new file mode 100644 index 0000000000..215ea9a71f --- /dev/null +++ b/tests/image/keypoint_detection/test_model.py @@ -0,0 +1,29 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock + +import pytest + +from flash.__main__ import main +from tests.helpers.utils import _IMAGE_TESTING + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_cli(): + cli_args = ["flash", "keypoint_detection", "--trainer.fast_dev_run", "True"] + with mock.patch("sys.argv", cli_args): + try: + main() + except SystemExit: + pass