Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add prediction unwrapping and serialization for icevision tasks #727

Merged
merged 9 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a `QuestionAnswering` task for extractive question answering ([#607](https://github.com/PyTorchLightning/lightning-flash/pull/607))

- Added automatic unwrapping of IceVision prediction objects ([#727](https://github.com/PyTorchLightning/lightning-flash/pull/727))

- Added support for the `ObjectDetector` with FiftyOne ([#727](https://github.com/PyTorchLightning/lightning-flash/pull/727))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
1 change: 0 additions & 1 deletion docs/source/api/image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ ________________
detection.data.FiftyOneParser
detection.data.ObjectDetectionFiftyOneDataSource
detection.data.ObjectDetectionPreprocess
detection.serialization.DetectionLabels
detection.serialization.FiftyOneDetectionLabels

Keypoint Detection
Expand Down
6 changes: 3 additions & 3 deletions flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,13 @@ def default_uncollate(batch: Any):
return batch
return list(torch.unbind(batch, 0))

if isinstance(batch, Mapping):
if isinstance(batch, dict):
return [batch_type(dict(zip(batch, default_uncollate(t)))) for t in zip(*batch.values())]

if isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple
return [batch_type(*default_uncollate(sample)) for sample in zip(*batch)]
return [batch_type(*sample) for sample in zip(*batch)]

if isinstance(batch, Sequence) and not isinstance(batch, str):
return [default_uncollate(sample) for sample in batch]
return [sample for sample in batch]

return batch
24 changes: 24 additions & 0 deletions flash/core/data/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Union

from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.process import Serializer


class Preds(Serializer):
"""A :class:`~flash.core.data.process.Serializer` which returns the "preds" from the model outputs."""

def serialize(self, sample: Any) -> Union[int, List[int]]:
return sample.get(DefaultDataKeys.PREDS, sample) if isinstance(sample, dict) else sample
22 changes: 13 additions & 9 deletions flash/core/integrations/icevision/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from flash.core.adapter import Adapter
from flash.core.data.auto_dataset import BaseAutoDataset
from flash.core.data.data_source import DefaultDataKeys
from flash.core.integrations.icevision.transforms import to_icevision_record
from flash.core.integrations.icevision.transforms import from_icevision_predictions, to_icevision_record
from flash.core.model import Task
from flash.core.utilities.imports import _ICEVISION_AVAILABLE
from flash.core.utilities.url_error import catch_url_error
Expand Down Expand Up @@ -81,9 +81,12 @@ def from_task(
@staticmethod
def _collate_fn(collate_fn, samples, metadata: Optional[List[Dict[str, Any]]] = None):
metadata = metadata or [None] * len(samples)
return collate_fn(
[to_icevision_record({**sample, DefaultDataKeys.METADATA: m}) for sample, m in zip(samples, metadata)]
)
return {
DefaultDataKeys.INPUT: collate_fn(
[to_icevision_record({**sample, DefaultDataKeys.METADATA: m}) for sample, m in zip(samples, metadata)]
),
DefaultDataKeys.METADATA: metadata,
}

def process_train_dataset(
self,
Expand Down Expand Up @@ -178,19 +181,20 @@ def process_predict_dataset(
return data_loader

def training_step(self, batch, batch_idx) -> Any:
return self.icevision_adapter.training_step(batch, batch_idx)
return self.icevision_adapter.training_step(batch[DefaultDataKeys.INPUT], batch_idx)

def validation_step(self, batch, batch_idx):
return self.icevision_adapter.validation_step(batch, batch_idx)
return self.icevision_adapter.validation_step(batch[DefaultDataKeys.INPUT], batch_idx)

def test_step(self, batch, batch_idx):
return self.icevision_adapter.validation_step(batch, batch_idx)
return self.icevision_adapter.validation_step(batch[DefaultDataKeys.INPUT], batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self(batch)
batch[DefaultDataKeys.PREDS] = self(batch[DefaultDataKeys.INPUT])
return batch

def forward(self, batch: Any) -> Any:
return self.model_type.predict_from_dl(self.model, [batch], show_pbar=False)
return from_icevision_predictions(self.model_type.predict_from_dl(self.model, [batch], show_pbar=False))

def training_epoch_end(self, outputs) -> None:
return self.icevision_adapter.training_epoch_end(outputs)
Expand Down
47 changes: 23 additions & 24 deletions flash/core/integrations/icevision/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type

import numpy as np

from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.data_source import DefaultDataKeys, LabelsState
from flash.core.integrations.icevision.transforms import from_icevision_record
from flash.core.utilities.imports import _ICEVISION_AVAILABLE
from flash.image.data import ImagePathsDataSource

if _ICEVISION_AVAILABLE:
from icevision.core.record import BaseRecord
from icevision.core.record_components import ClassMapRecordComponent, ImageRecordComponent, tasks
from icevision.core.record_components import ClassMapRecordComponent, FilepathRecordComponent, tasks
from icevision.data.data_splitter import SingleSplitSplitter
from icevision.parsers.parser import Parser

Expand All @@ -36,10 +37,14 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
return from_icevision_record(record)

def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(sample[DefaultDataKeys.INPUT], BaseRecord):
return self.load_sample(sample)
filepath = sample[DefaultDataKeys.INPUT]
sample = super().load_sample(sample)
image = np.array(sample[DefaultDataKeys.INPUT])
record = BaseRecord([ImageRecordComponent()])

record = BaseRecord([FilepathRecordComponent()])
record.filepath = filepath
record.set_img(image)
record.add_component(ClassMapRecordComponent(task=tasks.detection))
return from_icevision_record(record)
Expand All @@ -51,29 +56,23 @@ def __init__(self, parser: Optional[Type["Parser"]] = None):
self.parser = parser

def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]:
root, ann_file = data

if self.parser is not None:
parser = self.parser(ann_file, root)
dataset.num_classes = len(parser.class_map)
if inspect.isclass(self.parser) and issubclass(self.parser, Parser):
root, ann_file = data
parser = self.parser(ann_file, root)
elif isinstance(self.parser, Callable):
parser = self.parser(data)
else:
raise ValueError("The parser must be a callable or an IceVision Parser type.")
dataset.num_classes = parser.class_map.num_classes
self.set_state(LabelsState([parser.class_map.get_by_id(i) for i in range(dataset.num_classes)]))
records = parser.parse(data_splitter=SingleSplitSplitter())
return [{DefaultDataKeys.INPUT: record} for record in records[0]]
else:
raise ValueError("The parser type must be provided")


class IceDataParserDataSource(IceVisionPathsDataSource):
def __init__(self, parser: Optional[Callable] = None):
super().__init__()
self.parser = parser
raise ValueError("The parser argument must be provided.")

def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]:
root = data

if self.parser is not None:
parser = self.parser(root)
dataset.num_classes = len(parser.class_map)
records = parser.parse(data_splitter=SingleSplitSplitter())
return [{DefaultDataKeys.INPUT: record} for record in records[0]]
else:
raise ValueError("The parser must be provided")
def predict_load_data(self, data: Any, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]:
result = super().predict_load_data(data, dataset)
if len(result) == 0:
result = self.load_data(data, dataset)
return result
98 changes: 61 additions & 37 deletions flash/core/integrations/icevision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable, Dict, List, Tuple

from torch import nn

Expand All @@ -34,6 +34,7 @@
MasksRecordComponent,
RecordIDRecordComponent,
)
from icevision.data.prediction import Prediction
from icevision.tfms import A


Expand Down Expand Up @@ -101,51 +102,38 @@ def to_icevision_record(sample: Dict[str, Any]):
return record


def from_icevision_record(record: "BaseRecord"):
sample = {
DefaultDataKeys.METADATA: {
"image_id": record.record_id,
}
}
def from_icevision_detection(record: "BaseRecord"):
detection = record.detection

if record.img is not None:
sample[DefaultDataKeys.INPUT] = record.img
filepath = getattr(record, "filepath", None)
if filepath is not None:
sample[DefaultDataKeys.METADATA]["filepath"] = filepath
elif record.filepath is not None:
sample[DefaultDataKeys.INPUT] = record.filepath
result = {}

sample[DefaultDataKeys.TARGET] = {}

if hasattr(record.detection, "bboxes"):
sample[DefaultDataKeys.TARGET]["bboxes"] = []
for bbox in record.detection.bboxes:
bbox_list = list(bbox.xywh)
bbox_dict = {
"xmin": bbox_list[0],
"ymin": bbox_list[1],
"width": bbox_list[2],
"height": bbox_list[3],
if hasattr(detection, "bboxes"):
result["bboxes"] = [
{
"xmin": bbox.xmin,
"ymin": bbox.ymin,
"width": bbox.width,
"height": bbox.height,
}
sample[DefaultDataKeys.TARGET]["bboxes"].append(bbox_dict)
for bbox in detection.bboxes
]

if hasattr(record.detection, "masks"):
masks = record.detection.masks
if hasattr(detection, "masks"):
masks = detection.masks

if isinstance(masks, EncodedRLEs):
masks = masks.to_mask(record.height, record.width)

if isinstance(masks, MaskArray):
sample[DefaultDataKeys.TARGET]["masks"] = masks.data
result["masks"] = masks.data
else:
raise RuntimeError("Masks are expected to be a MaskArray or EncodedRLEs.")

if hasattr(record.detection, "keypoints"):
keypoints = record.detection.keypoints
if hasattr(detection, "keypoints"):
keypoints = detection.keypoints

sample[DefaultDataKeys.TARGET]["keypoints"] = []
sample[DefaultDataKeys.TARGET]["keypoints_metadata"] = []
result["keypoints"] = []
result["keypoints_metadata"] = []

for keypoint in keypoints:
keypoints_list = []
Expand All @@ -157,20 +145,56 @@ def from_icevision_record(record: "BaseRecord"):
"visible": v,
}
)
sample[DefaultDataKeys.TARGET]["keypoints"].append(keypoints_list)
result["keypoints"].append(keypoints_list)

# TODO: Unpack keypoints_metadata
sample[DefaultDataKeys.TARGET]["keypoints_metadata"].append(keypoint.metadata)
result["keypoints_metadata"].append(keypoint.metadata)

if getattr(detection, "label_ids", None) is not None:
result["labels"] = list(detection.label_ids)

if getattr(record.detection, "label_ids", None) is not None:
sample[DefaultDataKeys.TARGET]["labels"] = list(record.detection.label_ids)
if getattr(detection, "scores", None) is not None:
result["scores"] = list(detection.scores)

return result


def from_icevision_record(record: "BaseRecord"):
sample = {
DefaultDataKeys.METADATA: {
"size": (record.height, record.width),
}
}

if getattr(record, "record_id", None) is not None:
sample[DefaultDataKeys.METADATA]["image_id"] = record.record_id

if getattr(record, "filepath", None) is not None:
sample[DefaultDataKeys.METADATA]["filepath"] = record.filepath

if record.img is not None:
sample[DefaultDataKeys.INPUT] = record.img
filepath = getattr(record, "filepath", None)
if filepath is not None:
sample[DefaultDataKeys.METADATA]["filepath"] = filepath
elif record.filepath is not None:
sample[DefaultDataKeys.INPUT] = record.filepath

sample[DefaultDataKeys.TARGET] = from_icevision_detection(record)

if getattr(record.detection, "class_map", None) is not None:
sample[DefaultDataKeys.METADATA]["class_map"] = record.detection.class_map

return sample


def from_icevision_predictions(predictions: List["Prediction"]):
result = []
for prediction in predictions:
result.append(from_icevision_detection(prediction.pred))
return result


class IceVisionTransformAdapter(nn.Module):
def __init__(self, transform):
super().__init__()
Expand Down
8 changes: 2 additions & 6 deletions flash/image/detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, FiftyOneDataSource
from flash.core.data.process import Preprocess
from flash.core.integrations.icevision.data import (
IceDataParserDataSource,
IceVisionParserDataSource,
IceVisionPathsDataSource,
)
from flash.core.integrations.icevision.data import IceVisionParserDataSource, IceVisionPathsDataSource
from flash.core.integrations.icevision.transforms import default_transforms
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, lazy_import, requires

Expand Down Expand Up @@ -160,7 +156,7 @@ def __init__(
"via": IceVisionParserDataSource(parser=VIABBoxParser),
"voc": IceVisionParserDataSource(parser=VOCBBoxParser),
DefaultDataSources.FILES: IceVisionPathsDataSource(),
DefaultDataSources.FOLDERS: IceDataParserDataSource(parser=parser),
DefaultDataSources.FOLDERS: IceVisionParserDataSource(parser=parser),
DefaultDataSources.FIFTYONE: ObjectDetectionFiftyOneDataSource(**data_source_kwargs),
},
default_data_source=DefaultDataSources.FILES,
Expand Down
5 changes: 3 additions & 2 deletions flash/image/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from flash.core.adapter import AdapterTask
from flash.core.data.process import Serializer
from flash.core.data.serialization import Preds
from flash.core.registry import FlashRegistry
from flash.image.detection.backbones import OBJECT_DETECTION_HEADS

Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(
head: Optional[str] = "retinanet",
pretrained: bool = True,
optimizer: Type[Optimizer] = torch.optim.Adam,
learning_rate: float = 5e-4,
learning_rate: float = 5e-3,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
**kwargs: Any,
):
Expand All @@ -77,7 +78,7 @@ def __init__(
adapter,
learning_rate=learning_rate,
optimizer=optimizer,
serializer=serializer,
serializer=serializer or Preds(),
)

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None:
Expand Down
Loading