From 41b9850695b4eb1eb15e4c271a11df5ee58fff80 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 12 May 2021 18:58:04 +0100 Subject: [PATCH] [feat] Update Segmentation Task in preparation to adding resizing (#287) * cleanup segmentation * update * update * update * update * update * update --- flash/__init__.py | 1 + flash/data/data_pipeline.py | 3 + flash/data/data_source.py | 7 ++ flash/data/splits.py | 13 +++ flash/vision/segmentation/data.py | 108 +++++++++++++++--- flash/vision/segmentation/serialization.py | 13 +-- .../finetuning/semantic_segmentation.py | 9 +- .../predict/semantic_segmentation.py | 4 +- tests/data/test_split_dataset.py | 20 ++++ tests/vision/segmentation/test_data.py | 8 +- tests/vision/segmentation/test_model.py | 4 +- 11 files changed, 156 insertions(+), 34 deletions(-) diff --git a/flash/__init__.py b/flash/__init__.py index 3eef508374..1caafa591f 100644 --- a/flash/__init__.py +++ b/flash/__init__.py @@ -18,6 +18,7 @@ _PACKAGE_ROOT = os.path.dirname(__file__) PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) +_IS_TESTING = os.getenv("FLASH_TESTING", "0") == "1" from flash.core.model import Task # noqa: E402 from flash.core.trainer import Trainer # noqa: E402 diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 8a0b739ace..80eb0ecbad 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -61,6 +61,9 @@ def get_state(self, state_type: Type[ProcessState]) -> Optional[ProcessState]: else: return None + def __repr__(self) -> str: + return f"{self.__class__.__name__}(initialized={self._initialized}, state={self._state})" + class DataPipeline: """ diff --git a/flash/data/data_source.py b/flash/data/data_source.py index 7cbe294918..23955413ab 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -44,6 +44,12 @@ class LabelsState(ProcessState): labels: Optional[Sequence[str]] +@dataclass(unsafe_hash=True, frozen=True) +class ImageLabelsMap(ProcessState): + + labels_map: Optional[Dict[int, Tuple[int, int, int]]] + + class DefaultDataSources(LightningEnum): """The ``DefaultDataSources`` enum contains the data source names used by all of the default ``from_*`` methods in :class:`~flash.data.data_module.DataModule`.""" @@ -66,6 +72,7 @@ class DefaultDataKeys(LightningEnum): INPUT = "input" TARGET = "target" + METADATA = "metadata" # TODO: Create a FlashEnum class??? def __hash__(self) -> int: diff --git a/flash/data/splits.py b/flash/data/splits.py index d8f4e2aa7e..8c09ad2290 100644 --- a/flash/data/splits.py +++ b/flash/data/splits.py @@ -23,6 +23,8 @@ class SplitDataset(Dataset): """ + _INTERNAL_KEYS = ("dataset", "indices", "data") + def __init__(self, dataset: Any, indices: List[int] = [], use_duplicated_indices: bool = False) -> None: if not isinstance(indices, list): raise MisconfigurationException("indices should be a list") @@ -38,6 +40,17 @@ def __init__(self, dataset: Any, indices: List[int] = [], use_duplicated_indices self.dataset = dataset self.indices = indices + def __getattr__(self, key: str): + if key in self._INTERNAL_KEYS: + return getattr(self, key) + return getattr(self.dataset, key) + + def __setattr__(self, name: str, value: Any) -> None: + if name in self._INTERNAL_KEYS: + self.__dict__[name] = value + else: + setattr(self.dataset, name, value) + def __getitem__(self, index: int) -> Any: return self.dataset[self.indices[index]] diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index ae84349aad..c4f9d0ebf9 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -23,12 +24,15 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS +import flash +from flash.data.auto_dataset import BaseAutoDataset from flash.data.base_viz import BaseVisualization # for viz from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule from flash.data.data_source import ( DefaultDataKeys, DefaultDataSources, + ImageLabelsMap, NumpyDataSource, PathsDataSource, TensorDataSource, @@ -56,7 +60,8 @@ class SemanticSegmentationPathsDataSource(PathsDataSource): def __init__(self): super().__init__(IMG_EXTENSIONS) - def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]]) -> Sequence[Mapping[str, Any]]: + def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]], + dataset: BaseAutoDataset) -> Sequence[Mapping[str, Any]]: input_data, target_data = data if self.isdir(input_data) and self.isdir(target_data): @@ -98,7 +103,7 @@ def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]]) - def predict_load_data(self, data: Union[str, List[str]]): return super().predict_load_data(data) - def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, torch.Tensor]: + def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Tensor, torch.Size]]: # unpack data paths img_path = sample[DefaultDataKeys.INPUT] img_labels_path = sample[DefaultDataKeys.TARGET] @@ -108,7 +113,11 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, torch.Tensor]: img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW img_labels = img_labels[0] # HxW - return {DefaultDataKeys.INPUT: img.float(), DefaultDataKeys.TARGET: img_labels.float()} + return { + DefaultDataKeys.INPUT: img.float(), + DefaultDataKeys.TARGET: img_labels.float(), + DefaultDataKeys.METADATA: img.shape, + } def predict_load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: return {DefaultDataKeys.INPUT: torchvision.io.read_image(sample[DefaultDataKeys.INPUT]).float()} @@ -123,6 +132,8 @@ def __init__( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, image_size: Tuple[int, int] = (196, 196), + num_classes: int = None, + labels_map: Dict[int, Tuple[int, int, int]] = None, ) -> None: """Preprocess pipeline for semantic segmentation tasks. @@ -134,6 +145,9 @@ def __init__( image_size: A tuple with the expected output image size. """ self.image_size = image_size + self.num_classes = num_classes + if num_classes: + labels_map = labels_map or SegmentationLabels.create_random_labels_map(num_classes) super().__init__( train_transform=train_transform, @@ -149,10 +163,16 @@ def __init__( default_data_source=DefaultDataSources.FILES, ) + if labels_map: + self.set_state(ImageLabelsMap(labels_map)) + + self.labels_map = labels_map + def get_state_dict(self) -> Dict[str, Any]: return { - **self.transforms, - "image_size": self.image_size, + **self.transforms, "image_size": self.image_size, + "num_classes": self.num_classes, + "labels_map": self.labels_map } @classmethod @@ -182,16 +202,69 @@ class SemanticSegmentationData(DataModule): preprocess_cls = SemanticSegmentationPreprocess @staticmethod - def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: - return SegmentationMatplotlibVisualization(*args, **kwargs) - - def set_labels_map(self, labels_map: Dict[int, Tuple[int, int, int]]): - self.data_fetcher.labels_map = labels_map + def configure_data_fetcher( + labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None + ) -> 'SegmentationMatplotlibVisualization': + return SegmentationMatplotlibVisualization(labels_map=labels_map) def set_block_viz_window(self, value: bool) -> None: """Setter method to switch on/off matplotlib to pop up windows.""" self.data_fetcher.block_viz_window = value + @classmethod + def from_data_source( + cls, + data_source: str, + train_data: Any = None, + val_data: Any = None, + test_data: Any = None, + predict_data: Any = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ) -> 'DataModule': + + if 'num_classes' not in preprocess_kwargs: + raise MisconfigurationException("`num_classes` should be provided during instantiation.") + + num_classes = preprocess_kwargs["num_classes"] + + labels_map = getattr(preprocess_kwargs, "labels_map", + None) or SegmentationLabels.create_random_labels_map(num_classes) + + data_fetcher = data_fetcher or cls.configure_data_fetcher(labels_map) + + if flash._IS_TESTING: + data_fetcher.block_viz_window = True + + dm = super(SemanticSegmentationData, cls).from_data_source( + data_source=data_source, + train_data=train_data, + val_data=val_data, + test_data=test_data, + predict_data=predict_data, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs + ) + + dm.train_dataset.num_classes = num_classes + return dm + @classmethod def from_folders( cls, @@ -211,7 +284,9 @@ def from_folders( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - **preprocess_kwargs: Any, + num_classes: Optional[int] = None, + labels_map: Dict[int, Tuple[int, int, int]] = None, + **preprocess_kwargs, ) -> 'DataModule': """Creates a :class:`~flash.vision.segmentation.data.SemanticSegmentationData` object from the given data folders and corresponding target folders. @@ -243,6 +318,8 @@ def from_folders( val_split: The ``val_split`` argument to pass to the :class:`~flash.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + num_classes: Number of classes within the segmentation mask. + labels_map: Mapping between a class_id and its corresponding color. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -271,6 +348,8 @@ def from_folders( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + num_classes=num_classes, + labels_map=labels_map, **preprocess_kwargs, ) @@ -279,11 +358,12 @@ class SegmentationMatplotlibVisualization(BaseVisualization): """Process and show the image batch and its associated label using matplotlib. """ - def __init__(self): - super().__init__(self) + def __init__(self, labels_map: Dict[int, Tuple[int, int, int]]): + super().__init__() + self.max_cols: int = 4 # maximum number of columns we accept self.block_viz_window: bool = True # parameter to allow user to block visualisation windows - self.labels_map: Dict[int, Tuple[int, int, int]] = {} + self.labels_map: Dict[int, Tuple[int, int, int]] = labels_map @staticmethod def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray: diff --git a/flash/vision/segmentation/serialization.py b/flash/vision/segmentation/serialization.py index 50ba5be9a9..5a8cb40f69 100644 --- a/flash/vision/segmentation/serialization.py +++ b/flash/vision/segmentation/serialization.py @@ -16,6 +16,8 @@ import torch +import flash +from flash.data.data_source import ImageLabelsMap from flash.data.process import Serializer from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE @@ -68,14 +70,11 @@ def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int] def serialize(self, sample: torch.Tensor) -> torch.Tensor: assert len(sample.shape) == 3, sample.shape labels = torch.argmax(sample, dim=-3) # HxW - if self.visualize and os.getenv("FLASH_TESTING", "0") == "0": + + if self.visualize and not flash._IS_TESTING: if self.labels_map is None: - # create random colors map - num_classes = sample.shape[-3] - labels_map = self.create_random_labels_map(num_classes) - else: - labels_map = self.labels_map - labels_vis = self.labels_to_image(labels, labels_map) + self.labels_map = self.get_state(ImageLabelsMap).labels_map + labels_vis = self.labels_to_image(labels, self.labels_map) labels_vis = K.utils.tensor_to_image(labels_vis) plt.imshow(labels_vis) plt.show() diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index 3676353ec8..2d5bfaaee2 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -31,17 +31,17 @@ batch_size=4, val_split=0.3, image_size=(200, 200), # (600, 800) + num_classes=21, ) # 2.2 Visualise the samples -labels_map = SegmentationLabels.create_random_labels_map(num_classes=21) -datamodule.set_labels_map(labels_map) datamodule.show_train_batch(["load_sample", "post_tensor_transform"]) # 3. Build the model model = SemanticSegmentation( backbone="torchvision/fcn_resnet50", - num_classes=21, + num_classes=datamodule.num_classes, + serializer=SegmentationLabels(visualize=True) ) # 4. Create the trainer. @@ -53,9 +53,6 @@ # 5. Train the model trainer.finetune(model, datamodule=datamodule, strategy="freeze") -# 6. Predict what's on a few images! -model.serializer = SegmentationLabels(labels_map, visualize=True) - predictions = model.predict([ "data/CameraRGB/F61-1.png", "data/CameraRGB/F62-1.png", diff --git a/flash_examples/predict/semantic_segmentation.py b/flash_examples/predict/semantic_segmentation.py index f507f2a6a6..9209923be7 100644 --- a/flash_examples/predict/semantic_segmentation.py +++ b/flash_examples/predict/semantic_segmentation.py @@ -24,9 +24,7 @@ ) # 2. Load the model from a checkpoint -model = SemanticSegmentation.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt" -) +model = SemanticSegmentation.load_from_checkpoint("semantic_segmentation_model.pt") model.serializer = SegmentationLabels(visualize=True) # 3. Predict what's on a few images and visualize! diff --git a/tests/data/test_split_dataset.py b/tests/data/test_split_dataset.py index e92a44e1b6..cc087cd167 100644 --- a/tests/data/test_split_dataset.py +++ b/tests/data/test_split_dataset.py @@ -37,3 +37,23 @@ def test_split_dataset(tmpdir): with pytest.raises(MisconfigurationException, match="[0, 99]"): SplitDataset(list(range(50)) + list(range(50)), indices=[-1], use_duplicated_indices=True) + + class Dataset: + + def __init__(self): + self.data = [0, 1, 2] + self.name = "something" + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return len(self.data) + + split_dataset = SplitDataset(Dataset(), indices=[0]) + assert split_dataset.name == "something" + + assert split_dataset._INTERNAL_KEYS == ("dataset", "indices", "data") + + split_dataset.is_passed_down = True + assert split_dataset.dataset.is_passed_down diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index bd51f09d21..4d68bbc1d1 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -44,7 +44,7 @@ class TestSemanticSegmentationPreprocess: @pytest.mark.xfail(reaspn="parameters are marked as optional but it returns Misconficg error.") def test_smoke(self): - prep = SemanticSegmentationPreprocess() + prep = SemanticSegmentationPreprocess(num_classes=1) assert prep is not None @@ -89,6 +89,7 @@ def test_from_folders(self, tmpdir): test_target_folder=str(tmp_dir / "targets"), batch_size=2, num_workers=0, + num_classes=num_classes, ) assert dm is not None assert dm.train_dataloader() is not None @@ -143,6 +144,7 @@ def test_from_folders_warning(self, tmpdir): train_target_folder=str(tmp_dir / "targets"), batch_size=1, num_workers=0, + num_classes=num_classes, ) assert dm is not None assert dm.train_dataloader() is not None @@ -185,6 +187,7 @@ def test_from_files(self, tmpdir): test_targets=targets, batch_size=2, num_workers=0, + num_classes=num_classes ) assert dm is not None assert dm.train_dataloader() is not None @@ -238,6 +241,7 @@ def test_from_files_warning(self, tmpdir): train_targets=targets + [str(tmp_dir / "labels_img4.png")], batch_size=2, num_workers=0, + num_classes=num_classes ) def test_map_labels(self, tmpdir): @@ -275,6 +279,7 @@ def test_map_labels(self, tmpdir): val_targets=targets, batch_size=2, num_workers=0, + num_classes=num_classes ) assert dm is not None assert dm.train_dataloader() is not None @@ -284,7 +289,6 @@ def test_map_labels(self, tmpdir): dm.set_block_viz_window(False) assert dm.data_fetcher.block_viz_window is False - dm.set_labels_map(labels_map) dm.show_train_batch("load_sample") dm.show_train_batch("to_tensor_transform") diff --git a/tests/vision/segmentation/test_model.py b/tests/vision/segmentation/test_model.py index e248271e10..5ccc86d68f 100644 --- a/tests/vision/segmentation/test_model.py +++ b/tests/vision/segmentation/test_model.py @@ -86,7 +86,7 @@ def test_unfreeze(): def test_predict_tensor(): img = torch.rand(1, 3, 10, 20) model = SemanticSegmentation(2) - data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess()) + data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1)) out = model.predict(img, data_source="tensors", data_pipeline=data_pipe) assert isinstance(out[0], torch.Tensor) assert out[0].shape == (196, 196) @@ -95,7 +95,7 @@ def test_predict_tensor(): def test_predict_numpy(): img = np.ones((1, 3, 10, 20)) model = SemanticSegmentation(2) - data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess()) + data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1)) out = model.predict(img, data_source="numpy", data_pipeline=data_pipe) assert isinstance(out[0], torch.Tensor) assert out[0].shape == (196, 196)