diff --git a/CHANGELOG.md b/CHANGELOG.md index b57b450719..9ae95d42b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `Postprocess` to `OutputTransform` ([#942](https://github.com/PyTorchLightning/lightning-flash/pull/942)) +- Changed loading of RGBA images to drop alpha channel by default ([#946](https://github.com/PyTorchLightning/lightning-flash/pull/946)) + ### Deprecated - Deprecated `flash.core.data.process.Serializer` in favour of `flash.core.data.io.output.Output` ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927)) diff --git a/flash/audio/classification/data.py b/flash/audio/classification/data.py index 2f5e91d8ea..9c5bd805c1 100644 --- a/flash/audio/classification/data.py +++ b/flash/audio/classification/data.py @@ -25,17 +25,14 @@ PathsDataSource, ) from flash.core.data.process import Deserializer, Preprocess -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE +from flash.core.data.utils import image_default_loader from flash.image.classification.data import ImageClassificationData from flash.image.data import ImageDeserializer, IMG_EXTENSIONS, NP_EXTENSIONS -if _TORCHVISION_AVAILABLE: - from torchvision.datasets.folder import default_loader - def spectrogram_loader(filepath: str): if has_file_allowed_extension(filepath, IMG_EXTENSIONS): - img = default_loader(filepath) + img = image_default_loader(filepath) data = np.array(img) else: data = np.load(filepath) diff --git a/flash/core/data/utils.py b/flash/core/data/utils.py index 46342a2eb3..7ef3c70eb3 100644 --- a/flash/core/data/utils.py +++ b/flash/core/data/utils.py @@ -23,8 +23,17 @@ from torch import Tensor from tqdm.auto import tqdm as tq +from flash.core.utilities.imports import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE from flash.core.utilities.stages import RunningStage +if _PIL_AVAILABLE: + from PIL.Image import Image +else: + Image = object + +if _TORCHVISION_AVAILABLE: + from torchvision.datasets.folder import default_loader + _STAGES_PREFIX = { RunningStage.TRAINING: "train", RunningStage.TESTING: "test", @@ -207,3 +216,16 @@ def convert_to_modules(transforms: Optional[Dict[str, Callable]]): transforms, Iterable, torch.nn.ModuleList, wrong_dtype=(torch.nn.ModuleList, torch.nn.ModuleDict) ) return transforms + + +def image_default_loader(file_path: str, drop_alpha: bool = True) -> Image: + """Default loader for images. + + Args: + file_path: The image file to load. + drop_alpha: If ``True`` (default) then any alpha channels will be silently removed. + """ + img = default_loader(file_path) + if img.mode == "RGBA" and drop_alpha: + img = img.convert("RGB") + return img diff --git a/flash/core/integrations/labelstudio/data_source.py b/flash/core/integrations/labelstudio/data_source.py index b54bacdf1a..1d6b69c8cb 100644 --- a/flash/core/integrations/labelstudio/data_source.py +++ b/flash/core/integrations/labelstudio/data_source.py @@ -8,12 +8,10 @@ from flash.core.data.auto_dataset import AutoDataset, IterableAutoDataset from flash.core.data.data_source import DataSource, DefaultDataKeys, has_len -from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE, _TEXT_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.data.utils import image_default_loader +from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE, _TEXT_AVAILABLE from flash.core.utilities.stages import RunningStage -if _TORCHVISION_AVAILABLE: - from torchvision.datasets.folder import default_loader - if _TEXT_AVAILABLE: from transformers import AutoTokenizer @@ -188,7 +186,7 @@ def load_sample(self, sample: Mapping[str, Any] = None, dataset: Optional[Any] = """Load 1 sample from dataset.""" p = sample["file_upload"] # loading image - image = default_loader(p) + image = image_default_loader(p) result = {DefaultDataKeys.INPUT: image, DefaultDataKeys.TARGET: self._get_labels_from_sample(sample["label"])} return result diff --git a/flash/image/data.py b/flash/image/data.py index 5d0eb9cbe5..d8fa784ce0 100644 --- a/flash/image/data.py +++ b/flash/image/data.py @@ -30,10 +30,11 @@ TensorDataSource, ) from flash.core.data.process import Deserializer +from flash.core.data.utils import image_default_loader from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, Image, requires if _TORCHVISION_AVAILABLE: - from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS + from torchvision.datasets.folder import IMG_EXTENSIONS from torchvision.transforms.functional import to_pil_image else: IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") @@ -44,7 +45,7 @@ def image_loader(filepath: str): if has_file_allowed_extension(filepath, IMG_EXTENSIONS): - img = default_loader(filepath) + img = image_default_loader(filepath) elif has_file_allowed_extension(filepath, NP_EXTENSIONS): img = Image.fromarray(np.load(filepath).astype("uint8"), "RGB") else: @@ -116,7 +117,7 @@ class ImageFiftyOneDataSource(FiftyOneDataSource): @staticmethod def load_sample(sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: img_path = sample[DefaultDataKeys.INPUT] - img = default_loader(img_path) + img = image_default_loader(img_path) sample[DefaultDataKeys.INPUT] = img w, h = img.size # WxH sample[DefaultDataKeys.METADATA] = { diff --git a/flash/image/face_detection/data.py b/flash/image/face_detection/data.py index c6802766a4..663862e975 100644 --- a/flash/image/face_detection/data.py +++ b/flash/image/face_detection/data.py @@ -21,13 +21,13 @@ from flash.core.data.io.output_transform import OutputTransform from flash.core.data.process import Preprocess from flash.core.data.transforms import ApplyToKeys +from flash.core.data.utils import image_default_loader from flash.core.utilities.imports import _FASTFACE_AVAILABLE, _TORCHVISION_AVAILABLE from flash.image.data import ImagePathsDataSource from flash.image.detection import ObjectDetectionData if _TORCHVISION_AVAILABLE: import torchvision - from torchvision.datasets.folder import default_loader if _FASTFACE_AVAILABLE: import fastface as ff @@ -86,7 +86,7 @@ def load_data(self, data: Dataset, dataset: Any = None) -> Dataset: def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str, Any]: filepath = sample[DefaultDataKeys.INPUT] - img = default_loader(filepath) + img = image_default_loader(filepath) sample[DefaultDataKeys.INPUT] = img w, h = img.size # WxH diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index 7819dd4589..0e55db2ffc 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -34,6 +34,7 @@ TensorDataSource, ) from flash.core.data.process import Deserializer, Preprocess +from flash.core.data.utils import image_default_loader from flash.core.utilities.imports import ( _FIFTYONE_AVAILABLE, _MATPLOTLIB_AVAILABLE, @@ -63,7 +64,7 @@ if _TORCHVISION_AVAILABLE: import torchvision import torchvision.transforms.functional as FT - from torchvision.datasets.folder import default_loader, has_file_allowed_extension + from torchvision.datasets.folder import has_file_allowed_extension class SemanticSegmentationNumpyDataSource(NumpyDataSource): @@ -138,7 +139,7 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten img_labels_path = sample[DefaultDataKeys.TARGET] # load images directly to torch tensors - img: torch.Tensor = FT.to_tensor(default_loader(img_path)) # CxHxW + img: torch.Tensor = FT.to_tensor(image_default_loader(img_path)) # CxHxW img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW img_labels = img_labels[0] # HxW @@ -153,7 +154,7 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten @staticmethod def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]: img_path = sample[DefaultDataKeys.INPUT] - img = FT.to_tensor(default_loader(img_path)).float() + img = FT.to_tensor(image_default_loader(img_path)).float() sample[DefaultDataKeys.INPUT] = img sample[DefaultDataKeys.METADATA] = { @@ -184,7 +185,7 @@ def load_sample(self, sample: Mapping[str, str]) -> Mapping[str, Union[torch.Ten img_path = sample[DefaultDataKeys.INPUT] fo_sample = _fo_dataset[img_path] - img: torch.Tensor = FT.to_tensor(default_loader(img_path)) # CxHxW + img: torch.Tensor = FT.to_tensor(image_default_loader(img_path)) # CxHxW img_labels: torch.Tensor = torch.from_numpy(fo_sample[self.label_field].mask) # HxW sample[DefaultDataKeys.INPUT] = img.float() @@ -198,7 +199,7 @@ def load_sample(self, sample: Mapping[str, str]) -> Mapping[str, Union[torch.Ten @staticmethod def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]: img_path = sample[DefaultDataKeys.INPUT] - img = FT.to_tensor(default_loader(img_path)).float() + img = FT.to_tensor(image_default_loader(img_path)).float() sample[DefaultDataKeys.INPUT] = img sample[DefaultDataKeys.METADATA] = {