Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[feat] add support for metadata + resize Semantic Segmentation Image (#…
Browse files Browse the repository at this point in the history
…290)

* cleanup segmentation

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
  • Loading branch information
tchaton authored May 12, 2021
1 parent 41b9850 commit 74fbafd
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 15 deletions.
2 changes: 1 addition & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from flash.core.schedulers import _SCHEDULERS_REGISTRY
from flash.core.utils import get_callable_dict
from flash.data.data_pipeline import DataPipeline, DataPipelineState
from flash.data.data_source import DataSource, DefaultDataSources
from flash.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources
from flash.data.process import Postprocess, Preprocess, Serializer, SerializerMapping


Expand Down
13 changes: 12 additions & 1 deletion flash/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Optional, Sequence, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import torch
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor

from flash.data.callback import ControlFlow
from flash.data.data_source import DefaultDataKeys
from flash.data.utils import _contains_any_tensor, convert_to_modules, CurrentFuncContext, CurrentRunningStageContext

if TYPE_CHECKING:
Expand Down Expand Up @@ -137,6 +138,13 @@ def __init__(
self._collate_context = CurrentFuncContext("collate", preprocess)
self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", preprocess)

def _extract_metadata(
self,
samples: List[Dict[str, Any]],
) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]:
metadata = [s.pop(DefaultDataKeys.METADATA, None) if isinstance(s, Mapping) else None for s in samples]
return samples, metadata if any(m is not None for m in metadata) else None

def forward(self, samples: Sequence[Any]) -> Any:
# we create a new dict to prevent from potential memory leaks
# assuming that the dictionary samples are stored in between and
Expand All @@ -158,7 +166,10 @@ def forward(self, samples: Sequence[Any]) -> Any:
samples = type(_samples)(_samples)

with self._collate_context:
samples, metadata = self._extract_metadata(samples)
samples = self.collate_fn(samples)
if metadata:
samples[DefaultDataKeys.METADATA] = metadata
self.callback.on_collate(samples, self.stage)

with self._per_batch_transform_context:
Expand Down
1 change: 1 addition & 0 deletions flash/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class DefaultDataKeys(LightningEnum):
targets."""

INPUT = "input"
PREDS = "preds"
TARGET = "target"
METADATA = "metadata"

Expand Down
22 changes: 19 additions & 3 deletions flash/vision/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ImageLabelsMap,
NumpyDataSource,
PathsDataSource,
SEQUENCE_DATA_TYPE,
TensorDataSource,
)
from flash.data.process import Preprocess
Expand All @@ -51,7 +52,18 @@
class SemanticSegmentationNumpyDataSource(NumpyDataSource):

def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
sample[DefaultDataKeys.INPUT] = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float()
img = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float()
sample[DefaultDataKeys.INPUT] = img
sample[DefaultDataKeys.METADATA] = img.shape
return sample


class SemanticSegmentationTensorDataSource(TensorDataSource):

def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
img = sample[DefaultDataKeys.INPUT].float()
sample[DefaultDataKeys.INPUT] = img
sample[DefaultDataKeys.METADATA] = img.shape
return sample


Expand Down Expand Up @@ -120,7 +132,11 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten
}

def predict_load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
return {DefaultDataKeys.INPUT: torchvision.io.read_image(sample[DefaultDataKeys.INPUT]).float()}
img = torchvision.io.read_image(sample[DefaultDataKeys.INPUT]).float()
return {
DefaultDataKeys.INPUT: img,
DefaultDataKeys.METADATA: img.shape,
}


class SemanticSegmentationPreprocess(Preprocess):
Expand Down Expand Up @@ -157,7 +173,7 @@ def __init__(
data_sources={
DefaultDataSources.FILES: SemanticSegmentationPathsDataSource(),
DefaultDataSources.FOLDERS: SemanticSegmentationPathsDataSource(),
DefaultDataSources.TENSORS: TensorDataSource(),
DefaultDataSources.TENSORS: SemanticSegmentationTensorDataSource(),
DefaultDataSources.NUMPY: SemanticSegmentationNumpyDataSource(),
},
default_data_source=DefaultDataSources.FILES,
Expand Down
25 changes: 22 additions & 3 deletions flash/vision/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,23 @@
from flash.core.classification import ClassificationTask
from flash.core.registry import FlashRegistry
from flash.data.data_source import DefaultDataKeys
from flash.data.process import Serializer
from flash.data.process import Postprocess, Serializer
from flash.utils.imports import _KORNIA_AVAILABLE
from flash.vision.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES
from flash.vision.segmentation.serialization import SegmentationLabels

if _KORNIA_AVAILABLE:
import kornia as K


class SemanticSegmentationPostprocess(Postprocess):

def per_sample_transform(self, sample: Any) -> Any:
resize = K.geometry.Resize(sample[DefaultDataKeys.METADATA][-2:], interpolation='bilinear')
sample[DefaultDataKeys.PREDS] = resize(torch.stack(sample[DefaultDataKeys.PREDS]))
sample[DefaultDataKeys.INPUT] = resize(torch.stack(sample[DefaultDataKeys.INPUT]))
return super().per_sample_transform(sample)


class SemanticSegmentation(ClassificationTask):
"""Task that performs semantic segmentation on images.
Expand Down Expand Up @@ -53,6 +66,8 @@ class SemanticSegmentation(ClassificationTask):
serializer: The :class:`~flash.data.process.Serializer` to use when serializing prediction outputs.
"""

postprocess_cls = SemanticSegmentationPostprocess

backbones: FlashRegistry = SEMANTIC_SEGMENTATION_BACKBONES

def __init__(
Expand All @@ -67,6 +82,7 @@ def __init__(
learning_rate: float = 1e-3,
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
postprocess: Optional[Postprocess] = None,
) -> None:

if metrics is None:
Expand All @@ -86,6 +102,7 @@ def __init__(
metrics=metrics,
learning_rate=learning_rate,
serializer=serializer or SegmentationLabels(),
postprocess=postprocess or self.postprocess_cls()
)

self.save_hyperparameters()
Expand All @@ -109,8 +126,10 @@ def test_step(self, batch: Any, batch_idx: int) -> Any:
return super().test_step(batch, batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
batch = (batch[DefaultDataKeys.INPUT])
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)
batch_input = (batch[DefaultDataKeys.INPUT])
preds = super().predict_step(batch_input, batch_idx, dataloader_idx=dataloader_idx)
batch[DefaultDataKeys.PREDS] = preds
return batch

def forward(self, x) -> torch.Tensor:
# infer the image to the model
Expand Down
9 changes: 5 additions & 4 deletions flash/vision/segmentation/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch

import flash
from flash.data.data_source import ImageLabelsMap
from flash.data.data_source import DefaultDataKeys, ImageLabelsMap
from flash.data.process import Serializer
from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE

Expand Down Expand Up @@ -67,9 +67,10 @@ def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]
labels_map[i] = torch.randint(0, 255, (3, ))
return labels_map

def serialize(self, sample: torch.Tensor) -> torch.Tensor:
assert len(sample.shape) == 3, sample.shape
labels = torch.argmax(sample, dim=-3) # HxW
def serialize(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor:
preds = sample[DefaultDataKeys.PREDS]
assert len(preds.shape) == 3, preds.shape
labels = torch.argmax(preds, dim=-3) # HxW

if self.visualize and not flash._IS_TESTING:
if self.labels_map is None:
Expand Down
4 changes: 2 additions & 2 deletions tests/vision/segmentation/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_predict_tensor():
data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1))
out = model.predict(img, data_source="tensors", data_pipeline=data_pipe)
assert isinstance(out[0], torch.Tensor)
assert out[0].shape == (196, 196)
assert out[0].shape == (10, 20)


def test_predict_numpy():
Expand All @@ -98,4 +98,4 @@ def test_predict_numpy():
data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1))
out = model.predict(img, data_source="numpy", data_pipeline=data_pipe)
assert isinstance(out[0], torch.Tensor)
assert out[0].shape == (196, 196)
assert out[0].shape == (10, 20)
3 changes: 2 additions & 1 deletion tests/vision/segmentation/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch

from flash.data.data_source import DefaultDataKeys
from flash.vision.segmentation.serialization import SegmentationLabels


Expand Down Expand Up @@ -30,7 +31,7 @@ def test_serialize(self):
sample[1, 1, 2] = 1 # add peak in class 2
sample[3, 0, 1] = 1 # add peak in class 4

classes = serial.serialize(sample)
classes = serial.serialize({DefaultDataKeys.PREDS: sample})
assert classes[1, 2] == 1
assert classes[0, 1] == 3

Expand Down

0 comments on commit 74fbafd

Please sign in to comment.