From 95e19d3ad312821f92dba51fe1b103f618b73d1d Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Fri, 15 Jul 2022 02:24:05 +0530 Subject: [PATCH] Support `NP_EXTENSIONS` for `SemanticSegmentationFilesInput` (#1369) Co-authored-by: Ethan Harris Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 2 ++ flash/core/data/utilities/loading.py | 2 +- flash/image/segmentation/data.py | 8 ++++---- flash/image/segmentation/input.py | 6 +++--- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bcc17a414a..eb1e945ef2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where the `processor_backbone` argument to `SpeechRecognition` was not used for decoding outputs ([#1362](https://github.com/PyTorchLightning/lightning-flash/pull/1362)) +- Fixed a bug where `.npy` files could not be used with `SemanticSegmentationData` ([#1369](https://github.com/PyTorchLightning/lightning-flash/pull/1369)) + ## [0.7.4] - 2022-04-27 ### Fixed diff --git a/flash/core/data/utilities/loading.py b/flash/core/data/utilities/loading.py index a466add7eb..b0bc0180e0 100644 --- a/flash/core/data/utilities/loading.py +++ b/flash/core/data/utilities/loading.py @@ -70,7 +70,7 @@ def _load_image_from_image(file, drop_alpha: bool = True): def _load_image_from_numpy(file): - return Image.fromarray(np.load(file).astype("uint8"), "RGB") + return Image.fromarray(np.load(file).astype("uint8")).convert("RGB") def _load_spectrogram_from_image(file): diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index 66d07d46a0..12465c3f3d 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -115,9 +115,9 @@ def from_files( >>> from PIL import Image >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) - >>> rand_mask= Image.fromarray(np.random.randint(0, 10, (64, 64), dtype="uint8")) + >>> rand_mask= np.random.randint(0, 10, (64, 64), dtype="uint8") >>> _ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)] - >>> _ = [rand_mask.save(f"mask_{i}.png") for i in range(1, 4)] + >>> _ = [np.save(f"mask_{i}.npy", rand_mask) for i in range(1, 4)] >>> _ = [rand_image.save(f"predict_image_{i}.png") for i in range(1, 4)] .. doctest:: @@ -126,7 +126,7 @@ def from_files( >>> from flash.image import SemanticSegmentation, SemanticSegmentationData >>> datamodule = SemanticSegmentationData.from_files( ... train_files=["image_1.png", "image_2.png", "image_3.png"], - ... train_targets=["mask_1.png", "mask_2.png", "mask_3.png"], + ... train_targets=["mask_1.npy", "mask_2.npy", "mask_3.npy"], ... predict_files=["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"], ... transform_kwargs=dict(image_size=(128, 128)), ... num_classes=10, @@ -145,7 +145,7 @@ def from_files( >>> import os >>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)] - >>> _ = [os.remove(f"mask_{i}.png") for i in range(1, 4)] + >>> _ = [os.remove(f"mask_{i}.npy") for i in range(1, 4)] >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] """ diff --git a/flash/image/segmentation/input.py b/flash/image/segmentation/input.py index b72694fe66..20efdc4cf6 100644 --- a/flash/image/segmentation/input.py +++ b/flash/image/segmentation/input.py @@ -17,7 +17,7 @@ import torch from flash.core.data.io.input import DataKeys, Input -from flash.core.data.utilities.loading import IMG_EXTENSIONS, load_image +from flash.core.data.utilities.loading import IMG_EXTENSIONS, load_image, NP_EXTENSIONS from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE from flash.core.data.utilities.samples import to_samples from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities @@ -98,9 +98,9 @@ def load_data( ) -> List[Dict[str, Any]]: self.load_labels_map(num_classes, labels_map) if mask_files is None: - files = filter_valid_files(files, valid_extensions=IMG_EXTENSIONS) + files = filter_valid_files(files, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS) else: - files, mask_files = filter_valid_files(files, mask_files, valid_extensions=IMG_EXTENSIONS) + files, mask_files = filter_valid_files(files, mask_files, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS) return to_samples(files, mask_files) def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: