Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[feat] Update Segmentation Task in preparation to adding resizing (#287)
Browse files Browse the repository at this point in the history
* cleanup segmentation

* update

* update

* update

* update

* update

* update
  • Loading branch information
tchaton authored May 12, 2021
1 parent e36a250 commit 41b9850
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 34 deletions.
1 change: 1 addition & 0 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

_PACKAGE_ROOT = os.path.dirname(__file__)
PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
_IS_TESTING = os.getenv("FLASH_TESTING", "0") == "1"

from flash.core.model import Task # noqa: E402
from flash.core.trainer import Trainer # noqa: E402
Expand Down
3 changes: 3 additions & 0 deletions flash/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def get_state(self, state_type: Type[ProcessState]) -> Optional[ProcessState]:
else:
return None

def __repr__(self) -> str:
return f"{self.__class__.__name__}(initialized={self._initialized}, state={self._state})"


class DataPipeline:
"""
Expand Down
7 changes: 7 additions & 0 deletions flash/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ class LabelsState(ProcessState):
labels: Optional[Sequence[str]]


@dataclass(unsafe_hash=True, frozen=True)
class ImageLabelsMap(ProcessState):

labels_map: Optional[Dict[int, Tuple[int, int, int]]]


class DefaultDataSources(LightningEnum):
"""The ``DefaultDataSources`` enum contains the data source names used by all of the default ``from_*`` methods in
:class:`~flash.data.data_module.DataModule`."""
Expand All @@ -66,6 +72,7 @@ class DefaultDataKeys(LightningEnum):

INPUT = "input"
TARGET = "target"
METADATA = "metadata"

# TODO: Create a FlashEnum class???
def __hash__(self) -> int:
Expand Down
13 changes: 13 additions & 0 deletions flash/data/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class SplitDataset(Dataset):
"""

_INTERNAL_KEYS = ("dataset", "indices", "data")

def __init__(self, dataset: Any, indices: List[int] = [], use_duplicated_indices: bool = False) -> None:
if not isinstance(indices, list):
raise MisconfigurationException("indices should be a list")
Expand All @@ -38,6 +40,17 @@ def __init__(self, dataset: Any, indices: List[int] = [], use_duplicated_indices
self.dataset = dataset
self.indices = indices

def __getattr__(self, key: str):
if key in self._INTERNAL_KEYS:
return getattr(self, key)
return getattr(self.dataset, key)

def __setattr__(self, name: str, value: Any) -> None:
if name in self._INTERNAL_KEYS:
self.__dict__[name] = value
else:
setattr(self.dataset, name, value)

def __getitem__(self, index: int) -> Any:
return self.dataset[self.indices[index]]

Expand Down
108 changes: 94 additions & 14 deletions flash/vision/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
Expand All @@ -23,12 +24,15 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS

import flash
from flash.data.auto_dataset import BaseAutoDataset
from flash.data.base_viz import BaseVisualization # for viz
from flash.data.callback import BaseDataFetcher
from flash.data.data_module import DataModule
from flash.data.data_source import (
DefaultDataKeys,
DefaultDataSources,
ImageLabelsMap,
NumpyDataSource,
PathsDataSource,
TensorDataSource,
Expand Down Expand Up @@ -56,7 +60,8 @@ class SemanticSegmentationPathsDataSource(PathsDataSource):
def __init__(self):
super().__init__(IMG_EXTENSIONS)

def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]]) -> Sequence[Mapping[str, Any]]:
def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]],
dataset: BaseAutoDataset) -> Sequence[Mapping[str, Any]]:
input_data, target_data = data

if self.isdir(input_data) and self.isdir(target_data):
Expand Down Expand Up @@ -98,7 +103,7 @@ def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]]) -
def predict_load_data(self, data: Union[str, List[str]]):
return super().predict_load_data(data)

def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, torch.Tensor]:
def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Tensor, torch.Size]]:
# unpack data paths
img_path = sample[DefaultDataKeys.INPUT]
img_labels_path = sample[DefaultDataKeys.TARGET]
Expand All @@ -108,7 +113,11 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, torch.Tensor]:
img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW
img_labels = img_labels[0] # HxW

return {DefaultDataKeys.INPUT: img.float(), DefaultDataKeys.TARGET: img_labels.float()}
return {
DefaultDataKeys.INPUT: img.float(),
DefaultDataKeys.TARGET: img_labels.float(),
DefaultDataKeys.METADATA: img.shape,
}

def predict_load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
return {DefaultDataKeys.INPUT: torchvision.io.read_image(sample[DefaultDataKeys.INPUT]).float()}
Expand All @@ -123,6 +132,8 @@ def __init__(
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
image_size: Tuple[int, int] = (196, 196),
num_classes: int = None,
labels_map: Dict[int, Tuple[int, int, int]] = None,
) -> None:
"""Preprocess pipeline for semantic segmentation tasks.
Expand All @@ -134,6 +145,9 @@ def __init__(
image_size: A tuple with the expected output image size.
"""
self.image_size = image_size
self.num_classes = num_classes
if num_classes:
labels_map = labels_map or SegmentationLabels.create_random_labels_map(num_classes)

super().__init__(
train_transform=train_transform,
Expand All @@ -149,10 +163,16 @@ def __init__(
default_data_source=DefaultDataSources.FILES,
)

if labels_map:
self.set_state(ImageLabelsMap(labels_map))

self.labels_map = labels_map

def get_state_dict(self) -> Dict[str, Any]:
return {
**self.transforms,
"image_size": self.image_size,
**self.transforms, "image_size": self.image_size,
"num_classes": self.num_classes,
"labels_map": self.labels_map
}

@classmethod
Expand Down Expand Up @@ -182,16 +202,69 @@ class SemanticSegmentationData(DataModule):
preprocess_cls = SemanticSegmentationPreprocess

@staticmethod
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
return SegmentationMatplotlibVisualization(*args, **kwargs)

def set_labels_map(self, labels_map: Dict[int, Tuple[int, int, int]]):
self.data_fetcher.labels_map = labels_map
def configure_data_fetcher(
labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None
) -> 'SegmentationMatplotlibVisualization':
return SegmentationMatplotlibVisualization(labels_map=labels_map)

def set_block_viz_window(self, value: bool) -> None:
"""Setter method to switch on/off matplotlib to pop up windows."""
self.data_fetcher.block_viz_window = value

@classmethod
def from_data_source(
cls,
data_source: str,
train_data: Any = None,
val_data: Any = None,
test_data: Any = None,
predict_data: Any = None,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
**preprocess_kwargs: Any,
) -> 'DataModule':

if 'num_classes' not in preprocess_kwargs:
raise MisconfigurationException("`num_classes` should be provided during instantiation.")

num_classes = preprocess_kwargs["num_classes"]

labels_map = getattr(preprocess_kwargs, "labels_map",
None) or SegmentationLabels.create_random_labels_map(num_classes)

data_fetcher = data_fetcher or cls.configure_data_fetcher(labels_map)

if flash._IS_TESTING:
data_fetcher.block_viz_window = True

dm = super(SemanticSegmentationData, cls).from_data_source(
data_source=data_source,
train_data=train_data,
val_data=val_data,
test_data=test_data,
predict_data=predict_data,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_fetcher=data_fetcher,
preprocess=preprocess,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
**preprocess_kwargs
)

dm.train_dataset.num_classes = num_classes
return dm

@classmethod
def from_folders(
cls,
Expand All @@ -211,7 +284,9 @@ def from_folders(
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
**preprocess_kwargs: Any,
num_classes: Optional[int] = None,
labels_map: Dict[int, Tuple[int, int, int]] = None,
**preprocess_kwargs,
) -> 'DataModule':
"""Creates a :class:`~flash.vision.segmentation.data.SemanticSegmentationData` object from the given data
folders and corresponding target folders.
Expand Down Expand Up @@ -243,6 +318,8 @@ def from_folders(
val_split: The ``val_split`` argument to pass to the :class:`~flash.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.data.data_module.DataModule`.
num_classes: Number of classes within the segmentation mask.
labels_map: Mapping between a class_id and its corresponding color.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand Down Expand Up @@ -271,6 +348,8 @@ def from_folders(
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
num_classes=num_classes,
labels_map=labels_map,
**preprocess_kwargs,
)

Expand All @@ -279,11 +358,12 @@ class SegmentationMatplotlibVisualization(BaseVisualization):
"""Process and show the image batch and its associated label using matplotlib.
"""

def __init__(self):
super().__init__(self)
def __init__(self, labels_map: Dict[int, Tuple[int, int, int]]):
super().__init__()

self.max_cols: int = 4 # maximum number of columns we accept
self.block_viz_window: bool = True # parameter to allow user to block visualisation windows
self.labels_map: Dict[int, Tuple[int, int, int]] = {}
self.labels_map: Dict[int, Tuple[int, int, int]] = labels_map

@staticmethod
def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray:
Expand Down
13 changes: 6 additions & 7 deletions flash/vision/segmentation/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import torch

import flash
from flash.data.data_source import ImageLabelsMap
from flash.data.process import Serializer
from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE

Expand Down Expand Up @@ -68,14 +70,11 @@ def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]
def serialize(self, sample: torch.Tensor) -> torch.Tensor:
assert len(sample.shape) == 3, sample.shape
labels = torch.argmax(sample, dim=-3) # HxW
if self.visualize and os.getenv("FLASH_TESTING", "0") == "0":

if self.visualize and not flash._IS_TESTING:
if self.labels_map is None:
# create random colors map
num_classes = sample.shape[-3]
labels_map = self.create_random_labels_map(num_classes)
else:
labels_map = self.labels_map
labels_vis = self.labels_to_image(labels, labels_map)
self.labels_map = self.get_state(ImageLabelsMap).labels_map
labels_vis = self.labels_to_image(labels, self.labels_map)
labels_vis = K.utils.tensor_to_image(labels_vis)
plt.imshow(labels_vis)
plt.show()
Expand Down
9 changes: 3 additions & 6 deletions flash_examples/finetuning/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@
batch_size=4,
val_split=0.3,
image_size=(200, 200), # (600, 800)
num_classes=21,
)

# 2.2 Visualise the samples
labels_map = SegmentationLabels.create_random_labels_map(num_classes=21)
datamodule.set_labels_map(labels_map)
datamodule.show_train_batch(["load_sample", "post_tensor_transform"])

# 3. Build the model
model = SemanticSegmentation(
backbone="torchvision/fcn_resnet50",
num_classes=21,
num_classes=datamodule.num_classes,
serializer=SegmentationLabels(visualize=True)
)

# 4. Create the trainer.
Expand All @@ -53,9 +53,6 @@
# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 6. Predict what's on a few images!
model.serializer = SegmentationLabels(labels_map, visualize=True)

predictions = model.predict([
"data/CameraRGB/F61-1.png",
"data/CameraRGB/F62-1.png",
Expand Down
4 changes: 1 addition & 3 deletions flash_examples/predict/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
)

# 2. Load the model from a checkpoint
model = SemanticSegmentation.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt"
)
model = SemanticSegmentation.load_from_checkpoint("semantic_segmentation_model.pt")
model.serializer = SegmentationLabels(visualize=True)

# 3. Predict what's on a few images and visualize!
Expand Down
20 changes: 20 additions & 0 deletions tests/data/test_split_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,23 @@ def test_split_dataset(tmpdir):

with pytest.raises(MisconfigurationException, match="[0, 99]"):
SplitDataset(list(range(50)) + list(range(50)), indices=[-1], use_duplicated_indices=True)

class Dataset:

def __init__(self):
self.data = [0, 1, 2]
self.name = "something"

def __getitem__(self, index):
return self.data[index]

def __len__(self):
return len(self.data)

split_dataset = SplitDataset(Dataset(), indices=[0])
assert split_dataset.name == "something"

assert split_dataset._INTERNAL_KEYS == ("dataset", "indices", "data")

split_dataset.is_passed_down = True
assert split_dataset.dataset.is_passed_down
Loading

0 comments on commit 41b9850

Please sign in to comment.