Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Refactor: use PIL as default image format #1319

Merged
merged 10 commits into from
Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions flash/core/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
from typing import Any, Dict, Mapping, Sequence, Union

import torch
from torch import nn
from torch import nn, Tensor

from flash.core.data.io.input import DataKeys
from flash.core.data.utilities.collate import default_collate
from flash.core.data.utils import convert_to_modules
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
from torchvision.transforms.functional import to_tensor


class ApplyToKeys(nn.Sequential):
Expand All @@ -37,8 +42,10 @@ def __init__(self, keys: Union[str, Sequence[str]], *args):
self.keys = keys

def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]:
if not isinstance(x[DataKeys.INPUT], Tensor):
x[DataKeys.INPUT] = to_tensor(x[DataKeys.INPUT])
keys = list(filter(lambda key: key in x, self.keys))
inputs = [x[key] for key in keys]
inputs = [x[key].float() for key in keys]

result = {}
result.update(x)
Expand All @@ -48,11 +55,10 @@ def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]:
elif len(inputs) > 1:
try:
outputs = super().forward(inputs)
except TypeError as e:
except TypeError as ex:
raise Exception(
"Failed to apply transforms to multiple keys at the same time,"
" try using KorniaParallelTransforms."
) from e
"Failed to apply transforms to multiple keys at the same time, try using KorniaParallelTransforms."
) from ex

for i, key in enumerate(keys):
result[key] = outputs[i]
Expand Down
10 changes: 6 additions & 4 deletions flash/core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@
from flash.core.utilities.stages import RunningStage

if _PIL_AVAILABLE:
from PIL import Image as PILImage
from PIL.Image import Image
else:
Image = object
PILImage, Image = None, object

if _TORCHVISION_AVAILABLE:
from torchvision.datasets.folder import default_loader
else:
default_loader = None

_STAGES_PREFIX = {
RunningStage.TRAINING: "train",
Expand Down Expand Up @@ -70,8 +73,7 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None:
# __author__ = "github.com/ruxi"
# __license__ = "MIT"

Examples
________
Examples:

.. doctest::

Expand Down Expand Up @@ -162,7 +164,7 @@ def image_default_loader(file_path: str, drop_alpha: bool = True) -> Image:
file_path: The image file to load.
drop_alpha: If ``True`` (default) then any alpha channels will be silently removed.
"""
img = default_loader(file_path)
img = default_loader(file_path) if _TORCHVISION_AVAILABLE else PILImage.open(file_path)
if img.mode == "RGBA" and drop_alpha:
img = img.convert("RGB")
return img
2 changes: 2 additions & 0 deletions flash/core/integrations/icevision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ def from_icevision_record(record: "BaseRecord"):
sample = {
DataKeys.METADATA: {
"size": (record.height, record.width),
"height": record.height,
"width": record.width,
}
}

Expand Down
9 changes: 8 additions & 1 deletion flash/image/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def image_loader(filepath: str):
if has_file_allowed_extension(filepath, IMG_EXTENSIONS):
img = image_default_loader(filepath)
elif has_file_allowed_extension(filepath, NP_EXTENSIONS):
# Todo: reconsider if we may allow also image as float
img = Image.fromarray(np.load(filepath).astype("uint8"), "RGB")
else:
raise ValueError(
Expand Down Expand Up @@ -71,7 +72,13 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
w, h = sample[DataKeys.INPUT].size # W x H
if DataKeys.METADATA not in sample:
sample[DataKeys.METADATA] = {}
sample[DataKeys.METADATA]["size"] = (h, w)
sample[DataKeys.METADATA].update(
{
"size": (h, w),
"height": h,
"width": w,
}
)
return sample


Expand Down
15 changes: 9 additions & 6 deletions flash/image/segmentation/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ def load_labels_map(
self.labels_map = labels_map

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
sample[DataKeys.INPUT] = sample[DataKeys.INPUT].float()
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = sample[DataKeys.TARGET].float()
sample[DataKeys.METADATA] = {"size": sample[DataKeys.INPUT].shape[-2:]}
sample[DataKeys.METADATA] = {
"size": sample[DataKeys.INPUT].size,
"height": sample[DataKeys.INPUT].height,
"width": sample[DataKeys.INPUT].width,
}
return sample


Expand Down Expand Up @@ -104,9 +105,11 @@ def load_data(

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
filepath = sample[DataKeys.INPUT]
sample[DataKeys.INPUT] = to_tensor(image_loader(filepath))
sample[DataKeys.INPUT] = image_loader(filepath)
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = (to_tensor(image_loader(sample[DataKeys.TARGET])) * 255).long()[0]
im_segm = image_loader(sample[DataKeys.TARGET])
sample[DataKeys.TARGET] = (to_tensor(im_segm) * 255).long()[0]
assert sample[DataKeys.INPUT].size[::-1] == sample[DataKeys.TARGET].size()
sample = super().load_sample(sample)
sample[DataKeys.METADATA]["filepath"] = filepath
return sample
Expand Down