Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Support NP_EXTENSIONS for SemanticSegmentationFilesInput #1369

Merged
merged 6 commits into from
Jul 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where the `processor_backbone` argument to `SpeechRecognition` was not used for decoding outputs ([#1362](https://github.com/PyTorchLightning/lightning-flash/pull/1362))

- Fixed a bug where `.npy` files could not be used with `SemanticSegmentationData` ([#1369](https://github.com/PyTorchLightning/lightning-flash/pull/1369))

## [0.7.4] - 2022-04-27

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/utilities/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _load_image_from_image(file, drop_alpha: bool = True):


def _load_image_from_numpy(file):
return Image.fromarray(np.load(file).astype("uint8"), "RGB")
return Image.fromarray(np.load(file).astype("uint8")).convert("RGB")


def _load_spectrogram_from_image(file):
Expand Down
8 changes: 4 additions & 4 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def from_files(

>>> from PIL import Image
>>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))
>>> rand_mask= Image.fromarray(np.random.randint(0, 10, (64, 64), dtype="uint8"))
>>> rand_mask= np.random.randint(0, 10, (64, 64), dtype="uint8")
>>> _ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)]
>>> _ = [rand_mask.save(f"mask_{i}.png") for i in range(1, 4)]
>>> _ = [np.save(f"mask_{i}.npy", rand_mask) for i in range(1, 4)]
>>> _ = [rand_image.save(f"predict_image_{i}.png") for i in range(1, 4)]

.. doctest::
Expand All @@ -126,7 +126,7 @@ def from_files(
>>> from flash.image import SemanticSegmentation, SemanticSegmentationData
>>> datamodule = SemanticSegmentationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["mask_1.png", "mask_2.png", "mask_3.png"],
... train_targets=["mask_1.npy", "mask_2.npy", "mask_3.npy"],
... predict_files=["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"],
... transform_kwargs=dict(image_size=(128, 128)),
... num_classes=10,
Expand All @@ -145,7 +145,7 @@ def from_files(

>>> import os
>>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)]
>>> _ = [os.remove(f"mask_{i}.png") for i in range(1, 4)]
>>> _ = [os.remove(f"mask_{i}.npy") for i in range(1, 4)]
>>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)]
"""

Expand Down
6 changes: 3 additions & 3 deletions flash/image/segmentation/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch

from flash.core.data.io.input import DataKeys, Input
from flash.core.data.utilities.loading import IMG_EXTENSIONS, load_image
from flash.core.data.utilities.loading import IMG_EXTENSIONS, load_image, NP_EXTENSIONS
from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE
from flash.core.data.utilities.samples import to_samples
from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities
Expand Down Expand Up @@ -98,9 +98,9 @@ def load_data(
) -> List[Dict[str, Any]]:
self.load_labels_map(num_classes, labels_map)
if mask_files is None:
files = filter_valid_files(files, valid_extensions=IMG_EXTENSIONS)
files = filter_valid_files(files, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS)
else:
files, mask_files = filter_valid_files(files, mask_files, valid_extensions=IMG_EXTENSIONS)
files, mask_files = filter_valid_files(files, mask_files, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS)
return to_samples(files, mask_files)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Expand Down