Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Support for new icevision version (0.11.0) (#989)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Nov 22, 2021
1 parent a0c97a3 commit f48fe2f
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 15 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where using image classification with DDP spawn would trigger an infinite recursion ([#969](https://github.com/PyTorchLightning/lightning-flash/pull/969))

- Fixed a bug where Flash could not be used with IceVision 0.11.0 ([#989](https://github.com/PyTorchLightning/lightning-flash/pull/989))

### Removed

- Removed `OutputMapping` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939))
Expand Down
62 changes: 47 additions & 15 deletions flash/core/integrations/icevision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch import nn

from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, requires
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _ICEVISION_GREATER_EQUAL_0_11_0, requires

if _ICEVISION_AVAILABLE:
from icevision.core import tasks
Expand All @@ -31,12 +31,17 @@
ImageRecordComponent,
InstancesLabelsRecordComponent,
KeyPointsRecordComponent,
MasksRecordComponent,
RecordIDRecordComponent,
)
from icevision.data.prediction import Prediction
from icevision.tfms import A

if _ICEVISION_AVAILABLE and _ICEVISION_GREATER_EQUAL_0_11_0:
from icevision.core.mask import MaskFile
from icevision.core.record_components import InstanceMasksRecordComponent
elif _ICEVISION_AVAILABLE:
from icevision.core.record_components import MasksRecordComponent


def to_icevision_record(sample: Dict[str, Any]):
record = BaseRecord([])
Expand Down Expand Up @@ -65,11 +70,28 @@ def to_icevision_record(sample: Dict[str, Any]):
component.set_bboxes(bboxes)
record.add_component(component)

if "masks" in sample[DataKeys.TARGET]:
mask_array = MaskArray(sample[DataKeys.TARGET]["masks"])
component = MasksRecordComponent()
component.set_masks(mask_array)
record.add_component(component)
if _ICEVISION_GREATER_EQUAL_0_11_0:
mask_array = sample[DataKeys.TARGET].get("mask_array", None)
masks = sample[DataKeys.TARGET].get("masks", None)

if mask_array is not None or masks is not None:
component = InstanceMasksRecordComponent()

if masks is not None:
masks = [MaskFile(mask) for mask in masks]
component.set_masks(masks)

if mask_array is not None:
mask_array = MaskArray(mask_array)
component.set_mask_array(mask_array)

record.add_component(component)
else:
mask_array = sample[DataKeys.TARGET].get("mask_array", None)
if mask_array is not None:
component = MasksRecordComponent()
component.set_masks(mask_array)
record.add_component(component)

if "keypoints" in sample[DataKeys.TARGET]:
keypoints = []
Expand Down Expand Up @@ -118,16 +140,26 @@ def from_icevision_detection(record: "BaseRecord"):
for bbox in detection.bboxes
]

if hasattr(detection, "masks"):
masks = detection.masks

if isinstance(masks, EncodedRLEs):
masks = masks.to_mask(record.height, record.width)
mask_array = (
getattr(detection, "mask_array", None) if _ICEVISION_GREATER_EQUAL_0_11_0 else getattr(detection, "masks", None)
)
if mask_array is not None:
if isinstance(mask_array, EncodedRLEs):
mask_array = mask_array.to_mask(record.height, record.width)

if isinstance(masks, MaskArray):
result["masks"] = masks.data
if isinstance(mask_array, MaskArray):
result["mask_array"] = mask_array.data
else:
raise RuntimeError("Masks are expected to be a MaskArray or EncodedRLEs.")
raise RuntimeError("Mask arrays are expected to be a MaskArray or EncodedRLEs.")

masks = getattr(detection, "masks", None)
if masks is not None and _ICEVISION_GREATER_EQUAL_0_11_0:
result["masks"] = []
for mask in masks:
if isinstance(mask, MaskFile):
result["masks"].append(mask.filepath)
else:
raise RuntimeError("Masks are expected to be MaskFile objects.")

if hasattr(detection, "keypoints"):
keypoints = detection.keypoints
Expand Down
1 change: 1 addition & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class Image:
_PL_GREATER_EQUAL_1_4_3 = _compare_version("pytorch_lightning", operator.ge, "1.4.3")
_PL_GREATER_EQUAL_1_5_0 = _compare_version("pytorch_lightning", operator.ge, "1.5.0")
_PANDAS_GREATER_EQUAL_1_3_0 = _compare_version("pandas", operator.ge, "1.3.0")
_ICEVISION_GREATER_EQUAL_0_11_0 = _compare_version("icevision", operator.ge, "0.11.0")

_TEXT_AVAILABLE = all(
[
Expand Down

0 comments on commit f48fe2f

Please sign in to comment.