Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Fix inference for instance segmentation #857

Merged
merged 9 commits into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flash/core/integrations/icevision/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:

def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(sample[DefaultDataKeys.INPUT], BaseRecord):
# load the data via IceVision Base Record
return self.load_sample(sample)
# load the data using numpy
filepath = sample[DefaultDataKeys.INPUT]
sample = super().load_sample(sample)
image = np.array(sample[DefaultDataKeys.INPUT])
Expand Down
11 changes: 9 additions & 2 deletions flash/image/instance_segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataSources
from flash.core.data.process import Preprocess
from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources
from flash.core.data.process import Postprocess, Preprocess
from flash.core.integrations.icevision.data import IceVisionParserDataSource, IceVisionPathsDataSource
from flash.core.integrations.icevision.transforms import default_transforms
from flash.core.utilities.imports import _ICEVISION_AVAILABLE
Expand Down Expand Up @@ -70,9 +70,16 @@ def train_default_transforms(self) -> Optional[Dict[str, Callable]]:
return default_transforms(self.image_size)


class InstanceSegmentationPostProcess(Postprocess):
@staticmethod
def uncollate(batch: Any) -> Any:
return batch[DefaultDataKeys.PREDS]


class InstanceSegmentationData(DataModule):

preprocess_cls = InstanceSegmentationPreprocess
postprocess_cls = InstanceSegmentationPostProcess

@classmethod
def from_coco(
Expand Down
16 changes: 16 additions & 0 deletions flash/image/instance_segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
from typing import Any, Dict, List, Mapping, Optional, Type, Union

import torch
from pytorch_lightning.utilities import rank_zero_info
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from flash.core.adapter import AdapterTask
from flash.core.data.data_pipeline import DataPipeline
from flash.core.data.process import Serializer
from flash.core.data.serialization import Preds
from flash.core.registry import FlashRegistry
from flash.image.instance_segmentation.backbones import INSTANCE_SEGMENTATION_HEADS
from flash.image.instance_segmentation.data import InstanceSegmentationPostProcess, InstanceSegmentationPreprocess


class InstanceSegmentation(AdapterTask):
Expand Down Expand Up @@ -94,3 +97,16 @@ def __init__(
def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None:
"""This function is used only for debugging usage with CI."""
# todo

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
super().on_load_checkpoint(checkpoint)
# todo: currently the data pipeline for icevision is not serializable, so we re-create the pipeline.
if "data_pipeline" not in checkpoint:
rank_zero_info(
"Assigned Segmentation Data Pipeline for data processing. This is because a data-pipeline stored in "
"the model due to pickling issues. "
"If you'd like to change this, extend the InstanceSegmentation Task and override `on_load_checkpoint`."
)
self.data_pipeline = DataPipeline(
preprocess=InstanceSegmentationPreprocess(), postprocess=InstanceSegmentationPostProcess()
)
4 changes: 2 additions & 2 deletions flash_examples/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
predictions = model.predict(
[
str(data_dir / "images/yorkshire_terrier_9.jpg"),
str(data_dir / "images/english_cocker_spaniel_1.jpg"),
str(data_dir / "images/scottish_terrier_1.jpg"),
str(data_dir / "images/yorkshire_terrier_12.jpg"),
str(data_dir / "images/yorkshire_terrier_13.jpg"),
]
)
print(predictions)
Expand Down
63 changes: 62 additions & 1 deletion tests/image/instance_segmentation/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from functools import partial
from pathlib import Path

import flash
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE
from flash.image import InstanceSegmentation, InstanceSegmentationData

if _ICEDATA_AVAILABLE:
import icedata
if _ICEVISION_AVAILABLE:
import icevision

from unittest import mock

import pytest

from flash.__main__ import main
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
Expand All @@ -28,3 +40,52 @@ def test_cli():
main()
except SystemExit:
pass


# todo: this test takes around 25s because of the icedata download, can we speed it up?
@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="IceVision is not installed for testing")
def test_instance_segmentation_inference(tmpdir):
"""Test to ensure that inference runs with instance segmentation from input paths."""
# modify the root path to use 'data' where our CI caches datasets
icevision.utils.data_dir.data_dir = Path("data/icevision/")
icevision.utils.data_dir.data_dir.mkdir(exist_ok=True, parents=True)
data_dir = icedata.pets.load_data()

datamodule = InstanceSegmentationData.from_folders(
train_folder=data_dir,
val_split=0.1,
parser=partial(icedata.pets.parser, mask=True),
)

model = InstanceSegmentation(
head="mask_rcnn",
backbone="resnet18_fpn",
num_classes=datamodule.num_classes,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, fast_dev_run=True)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

predictions = model.predict(
[
str(data_dir / "images/yorkshire_terrier_9.jpg"),
str(data_dir / "images/yorkshire_terrier_12.jpg"),
str(data_dir / "images/yorkshire_terrier_13.jpg"),
]
)
assert len(predictions) == 3

model_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(model_path)
InstanceSegmentation.load_from_checkpoint(model_path)

predictions = model.predict(
[
str(data_dir / "images/yorkshire_terrier_9.jpg"),
str(data_dir / "images/yorkshire_terrier_12.jpg"),
str(data_dir / "images/yorkshire_terrier_15.jpg"),
]
)
assert len(predictions) == 3