Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Onboard segmentation inputs to new object (#1015)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Dec 1, 2021
1 parent e363aa8 commit 78c60ed
Show file tree
Hide file tree
Showing 11 changed files with 322 additions and 300 deletions.
4 changes: 3 additions & 1 deletion docs/source/api/image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,11 @@ ____________
~segmentation.data.SemanticSegmentationInputTransform

segmentation.data.SegmentationMatplotlibVisualization
segmentation.data.SemanticSegmentationInput
segmentation.data.SemanticSegmentationFilesInput
segmentation.data.SemanticSegmentationFolderInput
segmentation.data.SemanticSegmentationNumpyInput
segmentation.data.SemanticSegmentationTensorInput
segmentation.data.SemanticSegmentationPathsInput
segmentation.data.SemanticSegmentationFiftyOneInput
segmentation.data.SemanticSegmentationDeserializer
segmentation.model.SemanticSegmentationOutputTransform
Expand Down
16 changes: 8 additions & 8 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ def __init__(
if flash._IS_TESTING and torch.cuda.is_available():
batch_size = 16

self._train_ds = train_dataset
self._val_ds = val_dataset
self._test_ds = test_dataset
self._predict_ds = predict_dataset

if self._train_ds and (val_split is not None and not self._val_ds):
self._train_ds, self._val_ds = self._split_train_val(self._train_ds, val_split)

self._input: Input = input
self._input_tranform: Optional[InputTransform] = input_transform
self._output_transform: Optional[OutputTransform] = output_transform
Expand All @@ -119,14 +127,6 @@ def __init__(
# TODO: InputTransform can change
self.data_fetcher.attach_to_input_transform(self.input_transform)

self._train_ds = train_dataset
self._val_ds = val_dataset
self._test_ds = test_dataset
self._predict_ds = predict_dataset

if self._train_ds and (val_split is not None and not self._val_ds):
self._train_ds, self._val_ds = self._split_train_val(self._train_ds, val_split)

if self._train_ds:
self.train_dataloader = self._train_dataloader

Expand Down
11 changes: 10 additions & 1 deletion flash/core/data/utilities/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import os
from typing import Any, Callable, cast, List, Optional, Tuple, TypeVar, Union

from pytorch_lightning.utilities.exceptions import MisconfigurationException

PATH_TYPE = Union[str, bytes, os.PathLike]

T = TypeVar("T")
Expand Down Expand Up @@ -132,7 +134,7 @@ def list_valid_files(
def filter_valid_files(
files: Union[PATH_TYPE, List[PATH_TYPE]],
*additional_lists: List[Any],
valid_extensions: Optional[Tuple[str, ...]] = None
valid_extensions: Optional[Tuple[str, ...]] = None,
) -> Union[List[Any], Tuple[List[Any], ...]]:
"""Filter the given list of files and any additional lists to include only the entries that contain a file with
a valid extension.
Expand All @@ -148,6 +150,13 @@ def filter_valid_files(
if not isinstance(files, List):
files = [files]

additional_lists = tuple([a] if not isinstance(a, List) else a for a in additional_lists)

if not all(len(a) == len(files) for a in additional_lists):
raise MisconfigurationException(
f"The number of files ({len(files)}) and the number of items in any additional lists must be the same."
)

if valid_extensions is None:
return (files,) + additional_lists
filtered = list(
Expand Down
2 changes: 0 additions & 2 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ class ImageClassificationTensorInput(ClassificationInput, ImageTensorInput):
def load_data(self, tensor: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
if targets is not None:
self.load_target_metadata(targets)

return to_samples(tensor, targets)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -114,7 +113,6 @@ class ImageClassificationNumpyInput(ClassificationInput, ImageNumpyInput):
def load_data(self, array: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
if targets is not None:
self.load_target_metadata(targets)

return to_samples(array, targets)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
8 changes: 7 additions & 1 deletion flash/image/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def example_input(self) -> str:
class ImageInput(Input):
@requires("image")
def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
w, h = sample[DataKeys.INPUT].size # WxH
w, h = sample[DataKeys.INPUT].size # W x H
if DataKeys.METADATA not in sample:
sample[DataKeys.METADATA] = {}
sample[DataKeys.METADATA]["size"] = (h, w)
Expand All @@ -99,13 +99,19 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:


class ImageTensorInput(ImageInput):
def load_data(self, tensor: Any) -> List[Dict[str, Any]]:
return to_samples(tensor)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
img = to_pil_image(sample[DataKeys.INPUT])
sample[DataKeys.INPUT] = img
return super().load_sample(sample)


class ImageNumpyInput(ImageInput):
def load_data(self, array: Any) -> List[Dict[str, Any]]:
return to_samples(array)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
img = to_pil_image(torch.from_numpy(sample[DataKeys.INPUT]))
sample[DataKeys.INPUT] = img
Expand Down
Loading

0 comments on commit 78c60ed

Please sign in to comment.