Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

simplify loading RGBA images #946

Merged
merged 9 commits into from
Nov 8, 2021
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed `Postprocess` to `OutputTransform` ([#942](https://github.com/PyTorchLightning/lightning-flash/pull/942))

- Changed loading of RGBA images to drop alpha channel by default ([#946](https://github.com/PyTorchLightning/lightning-flash/pull/946))

### Deprecated

- Deprecated `flash.core.data.process.Serializer` in favour of `flash.core.data.io.output.Output` ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927))
Expand Down
7 changes: 2 additions & 5 deletions flash/audio/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,14 @@
PathsDataSource,
)
from flash.core.data.process import Deserializer, Preprocess
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from flash.core.data.utils import image_default_loader
from flash.image.classification.data import ImageClassificationData
from flash.image.data import ImageDeserializer, IMG_EXTENSIONS, NP_EXTENSIONS

if _TORCHVISION_AVAILABLE:
from torchvision.datasets.folder import default_loader


def spectrogram_loader(filepath: str):
if has_file_allowed_extension(filepath, IMG_EXTENSIONS):
img = default_loader(filepath)
img = image_default_loader(filepath)
data = np.array(img)
else:
data = np.load(filepath)
Expand Down
22 changes: 22 additions & 0 deletions flash/core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,17 @@
from torch import Tensor
from tqdm.auto import tqdm as tq

from flash.core.utilities.imports import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.stages import RunningStage

if _PIL_AVAILABLE:
from PIL.Image import Image
else:
Image = object

if _TORCHVISION_AVAILABLE:
from torchvision.datasets.folder import default_loader

_STAGES_PREFIX = {
RunningStage.TRAINING: "train",
RunningStage.TESTING: "test",
Expand Down Expand Up @@ -207,3 +216,16 @@ def convert_to_modules(transforms: Optional[Dict[str, Callable]]):
transforms, Iterable, torch.nn.ModuleList, wrong_dtype=(torch.nn.ModuleList, torch.nn.ModuleDict)
)
return transforms


def image_default_loader(file_path: str, drop_alpha: bool = True) -> Image:
"""Default loader for images.

Args:
file_path: The image file to load.
drop_alpha: If ``True`` (default) then any alpha channels will be silently removed.
"""
img = default_loader(file_path)
if img.mode == "RGBA" and drop_alpha:
img = img.convert("RGB")
return img
8 changes: 3 additions & 5 deletions flash/core/integrations/labelstudio/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@

from flash.core.data.auto_dataset import AutoDataset, IterableAutoDataset
from flash.core.data.data_source import DataSource, DefaultDataKeys, has_len
from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE, _TEXT_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.data.utils import image_default_loader
from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE, _TEXT_AVAILABLE
from flash.core.utilities.stages import RunningStage

if _TORCHVISION_AVAILABLE:
from torchvision.datasets.folder import default_loader

if _TEXT_AVAILABLE:
from transformers import AutoTokenizer

Expand Down Expand Up @@ -188,7 +186,7 @@ def load_sample(self, sample: Mapping[str, Any] = None, dataset: Optional[Any] =
"""Load 1 sample from dataset."""
p = sample["file_upload"]
# loading image
image = default_loader(p)
image = image_default_loader(p)
result = {DefaultDataKeys.INPUT: image, DefaultDataKeys.TARGET: self._get_labels_from_sample(sample["label"])}
return result

Expand Down
7 changes: 4 additions & 3 deletions flash/image/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@
TensorDataSource,
)
from flash.core.data.process import Deserializer
from flash.core.data.utils import image_default_loader
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, Image, requires

if _TORCHVISION_AVAILABLE:
from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS
from torchvision.datasets.folder import IMG_EXTENSIONS
from torchvision.transforms.functional import to_pil_image
else:
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
Expand All @@ -44,7 +45,7 @@

def image_loader(filepath: str):
if has_file_allowed_extension(filepath, IMG_EXTENSIONS):
img = default_loader(filepath)
img = image_default_loader(filepath)
elif has_file_allowed_extension(filepath, NP_EXTENSIONS):
img = Image.fromarray(np.load(filepath).astype("uint8"), "RGB")
else:
Expand Down Expand Up @@ -116,7 +117,7 @@ class ImageFiftyOneDataSource(FiftyOneDataSource):
@staticmethod
def load_sample(sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
img_path = sample[DefaultDataKeys.INPUT]
img = default_loader(img_path)
img = image_default_loader(img_path)
sample[DefaultDataKeys.INPUT] = img
w, h = img.size # WxH
sample[DefaultDataKeys.METADATA] = {
Expand Down
4 changes: 2 additions & 2 deletions flash/image/face_detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
from flash.core.data.io.output_transform import OutputTransform
from flash.core.data.process import Preprocess
from flash.core.data.transforms import ApplyToKeys
from flash.core.data.utils import image_default_loader
from flash.core.utilities.imports import _FASTFACE_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.image.data import ImagePathsDataSource
from flash.image.detection import ObjectDetectionData

if _TORCHVISION_AVAILABLE:
import torchvision
from torchvision.datasets.folder import default_loader

if _FASTFACE_AVAILABLE:
import fastface as ff
Expand Down Expand Up @@ -86,7 +86,7 @@ def load_data(self, data: Dataset, dataset: Any = None) -> Dataset:

def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str, Any]:
filepath = sample[DefaultDataKeys.INPUT]
img = default_loader(filepath)
img = image_default_loader(filepath)
sample[DefaultDataKeys.INPUT] = img

w, h = img.size # WxH
Expand Down
11 changes: 6 additions & 5 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
TensorDataSource,
)
from flash.core.data.process import Deserializer, Preprocess
from flash.core.data.utils import image_default_loader
from flash.core.utilities.imports import (
_FIFTYONE_AVAILABLE,
_MATPLOTLIB_AVAILABLE,
Expand Down Expand Up @@ -63,7 +64,7 @@
if _TORCHVISION_AVAILABLE:
import torchvision
import torchvision.transforms.functional as FT
from torchvision.datasets.folder import default_loader, has_file_allowed_extension
from torchvision.datasets.folder import has_file_allowed_extension


class SemanticSegmentationNumpyDataSource(NumpyDataSource):
Expand Down Expand Up @@ -138,7 +139,7 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten
img_labels_path = sample[DefaultDataKeys.TARGET]

# load images directly to torch tensors
img: torch.Tensor = FT.to_tensor(default_loader(img_path)) # CxHxW
img: torch.Tensor = FT.to_tensor(image_default_loader(img_path)) # CxHxW
img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW
img_labels = img_labels[0] # HxW

Expand All @@ -153,7 +154,7 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten
@staticmethod
def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]:
img_path = sample[DefaultDataKeys.INPUT]
img = FT.to_tensor(default_loader(img_path)).float()
img = FT.to_tensor(image_default_loader(img_path)).float()

sample[DefaultDataKeys.INPUT] = img
sample[DefaultDataKeys.METADATA] = {
Expand Down Expand Up @@ -184,7 +185,7 @@ def load_sample(self, sample: Mapping[str, str]) -> Mapping[str, Union[torch.Ten
img_path = sample[DefaultDataKeys.INPUT]
fo_sample = _fo_dataset[img_path]

img: torch.Tensor = FT.to_tensor(default_loader(img_path)) # CxHxW
img: torch.Tensor = FT.to_tensor(image_default_loader(img_path)) # CxHxW
img_labels: torch.Tensor = torch.from_numpy(fo_sample[self.label_field].mask) # HxW

sample[DefaultDataKeys.INPUT] = img.float()
Expand All @@ -198,7 +199,7 @@ def load_sample(self, sample: Mapping[str, str]) -> Mapping[str, Union[torch.Ten
@staticmethod
def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]:
img_path = sample[DefaultDataKeys.INPUT]
img = FT.to_tensor(default_loader(img_path)).float()
img = FT.to_tensor(image_default_loader(img_path)).float()

sample[DefaultDataKeys.INPUT] = img
sample[DefaultDataKeys.METADATA] = {
Expand Down