Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Support NP_EXTENSIONS for SemanticSegmentationFilesInput (#1369)
Browse files Browse the repository at this point in the history
Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 14, 2022
1 parent fd8cc7f commit 95e19d3
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where the `processor_backbone` argument to `SpeechRecognition` was not used for decoding outputs ([#1362](https://github.com/PyTorchLightning/lightning-flash/pull/1362))

- Fixed a bug where `.npy` files could not be used with `SemanticSegmentationData` ([#1369](https://github.com/PyTorchLightning/lightning-flash/pull/1369))

## [0.7.4] - 2022-04-27

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/utilities/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _load_image_from_image(file, drop_alpha: bool = True):


def _load_image_from_numpy(file):
return Image.fromarray(np.load(file).astype("uint8"), "RGB")
return Image.fromarray(np.load(file).astype("uint8")).convert("RGB")


def _load_spectrogram_from_image(file):
Expand Down
8 changes: 4 additions & 4 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def from_files(
>>> from PIL import Image
>>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))
>>> rand_mask= Image.fromarray(np.random.randint(0, 10, (64, 64), dtype="uint8"))
>>> rand_mask= np.random.randint(0, 10, (64, 64), dtype="uint8")
>>> _ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)]
>>> _ = [rand_mask.save(f"mask_{i}.png") for i in range(1, 4)]
>>> _ = [np.save(f"mask_{i}.npy", rand_mask) for i in range(1, 4)]
>>> _ = [rand_image.save(f"predict_image_{i}.png") for i in range(1, 4)]
.. doctest::
Expand All @@ -126,7 +126,7 @@ def from_files(
>>> from flash.image import SemanticSegmentation, SemanticSegmentationData
>>> datamodule = SemanticSegmentationData.from_files(
... train_files=["image_1.png", "image_2.png", "image_3.png"],
... train_targets=["mask_1.png", "mask_2.png", "mask_3.png"],
... train_targets=["mask_1.npy", "mask_2.npy", "mask_3.npy"],
... predict_files=["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"],
... transform_kwargs=dict(image_size=(128, 128)),
... num_classes=10,
Expand All @@ -145,7 +145,7 @@ def from_files(
>>> import os
>>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)]
>>> _ = [os.remove(f"mask_{i}.png") for i in range(1, 4)]
>>> _ = [os.remove(f"mask_{i}.npy") for i in range(1, 4)]
>>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)]
"""

Expand Down
6 changes: 3 additions & 3 deletions flash/image/segmentation/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch

from flash.core.data.io.input import DataKeys, Input
from flash.core.data.utilities.loading import IMG_EXTENSIONS, load_image
from flash.core.data.utilities.loading import IMG_EXTENSIONS, load_image, NP_EXTENSIONS
from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE
from flash.core.data.utilities.samples import to_samples
from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities
Expand Down Expand Up @@ -98,9 +98,9 @@ def load_data(
) -> List[Dict[str, Any]]:
self.load_labels_map(num_classes, labels_map)
if mask_files is None:
files = filter_valid_files(files, valid_extensions=IMG_EXTENSIONS)
files = filter_valid_files(files, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS)
else:
files, mask_files = filter_valid_files(files, mask_files, valid_extensions=IMG_EXTENSIONS)
files, mask_files = filter_valid_files(files, mask_files, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS)
return to_samples(files, mask_files)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Expand Down

0 comments on commit 95e19d3

Please sign in to comment.