Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Support for new icevision version (0.11.0) #989

Merged
merged 4 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where using image classification with DDP spawn would trigger an infinite recursion ([#969](https://github.com/PyTorchLightning/lightning-flash/pull/969))

- Fixed a bug where Flash could not be used with IceVision 0.11.0 ([#989](https://github.com/PyTorchLightning/lightning-flash/pull/989))

### Removed

- Removed `OutputMapping` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939))
Expand Down
62 changes: 47 additions & 15 deletions flash/core/integrations/icevision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch import nn

from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, requires
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _ICEVISION_GREATER_EQUAL_0_11_0, requires

if _ICEVISION_AVAILABLE:
from icevision.core import tasks
Expand All @@ -31,12 +31,17 @@
ImageRecordComponent,
InstancesLabelsRecordComponent,
KeyPointsRecordComponent,
MasksRecordComponent,
RecordIDRecordComponent,
)
from icevision.data.prediction import Prediction
from icevision.tfms import A

if _ICEVISION_AVAILABLE and _ICEVISION_GREATER_EQUAL_0_11_0:
from icevision.core.mask import MaskFile
from icevision.core.record_components import InstanceMasksRecordComponent
elif _ICEVISION_AVAILABLE:
from icevision.core.record_components import MasksRecordComponent


def to_icevision_record(sample: Dict[str, Any]):
record = BaseRecord([])
Expand Down Expand Up @@ -65,11 +70,28 @@ def to_icevision_record(sample: Dict[str, Any]):
component.set_bboxes(bboxes)
record.add_component(component)

if "masks" in sample[DataKeys.TARGET]:
mask_array = MaskArray(sample[DataKeys.TARGET]["masks"])
component = MasksRecordComponent()
component.set_masks(mask_array)
record.add_component(component)
if _ICEVISION_GREATER_EQUAL_0_11_0:
mask_array = sample[DataKeys.TARGET].get("mask_array", None)
masks = sample[DataKeys.TARGET].get("masks", None)

if mask_array is not None or masks is not None:
component = InstanceMasksRecordComponent()

if masks is not None:
masks = [MaskFile(mask) for mask in masks]
component.set_masks(masks)

if mask_array is not None:
mask_array = MaskArray(mask_array)
component.set_mask_array(mask_array)

record.add_component(component)
else:
mask_array = sample[DataKeys.TARGET].get("mask_array", None)
if mask_array is not None:
component = MasksRecordComponent()
component.set_masks(mask_array)
record.add_component(component)

if "keypoints" in sample[DataKeys.TARGET]:
keypoints = []
Expand Down Expand Up @@ -118,16 +140,26 @@ def from_icevision_detection(record: "BaseRecord"):
for bbox in detection.bboxes
]

if hasattr(detection, "masks"):
masks = detection.masks

if isinstance(masks, EncodedRLEs):
masks = masks.to_mask(record.height, record.width)
mask_array = (
getattr(detection, "mask_array", None) if _ICEVISION_GREATER_EQUAL_0_11_0 else getattr(detection, "masks", None)
)
if mask_array is not None:
if isinstance(mask_array, EncodedRLEs):
mask_array = mask_array.to_mask(record.height, record.width)

if isinstance(masks, MaskArray):
result["masks"] = masks.data
if isinstance(mask_array, MaskArray):
result["mask_array"] = mask_array.data
else:
raise RuntimeError("Masks are expected to be a MaskArray or EncodedRLEs.")
raise RuntimeError("Mask arrays are expected to be a MaskArray or EncodedRLEs.")

masks = getattr(detection, "masks", None)
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
if masks is not None and _ICEVISION_GREATER_EQUAL_0_11_0:
result["masks"] = []
for mask in masks:
if isinstance(mask, MaskFile):
result["masks"].append(mask.filepath)
else:
raise RuntimeError("Masks are expected to be MaskFile objects.")

if hasattr(detection, "keypoints"):
keypoints = detection.keypoints
Expand Down
1 change: 1 addition & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class Image:
_PL_GREATER_EQUAL_1_4_3 = _compare_version("pytorch_lightning", operator.ge, "1.4.3")
_PL_GREATER_EQUAL_1_5_0 = _compare_version("pytorch_lightning", operator.ge, "1.5.0")
_PANDAS_GREATER_EQUAL_1_3_0 = _compare_version("pandas", operator.ge, "1.3.0")
_ICEVISION_GREATER_EQUAL_0_11_0 = _compare_version("icevision", operator.ge, "0.11.0")

_TEXT_AVAILABLE = all(
[
Expand Down