Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Refactor: use PIL as default image format (#1319)
Browse files Browse the repository at this point in the history
Co-authored-by: Kushashwa Ravi Shrimali <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
3 people authored Sep 5, 2022
1 parent f26e3df commit 95ae65f
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 48 deletions.
1 change: 0 additions & 1 deletion docs/source/api/image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ ____________
segmentation.input.SemanticSegmentationNumpyInput
segmentation.input.SemanticSegmentationTensorInput
segmentation.input.SemanticSegmentationFiftyOneInput
segmentation.input.SemanticSegmentationDeserializer
segmentation.model.SemanticSegmentationOutputTransform
segmentation.output.FiftyOneSegmentationLabelsOutput
segmentation.output.SegmentationLabelsOutput
Expand Down
6 changes: 3 additions & 3 deletions docs/source/template/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,12 @@ In your :class:`~flash.core.data.io.input.Input` or :class:`~flash.core.data.io.
Your :class:`~flash.core.data.io.output_transform.OutputTransform` can then use this metadata in its transforms.
You should use this approach if your postprocessing depends on the state of the input before the :class:`~flash.core.data.io.input_transform.InputTransform` transforms.
For example, if you want to resize the predictions to the original size of the inputs you should add the original image size in the :attr:`~flash.core.data.io.input.DataKeys.METADATA`.
Here's an example from the :class:`~flash.image.segmentation.SemanticSegmentationNumpyInput`:
Here's an example from the :class:`~flash.image.data.ImageInput`:

.. literalinclude:: ../../../flash/image/segmentation/input.py
.. literalinclude:: ../../../flash/image/data.py
:language: python
:dedent: 4
:pyobject: SemanticSegmentationNumpyInput.load_sample
:pyobject: ImageInput.load_sample

The :attr:`~flash.core.data.io.input.DataKeys.METADATA` can now be referenced in your :class:`~flash.core.data.io.output_transform.OutputTransform`.
For example, here's the code for the ``per_sample_transform`` method of the :class:`~flash.image.segmentation.model.SemanticSegmentationOutputTransform`:
Expand Down
3 changes: 1 addition & 2 deletions flash/core/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]:
outputs = super().forward(inputs)
except TypeError as e:
raise Exception(
"Failed to apply transforms to multiple keys at the same time,"
" try using KorniaParallelTransforms."
"Failed to apply transforms to multiple keys at the same time, try using KorniaParallelTransforms."
) from e

for i, key in enumerate(keys):
Expand Down
12 changes: 11 additions & 1 deletion flash/image/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ def serve_load_sample(self, data: str) -> Dict:
img = base64.b64decode(encoded_with_padding)
buffer = BytesIO(img)
img = Image.open(buffer, mode="r")
w, h = img.size
return {
DataKeys.INPUT: img,
DataKeys.METADATA: {
"size": (h, w),
},
}

@property
Expand All @@ -52,7 +56,13 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
w, h = sample[DataKeys.INPUT].size # W x H
if DataKeys.METADATA not in sample:
sample[DataKeys.METADATA] = {}
sample[DataKeys.METADATA]["size"] = (h, w)
sample[DataKeys.METADATA].update(
{
"size": (h, w),
"height": h,
"width": w,
}
)
return sample


Expand Down
51 changes: 14 additions & 37 deletions flash/image/segmentation/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
import os
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import numpy as np

from flash.core.data.io.input import DataKeys, Input
from flash.core.data.utilities.loading import IMG_EXTENSIONS, load_image, NP_EXTENSIONS
from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE
from flash.core.data.utilities.samples import to_samples
from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _TORCHVISION_AVAILABLE, lazy_import
from flash.image.data import ImageDeserializer
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import
from flash.image.data import ImageFilesInput, ImageNumpyInput, ImageTensorInput
from flash.image.segmentation.output import SegmentationLabelsOutput

if _FIFTYONE_AVAILABLE:
Expand All @@ -32,9 +32,6 @@
fo = None
SampleCollection = None

if _TORCHVISION_AVAILABLE:
from torchvision.transforms.functional import to_tensor


class SemanticSegmentationInput(Input):
num_classes: int
Expand All @@ -50,15 +47,8 @@ def load_labels_map(
if labels_map is not None:
self.labels_map = labels_map

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
sample[DataKeys.INPUT] = sample[DataKeys.INPUT].float()
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = sample[DataKeys.TARGET].float()
sample[DataKeys.METADATA] = {"size": sample[DataKeys.INPUT].shape[-2:]}
return sample


class SemanticSegmentationTensorInput(SemanticSegmentationInput):
class SemanticSegmentationTensorInput(SemanticSegmentationInput, ImageTensorInput):
def load_data(
self,
tensor: Any,
Expand All @@ -69,8 +59,13 @@ def load_data(
self.load_labels_map(num_classes, labels_map)
return to_samples(tensor, masks)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = sample[DataKeys.TARGET].numpy()
return super().load_sample(sample)

class SemanticSegmentationNumpyInput(SemanticSegmentationInput):

class SemanticSegmentationNumpyInput(SemanticSegmentationInput, ImageNumpyInput):
def load_data(
self,
array: Any,
Expand All @@ -81,14 +76,8 @@ def load_data(
self.load_labels_map(num_classes, labels_map)
return to_samples(array, masks)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
sample[DataKeys.INPUT] = torch.from_numpy(sample[DataKeys.INPUT])
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = torch.from_numpy(sample[DataKeys.TARGET])
return super().load_sample(sample)


class SemanticSegmentationFilesInput(SemanticSegmentationInput):
class SemanticSegmentationFilesInput(SemanticSegmentationInput, ImageFilesInput):
def load_data(
self,
files: Union[PATH_TYPE, List[PATH_TYPE]],
Expand All @@ -104,13 +93,9 @@ def load_data(
return to_samples(files, mask_files)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
filepath = sample[DataKeys.INPUT]
sample[DataKeys.INPUT] = to_tensor(load_image(filepath))
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = (to_tensor(load_image(sample[DataKeys.TARGET])) * 255).long()[0]
sample = super().load_sample(sample)
sample[DataKeys.METADATA]["filepath"] = filepath
return sample
sample[DataKeys.TARGET] = np.array(load_image(sample[DataKeys.TARGET]))[:, :, 0]
return super().load_sample(sample)


class SemanticSegmentationFolderInput(SemanticSegmentationFilesInput):
Expand Down Expand Up @@ -171,13 +156,5 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
if not self.predicting:
fo_dataset = fo.load_dataset(self._fo_dataset_name)
fo_sample = fo_dataset[filepath]
sample[DataKeys.TARGET] = torch.from_numpy(fo_sample[self.label_field].mask).float()
sample[DataKeys.TARGET] = fo_sample[self.label_field].mask
return sample


class SemanticSegmentationDeserializer(ImageDeserializer):
def serve_load_sample(self, data: str) -> Dict[str, Any]:
result = super().serve_load_sample(data)
result[DataKeys.INPUT] = to_tensor(result[DataKeys.INPUT])
result[DataKeys.METADATA] = {"size": result[DataKeys.INPUT].shape[-2:]}
return result
22 changes: 22 additions & 0 deletions flash/image/segmentation/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, Tuple, Union

import torch

from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.transforms import ApplyToKeys, kornia_collate, KorniaParallelTransforms
Expand All @@ -33,6 +35,15 @@ def prepare_target(batch: Dict[str, Any]) -> Dict[str, Any]:
return batch


def target_as_tensor(sample: Dict[str, Any]) -> Dict[str, Any]:
if DataKeys.TARGET in sample:
target = sample[DataKeys.TARGET]
if target.ndim == 2:
target = target[:, :, None]
sample[DataKeys.TARGET] = torch.from_numpy(target.transpose((2, 0, 1))).contiguous().squeeze().float()
return sample


def remove_extra_dimensions(batch: Dict[str, Any]):
if isinstance(batch[DataKeys.INPUT], list):
assert len(batch[DataKeys.INPUT]) == 1
Expand All @@ -51,6 +62,11 @@ class SemanticSegmentationInputTransform(InputTransform):
def train_per_sample_transform(self) -> Callable:
return T.Compose(
[
ApplyToKeys(
DataKeys.INPUT,
T.ToTensor(),
),
target_as_tensor,
ApplyToKeys(
[DataKeys.INPUT, DataKeys.TARGET],
KorniaParallelTransforms(
Expand All @@ -66,6 +82,11 @@ def train_per_sample_transform(self) -> Callable:
def per_sample_transform(self) -> Callable:
return T.Compose(
[
ApplyToKeys(
DataKeys.INPUT,
T.ToTensor(),
),
target_as_tensor,
ApplyToKeys(
[DataKeys.INPUT, DataKeys.TARGET],
KorniaParallelTransforms(K.geometry.Resize(self.image_size, interpolation="nearest")),
Expand All @@ -78,6 +99,7 @@ def per_sample_transform(self) -> Callable:
def predict_per_sample_transform(self) -> Callable:
return ApplyToKeys(
DataKeys.INPUT,
T.ToTensor(),
K.geometry.Resize(
self.image_size,
interpolation="nearest",
Expand Down
4 changes: 2 additions & 2 deletions flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
OPTIMIZER_TYPE,
OUTPUT_TRANSFORM_TYPE,
)
from flash.image.data import ImageDeserializer
from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES
from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS
from flash.image.segmentation.input import SemanticSegmentationDeserializer
from flash.image.segmentation.input_transform import SemanticSegmentationInputTransform
from flash.image.segmentation.output import SEMANTIC_SEGMENTATION_OUTPUTS

Expand Down Expand Up @@ -187,7 +187,7 @@ def serve(
host: str = "127.0.0.1",
port: int = 8000,
sanity_check: bool = True,
input_cls: Optional[Type[ServeInput]] = SemanticSegmentationDeserializer,
input_cls: Optional[Type[ServeInput]] = ImageDeserializer,
transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform,
transform_kwargs: Optional[Dict] = None,
output: Optional[Union[str, Output]] = None,
Expand Down
7 changes: 5 additions & 2 deletions flash/image/segmentation/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Dict, List, Tuple, Union

import numpy as np
import torch
from torch import Tensor

from flash.core.data.base_viz import BaseVisualization
Expand Down Expand Up @@ -85,8 +86,10 @@ def _show_images_and_labels(
else:
raise TypeError(f"Unknown data type. Got: {type(data)}.")
# convert images and labels to numpy and stack horizontally
image_vis: np.ndarray = self._to_numpy(image.byte())
label_tmp: Tensor = SegmentationLabelsOutput.labels_to_image(label.squeeze().byte(), self.labels_map)
image_vis: np.ndarray = self._to_numpy(image)
label_tmp: Tensor = SegmentationLabelsOutput.labels_to_image(
torch.as_tensor(label).squeeze(), self.labels_map
)
label_vis: np.ndarray = self._to_numpy(label_tmp)
img_vis = np.hstack((image_vis, label_vis))
# send to visualiser
Expand Down

0 comments on commit 95ae65f

Please sign in to comment.