Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[feat] add support for metadata + resize Semantic Segmentation Image #290

Merged
merged 21 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from flash.core.schedulers import _SCHEDULERS_REGISTRY
from flash.core.utils import get_callable_dict
from flash.data.data_pipeline import DataPipeline, DataPipelineState
from flash.data.data_source import DataSource, DefaultDataSources
from flash.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources
from flash.data.process import Postprocess, Preprocess, Serializer, SerializerMapping


Expand Down
13 changes: 12 additions & 1 deletion flash/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Optional, Sequence, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import torch
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor

from flash.data.callback import ControlFlow
from flash.data.data_source import DefaultDataKeys
from flash.data.utils import _contains_any_tensor, convert_to_modules, CurrentFuncContext, CurrentRunningStageContext

if TYPE_CHECKING:
Expand Down Expand Up @@ -137,6 +138,13 @@ def __init__(
self._collate_context = CurrentFuncContext("collate", preprocess)
self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", preprocess)

def _extract_metadata(
self,
samples: List[Dict[str, Any]],
) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]:
metadata = [s.pop(DefaultDataKeys.METADATA, None) if isinstance(s, Mapping) else None for s in samples]
return samples, metadata if any(m is not None for m in metadata) else None

def forward(self, samples: Sequence[Any]) -> Any:
# we create a new dict to prevent from potential memory leaks
# assuming that the dictionary samples are stored in between and
Expand All @@ -158,7 +166,10 @@ def forward(self, samples: Sequence[Any]) -> Any:
samples = type(_samples)(_samples)

with self._collate_context:
samples, metadata = self._extract_metadata(samples)
samples = self.collate_fn(samples)
if metadata:
samples[DefaultDataKeys.METADATA] = metadata
self.callback.on_collate(samples, self.stage)

with self._per_batch_transform_context:
Expand Down
1 change: 1 addition & 0 deletions flash/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class DefaultDataKeys(LightningEnum):
targets."""

INPUT = "input"
PREDS = "preds"
TARGET = "target"
METADATA = "metadata"

Expand Down
22 changes: 19 additions & 3 deletions flash/vision/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ImageLabelsMap,
NumpyDataSource,
PathsDataSource,
SEQUENCE_DATA_TYPE,
TensorDataSource,
)
from flash.data.process import Preprocess
Expand All @@ -51,7 +52,18 @@
class SemanticSegmentationNumpyDataSource(NumpyDataSource):

def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
sample[DefaultDataKeys.INPUT] = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float()
img = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float()
sample[DefaultDataKeys.INPUT] = img
sample[DefaultDataKeys.METADATA] = img.shape
return sample


class SemanticSegmentationTensorDataSource(TensorDataSource):

def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
img = sample[DefaultDataKeys.INPUT].float()
sample[DefaultDataKeys.INPUT] = img
sample[DefaultDataKeys.METADATA] = img.shape
return sample


Expand Down Expand Up @@ -120,7 +132,11 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten
}

def predict_load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
return {DefaultDataKeys.INPUT: torchvision.io.read_image(sample[DefaultDataKeys.INPUT]).float()}
img = torchvision.io.read_image(sample[DefaultDataKeys.INPUT]).float()
return {
DefaultDataKeys.INPUT: img,
DefaultDataKeys.METADATA: img.shape,
}


class SemanticSegmentationPreprocess(Preprocess):
Expand Down Expand Up @@ -157,7 +173,7 @@ def __init__(
data_sources={
DefaultDataSources.FILES: SemanticSegmentationPathsDataSource(),
DefaultDataSources.FOLDERS: SemanticSegmentationPathsDataSource(),
DefaultDataSources.TENSORS: TensorDataSource(),
DefaultDataSources.TENSORS: SemanticSegmentationTensorDataSource(),
DefaultDataSources.NUMPY: SemanticSegmentationNumpyDataSource(),
},
default_data_source=DefaultDataSources.FILES,
Expand Down
25 changes: 22 additions & 3 deletions flash/vision/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,23 @@
from flash.core.classification import ClassificationTask
from flash.core.registry import FlashRegistry
from flash.data.data_source import DefaultDataKeys
from flash.data.process import Serializer
from flash.data.process import Postprocess, Serializer
from flash.utils.imports import _KORNIA_AVAILABLE
from flash.vision.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES
from flash.vision.segmentation.serialization import SegmentationLabels

if _KORNIA_AVAILABLE:
import kornia as K


class SemanticSegmentationPostprocess(Postprocess):

def per_sample_transform(self, sample: Any) -> Any:
resize = K.geometry.Resize(sample[DefaultDataKeys.METADATA][-2:], interpolation='bilinear')
sample[DefaultDataKeys.PREDS] = resize(torch.stack(sample[DefaultDataKeys.PREDS]))
sample[DefaultDataKeys.INPUT] = resize(torch.stack(sample[DefaultDataKeys.INPUT]))
return super().per_sample_transform(sample)


class SemanticSegmentation(ClassificationTask):
"""Task that performs semantic segmentation on images.
Expand Down Expand Up @@ -53,6 +66,8 @@ class SemanticSegmentation(ClassificationTask):
serializer: The :class:`~flash.data.process.Serializer` to use when serializing prediction outputs.
"""

postprocess_cls = SemanticSegmentationPostprocess

backbones: FlashRegistry = SEMANTIC_SEGMENTATION_BACKBONES

def __init__(
Expand All @@ -67,6 +82,7 @@ def __init__(
learning_rate: float = 1e-3,
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
postprocess: Optional[Postprocess] = None,
) -> None:

if metrics is None:
Expand All @@ -86,6 +102,7 @@ def __init__(
metrics=metrics,
learning_rate=learning_rate,
serializer=serializer or SegmentationLabels(),
postprocess=postprocess or self.postprocess_cls()
)

self.save_hyperparameters()
Expand All @@ -109,8 +126,10 @@ def test_step(self, batch: Any, batch_idx: int) -> Any:
return super().test_step(batch, batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
batch = (batch[DefaultDataKeys.INPUT])
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)
batch_input = (batch[DefaultDataKeys.INPUT])
preds = super().predict_step(batch_input, batch_idx, dataloader_idx=dataloader_idx)
batch[DefaultDataKeys.PREDS] = preds
return batch

def forward(self, x) -> torch.Tensor:
# infer the image to the model
Expand Down
9 changes: 5 additions & 4 deletions flash/vision/segmentation/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch

import flash
from flash.data.data_source import ImageLabelsMap
from flash.data.data_source import DefaultDataKeys, ImageLabelsMap
from flash.data.process import Serializer
from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE

Expand Down Expand Up @@ -67,9 +67,10 @@ def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]
labels_map[i] = torch.randint(0, 255, (3, ))
return labels_map

def serialize(self, sample: torch.Tensor) -> torch.Tensor:
assert len(sample.shape) == 3, sample.shape
labels = torch.argmax(sample, dim=-3) # HxW
def serialize(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor:
preds = sample[DefaultDataKeys.PREDS]
assert len(preds.shape) == 3, preds.shape
labels = torch.argmax(preds, dim=-3) # HxW

if self.visualize and not flash._IS_TESTING:
if self.labels_map is None:
Expand Down
4 changes: 2 additions & 2 deletions tests/vision/segmentation/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_predict_tensor():
data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1))
out = model.predict(img, data_source="tensors", data_pipeline=data_pipe)
assert isinstance(out[0], torch.Tensor)
assert out[0].shape == (196, 196)
assert out[0].shape == (10, 20)


def test_predict_numpy():
Expand All @@ -98,4 +98,4 @@ def test_predict_numpy():
data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1))
out = model.predict(img, data_source="numpy", data_pipeline=data_pipe)
assert isinstance(out[0], torch.Tensor)
assert out[0].shape == (196, 196)
assert out[0].shape == (10, 20)
3 changes: 2 additions & 1 deletion tests/vision/segmentation/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch

from flash.data.data_source import DefaultDataKeys
from flash.vision.segmentation.serialization import SegmentationLabels


Expand Down Expand Up @@ -30,7 +31,7 @@ def test_serialize(self):
sample[1, 1, 2] = 1 # add peak in class 2
sample[3, 0, 1] = 1 # add peak in class 4

classes = serial.serialize(sample)
classes = serial.serialize({DefaultDataKeys.PREDS: sample})
assert classes[1, 2] == 1
assert classes[0, 1] == 3

Expand Down