Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix inference for instance segmentation (#857)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Naren authored Oct 13, 2021
1 parent 9e618ed commit e0fbf5c
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 5 deletions.
2 changes: 2 additions & 0 deletions flash/core/integrations/icevision/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:

def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(sample[DefaultDataKeys.INPUT], BaseRecord):
# load the data via IceVision Base Record
return self.load_sample(sample)
# load the data using numpy
filepath = sample[DefaultDataKeys.INPUT]
sample = super().load_sample(sample)
image = np.array(sample[DefaultDataKeys.INPUT])
Expand Down
11 changes: 9 additions & 2 deletions flash/image/instance_segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataSources
from flash.core.data.process import Preprocess
from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources
from flash.core.data.process import Postprocess, Preprocess
from flash.core.integrations.icevision.data import IceVisionParserDataSource, IceVisionPathsDataSource
from flash.core.integrations.icevision.transforms import default_transforms
from flash.core.utilities.imports import _ICEVISION_AVAILABLE
Expand Down Expand Up @@ -70,9 +70,16 @@ def train_default_transforms(self) -> Optional[Dict[str, Callable]]:
return default_transforms(self.image_size)


class InstanceSegmentationPostProcess(Postprocess):
@staticmethod
def uncollate(batch: Any) -> Any:
return batch[DefaultDataKeys.PREDS]


class InstanceSegmentationData(DataModule):

preprocess_cls = InstanceSegmentationPreprocess
postprocess_cls = InstanceSegmentationPostProcess

@classmethod
def from_coco(
Expand Down
16 changes: 16 additions & 0 deletions flash/image/instance_segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
from typing import Any, Dict, List, Mapping, Optional, Type, Union

import torch
from pytorch_lightning.utilities import rank_zero_info
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from flash.core.adapter import AdapterTask
from flash.core.data.data_pipeline import DataPipeline
from flash.core.data.process import Serializer
from flash.core.data.serialization import Preds
from flash.core.registry import FlashRegistry
from flash.image.instance_segmentation.backbones import INSTANCE_SEGMENTATION_HEADS
from flash.image.instance_segmentation.data import InstanceSegmentationPostProcess, InstanceSegmentationPreprocess


class InstanceSegmentation(AdapterTask):
Expand Down Expand Up @@ -94,3 +97,16 @@ def __init__(
def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None:
"""This function is used only for debugging usage with CI."""
# todo

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
super().on_load_checkpoint(checkpoint)
# todo: currently the data pipeline for icevision is not serializable, so we re-create the pipeline.
if "data_pipeline" not in checkpoint:
rank_zero_info(
"Assigned Segmentation Data Pipeline for data processing. This is because a data-pipeline stored in "
"the model due to pickling issues. "
"If you'd like to change this, extend the InstanceSegmentation Task and override `on_load_checkpoint`."
)
self.data_pipeline = DataPipeline(
preprocess=InstanceSegmentationPreprocess(), postprocess=InstanceSegmentationPostProcess()
)
4 changes: 2 additions & 2 deletions flash_examples/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
predictions = model.predict(
[
str(data_dir / "images/yorkshire_terrier_9.jpg"),
str(data_dir / "images/english_cocker_spaniel_1.jpg"),
str(data_dir / "images/scottish_terrier_1.jpg"),
str(data_dir / "images/yorkshire_terrier_12.jpg"),
str(data_dir / "images/yorkshire_terrier_13.jpg"),
]
)
print(predictions)
Expand Down
63 changes: 62 additions & 1 deletion tests/image/instance_segmentation/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from functools import partial
from pathlib import Path

import flash
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE
from flash.image import InstanceSegmentation, InstanceSegmentationData

if _ICEDATA_AVAILABLE:
import icedata
if _ICEVISION_AVAILABLE:
import icevision

from unittest import mock

import pytest

from flash.__main__ import main
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
Expand All @@ -28,3 +40,52 @@ def test_cli():
main()
except SystemExit:
pass


# todo: this test takes around 25s because of the icedata download, can we speed it up?
@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="IceVision is not installed for testing")
def test_instance_segmentation_inference(tmpdir):
"""Test to ensure that inference runs with instance segmentation from input paths."""
# modify the root path to use 'data' where our CI caches datasets
icevision.utils.data_dir.data_dir = Path("data/icevision/")
icevision.utils.data_dir.data_dir.mkdir(exist_ok=True, parents=True)
data_dir = icedata.pets.load_data()

datamodule = InstanceSegmentationData.from_folders(
train_folder=data_dir,
val_split=0.1,
parser=partial(icedata.pets.parser, mask=True),
)

model = InstanceSegmentation(
head="mask_rcnn",
backbone="resnet18_fpn",
num_classes=datamodule.num_classes,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, fast_dev_run=True)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

predictions = model.predict(
[
str(data_dir / "images/yorkshire_terrier_9.jpg"),
str(data_dir / "images/yorkshire_terrier_12.jpg"),
str(data_dir / "images/yorkshire_terrier_13.jpg"),
]
)
assert len(predictions) == 3

model_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(model_path)
InstanceSegmentation.load_from_checkpoint(model_path)

predictions = model.predict(
[
str(data_dir / "images/yorkshire_terrier_9.jpg"),
str(data_dir / "images/yorkshire_terrier_12.jpg"),
str(data_dir / "images/yorkshire_terrier_15.jpg"),
]
)
assert len(predictions) == 3

0 comments on commit e0fbf5c

Please sign in to comment.