Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add prediction unwrapping and serialization for icevision tasks #727

Merged
merged 9 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a `QuestionAnswering` task for extractive question answering ([#607](https://github.com/PyTorchLightning/lightning-flash/pull/607))

- Added automatic unwrapping of IceVision prediciton objects ([#727](https://github.com/PyTorchLightning/lightning-flash/pull/727))
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

- Added support for the `ObjectDetector` with FiftyOne ([#727](https://github.com/PyTorchLightning/lightning-flash/pull/727))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
1 change: 0 additions & 1 deletion docs/source/api/image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ ________________
detection.data.FiftyOneParser
detection.data.ObjectDetectionFiftyOneDataSource
detection.data.ObjectDetectionPreprocess
detection.serialization.DetectionLabels
detection.serialization.FiftyOneDetectionLabels

Keypoint Detection
Expand Down
6 changes: 3 additions & 3 deletions flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,13 @@ def default_uncollate(batch: Any):
return batch
return list(torch.unbind(batch, 0))

if isinstance(batch, Mapping):
if isinstance(batch, dict):
return [batch_type(dict(zip(batch, default_uncollate(t)))) for t in zip(*batch.values())]

if isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple
return [batch_type(*default_uncollate(sample)) for sample in zip(*batch)]
return [batch_type(*sample) for sample in zip(*batch)]

if isinstance(batch, Sequence) and not isinstance(batch, str):
return [default_uncollate(sample) for sample in batch]
return [sample for sample in batch]

return batch
24 changes: 24 additions & 0 deletions flash/core/data/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Union

from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.process import Serializer


class Preds(Serializer):
"""A :class:`~flash.core.data.process.Serializer` which returns the "preds" from the model outputs."""

def serialize(self, sample: Any) -> Union[int, List[int]]:
return sample.get(DefaultDataKeys.PREDS, sample) if isinstance(sample, dict) else sample
22 changes: 13 additions & 9 deletions flash/core/integrations/icevision/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from flash.core.adapter import Adapter
from flash.core.data.auto_dataset import BaseAutoDataset
from flash.core.data.data_source import DefaultDataKeys
from flash.core.integrations.icevision.transforms import to_icevision_record
from flash.core.integrations.icevision.transforms import from_icevision_predictions, to_icevision_record
from flash.core.model import Task
from flash.core.utilities.imports import _ICEVISION_AVAILABLE
from flash.core.utilities.url_error import catch_url_error
Expand Down Expand Up @@ -81,9 +81,12 @@ def from_task(
@staticmethod
def _collate_fn(collate_fn, samples, metadata: Optional[List[Dict[str, Any]]] = None):
metadata = metadata or [None] * len(samples)
return collate_fn(
[to_icevision_record({**sample, DefaultDataKeys.METADATA: m}) for sample, m in zip(samples, metadata)]
)
return {
DefaultDataKeys.INPUT: collate_fn(
[to_icevision_record({**sample, DefaultDataKeys.METADATA: m}) for sample, m in zip(samples, metadata)]
),
DefaultDataKeys.METADATA: metadata,
}

def process_train_dataset(
self,
Expand Down Expand Up @@ -178,19 +181,20 @@ def process_predict_dataset(
return data_loader

def training_step(self, batch, batch_idx) -> Any:
return self.icevision_adapter.training_step(batch, batch_idx)
return self.icevision_adapter.training_step(batch[DefaultDataKeys.INPUT], batch_idx)

def validation_step(self, batch, batch_idx):
return self.icevision_adapter.validation_step(batch, batch_idx)
return self.icevision_adapter.validation_step(batch[DefaultDataKeys.INPUT], batch_idx)

def test_step(self, batch, batch_idx):
return self.icevision_adapter.validation_step(batch, batch_idx)
return self.icevision_adapter.validation_step(batch[DefaultDataKeys.INPUT], batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self(batch)
batch[DefaultDataKeys.PREDS] = self(batch[DefaultDataKeys.INPUT])
return batch

def forward(self, batch: Any) -> Any:
return self.model_type.predict_from_dl(self.model, [batch], show_pbar=False)
return from_icevision_predictions(self.model_type.predict_from_dl(self.model, [batch], show_pbar=False))

def training_epoch_end(self, outputs) -> None:
return self.icevision_adapter.training_epoch_end(outputs)
Expand Down
47 changes: 23 additions & 24 deletions flash/core/integrations/icevision/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type

import numpy as np

from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.data_source import DefaultDataKeys, LabelsState
from flash.core.integrations.icevision.transforms import from_icevision_record
from flash.core.utilities.imports import _ICEVISION_AVAILABLE
from flash.image.data import ImagePathsDataSource

if _ICEVISION_AVAILABLE:
from icevision.core.record import BaseRecord
from icevision.core.record_components import ClassMapRecordComponent, ImageRecordComponent, tasks
from icevision.core.record_components import ClassMapRecordComponent, FilepathRecordComponent, tasks
from icevision.data.data_splitter import SingleSplitSplitter
from icevision.parsers.parser import Parser

Expand All @@ -36,10 +37,14 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
return from_icevision_record(record)

def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(sample[DefaultDataKeys.INPUT], BaseRecord):
return self.load_sample(sample)
filepath = sample[DefaultDataKeys.INPUT]
sample = super().load_sample(sample)
image = np.array(sample[DefaultDataKeys.INPUT])
record = BaseRecord([ImageRecordComponent()])

record = BaseRecord([FilepathRecordComponent()])
record.filepath = filepath
record.set_img(image)
record.add_component(ClassMapRecordComponent(task=tasks.detection))
return from_icevision_record(record)
Expand All @@ -51,29 +56,23 @@ def __init__(self, parser: Optional[Type["Parser"]] = None):
self.parser = parser

def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]:
root, ann_file = data

if self.parser is not None:
parser = self.parser(ann_file, root)
dataset.num_classes = len(parser.class_map)
if inspect.isclass(self.parser) and issubclass(self.parser, Parser):
root, ann_file = data
parser = self.parser(ann_file, root)
elif isinstance(self.parser, Callable):
parser = self.parser(data)
else:
raise ValueError("The parser must be a callable or an IceVision Parser type.")
dataset.num_classes = parser.class_map.num_classes
self.set_state(LabelsState([parser.class_map.get_by_id(i) for i in range(dataset.num_classes)]))
records = parser.parse(data_splitter=SingleSplitSplitter())
return [{DefaultDataKeys.INPUT: record} for record in records[0]]
else:
raise ValueError("The parser type must be provided")


class IceDataParserDataSource(IceVisionPathsDataSource):
def __init__(self, parser: Optional[Callable] = None):
super().__init__()
self.parser = parser
raise ValueError("The parser argument must be provided.")

def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]:
root = data

if self.parser is not None:
parser = self.parser(root)
dataset.num_classes = len(parser.class_map)
records = parser.parse(data_splitter=SingleSplitSplitter())
return [{DefaultDataKeys.INPUT: record} for record in records[0]]
else:
raise ValueError("The parser must be provided")
def predict_load_data(self, data: Any, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]:
result = super().predict_load_data(data, dataset)
if len(result) == 0:
result = self.load_data(data, dataset)
return result
87 changes: 56 additions & 31 deletions flash/core/integrations/icevision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable, Dict, List, Tuple

from torch import nn

Expand All @@ -34,6 +34,7 @@
MasksRecordComponent,
RecordIDRecordComponent,
)
from icevision.data.prediction import Prediction
from icevision.tfms import A


Expand Down Expand Up @@ -101,51 +102,39 @@ def to_icevision_record(sample: Dict[str, Any]):
return record


def from_icevision_record(record: "BaseRecord"):
sample = {
DefaultDataKeys.METADATA: {
"image_id": record.record_id,
}
}

if record.img is not None:
sample[DefaultDataKeys.INPUT] = record.img
filepath = getattr(record, "filepath", None)
if filepath is not None:
sample[DefaultDataKeys.METADATA]["filepath"] = filepath
elif record.filepath is not None:
sample[DefaultDataKeys.INPUT] = record.filepath
def from_icevision_detection(record: "BaseRecord"):
detection = record.detection

sample[DefaultDataKeys.TARGET] = {}
result = {}

if hasattr(record.detection, "bboxes"):
sample[DefaultDataKeys.TARGET]["bboxes"] = []
for bbox in record.detection.bboxes:
if hasattr(detection, "bboxes"):
result["bboxes"] = []
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
for bbox in detection.bboxes:
bbox_list = list(bbox.xywh)
bbox_dict = {
"xmin": bbox_list[0],
"ymin": bbox_list[1],
"width": bbox_list[2],
"height": bbox_list[3],
}
sample[DefaultDataKeys.TARGET]["bboxes"].append(bbox_dict)
result["bboxes"].append(bbox_dict)

if hasattr(record.detection, "masks"):
masks = record.detection.masks
if hasattr(detection, "masks"):
masks = detection.masks

if isinstance(masks, EncodedRLEs):
masks = masks.to_mask(record.height, record.width)

if isinstance(masks, MaskArray):
sample[DefaultDataKeys.TARGET]["masks"] = masks.data
result["masks"] = masks.data
else:
raise RuntimeError("Masks are expected to be a MaskArray or EncodedRLEs.")

if hasattr(record.detection, "keypoints"):
keypoints = record.detection.keypoints
if hasattr(detection, "keypoints"):
keypoints = detection.keypoints

sample[DefaultDataKeys.TARGET]["keypoints"] = []
sample[DefaultDataKeys.TARGET]["keypoints_metadata"] = []
result["keypoints"] = []
result["keypoints_metadata"] = []

for keypoint in keypoints:
keypoints_list = []
Expand All @@ -157,20 +146,56 @@ def from_icevision_record(record: "BaseRecord"):
"visible": v,
}
)
sample[DefaultDataKeys.TARGET]["keypoints"].append(keypoints_list)
result["keypoints"].append(keypoints_list)

# TODO: Unpack keypoints_metadata
sample[DefaultDataKeys.TARGET]["keypoints_metadata"].append(keypoint.metadata)
result["keypoints_metadata"].append(keypoint.metadata)

if getattr(detection, "label_ids", None) is not None:
result["labels"] = list(detection.label_ids)

if getattr(record.detection, "label_ids", None) is not None:
sample[DefaultDataKeys.TARGET]["labels"] = list(record.detection.label_ids)
if getattr(detection, "scores", None) is not None:
result["scores"] = list(detection.scores)

return result


def from_icevision_record(record: "BaseRecord"):
sample = {
DefaultDataKeys.METADATA: {
"size": (record.height, record.width),
}
}

if getattr(record, "record_id", None) is not None:
sample[DefaultDataKeys.METADATA]["image_id"] = record.image_id

if getattr(record, "filepath", None) is not None:
sample[DefaultDataKeys.METADATA]["filepath"] = record.filepath

if record.img is not None:
sample[DefaultDataKeys.INPUT] = record.img
filepath = getattr(record, "filepath", None)
if filepath is not None:
sample[DefaultDataKeys.METADATA]["filepath"] = filepath
elif record.filepath is not None:
sample[DefaultDataKeys.INPUT] = record.filepath

sample[DefaultDataKeys.TARGET] = from_icevision_detection(record)

if getattr(record.detection, "class_map", None) is not None:
sample[DefaultDataKeys.METADATA]["class_map"] = record.detection.class_map

return sample


def from_icevision_predictions(predictions: List["Prediction"]):
result = []
for prediction in predictions:
result.append(from_icevision_detection(prediction.pred))
return result


class IceVisionTransformAdapter(nn.Module):
def __init__(self, transform):
super().__init__()
Expand Down
8 changes: 2 additions & 6 deletions flash/image/detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, FiftyOneDataSource
from flash.core.data.process import Preprocess
from flash.core.integrations.icevision.data import (
IceDataParserDataSource,
IceVisionParserDataSource,
IceVisionPathsDataSource,
)
from flash.core.integrations.icevision.data import IceVisionParserDataSource, IceVisionPathsDataSource
from flash.core.integrations.icevision.transforms import default_transforms
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, lazy_import, requires

Expand Down Expand Up @@ -160,7 +156,7 @@ def __init__(
"via": IceVisionParserDataSource(parser=VIABBoxParser),
"voc": IceVisionParserDataSource(parser=VOCBBoxParser),
DefaultDataSources.FILES: IceVisionPathsDataSource(),
DefaultDataSources.FOLDERS: IceDataParserDataSource(parser=parser),
DefaultDataSources.FOLDERS: IceVisionParserDataSource(parser=parser),
DefaultDataSources.FIFTYONE: ObjectDetectionFiftyOneDataSource(**data_source_kwargs),
},
default_data_source=DefaultDataSources.FILES,
Expand Down
5 changes: 3 additions & 2 deletions flash/image/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from flash.core.adapter import AdapterTask
from flash.core.data.process import Serializer
from flash.core.data.serialization import Preds
from flash.core.registry import FlashRegistry
from flash.image.detection.backbones import OBJECT_DETECTION_HEADS

Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(
head: Optional[str] = "retinanet",
pretrained: bool = True,
optimizer: Type[Optimizer] = torch.optim.Adam,
learning_rate: float = 5e-4,
learning_rate: float = 5e-3,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
**kwargs: Any,
):
Expand All @@ -77,7 +78,7 @@ def __init__(
adapter,
learning_rate=learning_rate,
optimizer=optimizer,
serializer=serializer,
serializer=serializer or Preds(),
)

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None:
Expand Down
Loading