Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add visualisation callback for image classification (#228)
Browse files Browse the repository at this point in the history
* resolve bug

* update

* draft version of the image classification callback

* move callback under base class

* add crashing tests

* add more tests

* more tests

* more tests

* fix issues with files directory loading and add better error message

* fix visualisation test

* fix data outpye type

* fix tests with from_paths

* add matplotlib check import

* improve tests on from_folders

* add missing imports

* fixed test_classification

* implement all hooks

* fix more tests

* fix more tests

* add matplotlib in reaquirements

* remove useless test

* implement function filtering to visualize

* fix comments

* add setter method to block windows to show with matplotlib

* remove unused variable

Co-authored-by: tchaton <[email protected]>
  • Loading branch information
edgarriba and tchaton authored Apr 22, 2021
1 parent 7bfa80d commit 1f9e151
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 79 deletions.
16 changes: 11 additions & 5 deletions flash/data/base_viz.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, List, Sequence
from typing import Any, Dict, List, Sequence, Set

from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash.core.utils import _is_overriden
from flash.data.callback import BaseDataFetcher
Expand Down Expand Up @@ -94,14 +95,19 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage):
"""

def _show(self, running_stage: RunningStage) -> None:
self.show(self.batches[running_stage], running_stage)
def _show(self, running_stage: RunningStage, func_names_list: List[str]) -> None:
self.show(self.batches[running_stage], running_stage, func_names_list)

def show(self, batch: Dict[str, Any], running_stage: RunningStage) -> None:
def show(self, batch: Dict[str, Any], running_stage: RunningStage, func_names_list: List[str]) -> None:
"""
Override this function when you want to visualize a composition.
"""
for func_name in _PREPROCESS_FUNCS:
# filter out the functions to visualise
func_names_set: Set[str] = set(func_names_list) & set(_PREPROCESS_FUNCS)
if len(func_names_set) == 0:
raise MisconfigurationException(f"Invalid function names: {func_names_list}.")

for func_name in func_names_set:
hook_name = f"show_{func_name}"
if _is_overriden(hook_name, self, BaseVisualization):
getattr(self, hook_name)(batch[func_name], running_stage)
Expand Down
30 changes: 19 additions & 11 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
import platform
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -149,7 +149,7 @@ def _reset_iterator(self, stage: RunningStage) -> Iterable[Any]:
setattr(self, iter_name, iterator)
return iterator

def _show_batch(self, stage: RunningStage, reset: bool = True) -> None:
def _show_batch(self, stage: RunningStage, func_names: Union[str, List[str]], reset: bool = True) -> None:
"""
This function is used to handle transforms profiling for batch visualization.
"""
Expand All @@ -158,6 +158,10 @@ def _show_batch(self, stage: RunningStage, reset: bool = True) -> None:
if not hasattr(self, iter_name):
self._reset_iterator(stage)

# list of functions to visualise
if isinstance(func_names, str):
func_names = [func_names]

iter_dataloader = getattr(self, iter_name)
with self.data_fetcher.enable():
try:
Expand All @@ -166,25 +170,29 @@ def _show_batch(self, stage: RunningStage, reset: bool = True) -> None:
iter_dataloader = self._reset_iterator(stage)
_ = next(iter_dataloader)
data_fetcher: BaseVisualization = self.data_fetcher
data_fetcher._show(stage)
data_fetcher._show(stage, func_names)
if reset:
self.viz.batches[stage] = {}

def show_train_batch(self, reset: bool = True) -> None:
def show_train_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None:
"""This function is used to visualize a batch from the train dataloader."""
self._show_batch(_STAGES_PREFIX[RunningStage.TRAINING], reset=reset)
stage_name: str = _STAGES_PREFIX[RunningStage.TRAINING]
self._show_batch(stage_name, hooks_names, reset=reset)

def show_val_batch(self, reset: bool = True) -> None:
def show_val_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None:
"""This function is used to visualize a batch from the validation dataloader."""
self._show_batch(_STAGES_PREFIX[RunningStage.VALIDATING], reset=reset)
stage_name: str = _STAGES_PREFIX[RunningStage.VALIDATING]
self._show_batch(stage_name, hooks_names, reset=reset)

def show_test_batch(self, reset: bool = True) -> None:
def show_test_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None:
"""This function is used to visualize a batch from the test dataloader."""
self._show_batch(_STAGES_PREFIX[RunningStage.TESTING], reset=reset)
stage_name: str = _STAGES_PREFIX[RunningStage.TESTING]
self._show_batch(stage_name, hooks_names, reset=reset)

def show_predict_batch(self, reset: bool = True) -> None:
def show_predict_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None:
"""This function is used to visualize a batch from the predict dataloader."""
self._show_batch(_STAGES_PREFIX[RunningStage.PREDICTING], reset=reset)
stage_name: str = _STAGES_PREFIX[RunningStage.PREDICTING]
self._show_batch(stage_name, hooks_names, reset=reset)

@staticmethod
def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, default: Optional[Any] = None) -> Any:
Expand Down
1 change: 1 addition & 0 deletions flash/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
_COCO_AVAILABLE = _module_available("pycocotools")
_TIMM_AVAILABLE = _module_available("timm")
_TORCHVISION_AVAILABLE = _module_available("torchvision")
_MATPLOTLIB_AVAILABLE = _module_available("matplotlib")
_TRANSFORMERS_AVAILABLE = _module_available("transformers")
124 changes: 114 additions & 10 deletions flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,37 @@
import pathlib
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
import torchvision
from PIL import Image
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data._utils.collate import default_collate
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, make_dataset

from flash.core.classification import ClassificationState
from flash.core.utils import _is_overriden
from flash.data.auto_dataset import AutoDataset
from flash.data.base_viz import BaseVisualization # for viz
from flash.data.callback import BaseDataFetcher
from flash.data.data_module import DataModule
from flash.data.process import Preprocess
from flash.utils.imports import _KORNIA_AVAILABLE
from flash.data.utils import _PREPROCESS_FUNCS
from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE

if _KORNIA_AVAILABLE:
import kornia.augmentation as K
import kornia.geometry.transform as T
import kornia as K
else:
from torchvision import transforms as T

if _MATPLOTLIB_AVAILABLE:
import matplotlib.pyplot as plt
else:
plt = None


class ImageClassificationPreprocess(Preprocess):

Expand Down Expand Up @@ -93,9 +102,11 @@ def default_train_transforms(self):
# Better approach as all transforms are applied on tensor directly
return {
"to_tensor_transform": torchvision.transforms.ToTensor(),
"post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size), K.RandomHorizontalFlip()),
"post_tensor_transform": nn.Sequential(
K.augmentation.RandomResizedCrop(image_size), K.augmentation.RandomHorizontalFlip()
),
"per_batch_transform_on_device": nn.Sequential(
K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
)
}
else:
Expand All @@ -112,9 +123,9 @@ def default_val_transforms(self):
# Better approach as all transforms are applied on tensor directly
return {
"to_tensor_transform": torchvision.transforms.ToTensor(),
"post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size)),
"post_tensor_transform": nn.Sequential(K.augmentation.RandomResizedCrop(image_size)),
"per_batch_transform_on_device": nn.Sequential(
K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
)
}
else:
Expand Down Expand Up @@ -159,18 +170,34 @@ def _load_data_dir(
dataset: Optional[AutoDataset] = None,
) -> Tuple[Optional[List[str]], List[Tuple[str, int]]]:
if isinstance(data, list):
# TODO: define num_classes elsewhere. This is a bad assumption since the list of
# labels might not contain the complete set of ids so that you can infer the total
# number of classes to train in your dataset.
dataset.num_classes = len(data)
out = []
out: List[Tuple[str, int]] = []
for p, label in data:
if os.path.isdir(p):
for f in os.listdir(p):
# TODO: there is an issue here when a path is provided along with labels.
# os.listdir cannot assure the same file order as the passed labels list.
files_list: List[str] = os.listdir(p)
if len(files_list) > 1:
raise ValueError(
f"The provided directory contains more than one file."
f"Directory: {p} -> Contains: {files_list}"
)
for f in files_list:
if has_file_allowed_extension(f, IMG_EXTENSIONS):
out.append([os.path.join(p, f), label])
elif os.path.isfile(p) and has_file_allowed_extension(p, IMG_EXTENSIONS):
elif os.path.isfile(p) and has_file_allowed_extension(str(p), IMG_EXTENSIONS):
out.append([p, label])
else:
raise TypeError(f"Unexpected file path type: {p}.")
return None, out
else:
classes, class_to_idx = cls._find_classes(data)
# TODO: define num_classes elsewhere. This is a bad assumption since the list of
# labels might not contain the complete set of ids so that you can infer the total
# number of classes to train in your dataset.
dataset.num_classes = len(classes)
return classes, make_dataset(data, class_to_idx, IMG_EXTENSIONS, None)

Expand Down Expand Up @@ -318,6 +345,14 @@ def __init__(
if self._predict_ds:
self.set_dataset_attribute(self._predict_ds, 'num_classes', self.num_classes)

def set_block_viz_window(self, value: bool) -> None:
"""Setter method to switch on/off matplotlib to pop up windows."""
self.data_fetcher.block_viz_window = value

@staticmethod
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
return MatplotlibVisualization(*args, **kwargs)

@property
def num_classes(self) -> int:
if self._num_classes is None:
Expand Down Expand Up @@ -494,3 +529,72 @@ def from_filepaths(
seed=seed,
**kwargs
)


class MatplotlibVisualization(BaseVisualization):
"""Process and show the image batch and its associated label using matplotlib.
"""
max_cols: int = 4 # maximum number of columns we accept
block_viz_window: bool = True # parameter to allow user to block visualisation windows

@staticmethod
def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray:
out: np.ndarray
if isinstance(img, Image.Image):
out = np.array(img)
elif isinstance(img, torch.Tensor):
out = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
else:
raise TypeError(f"Unknown image type. Got: {type(img)}.")
return out

def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str):
# define the image grid
cols: int = min(num_samples, self.max_cols)
rows: int = num_samples // cols

if not _MATPLOTLIB_AVAILABLE:
raise MisconfigurationException("You need matplotlib to visualise. Please, pip install matplotlib")

# create figure and set title
fig, axs = plt.subplots(rows, cols)
fig.suptitle(title)

for i, ax in enumerate(axs.ravel()):
# unpack images and labels
if isinstance(data, list):
_img, _label = data[i]
elif isinstance(data, tuple):
imgs, labels = data
_img, _label = imgs[i], labels[i]
else:
raise TypeError(f"Unknown data type. Got: {type(data)}.")
# convert images to numpy
_img: np.ndarray = self._to_numpy(_img)
if isinstance(_label, torch.Tensor):
_label = _label.squeeze().tolist()
# show image and set label as subplot title
ax.imshow(_img)
ax.set_title(str(_label))
ax.axis('off')
plt.show(block=self.block_viz_window)

def show_load_sample(self, samples: List[Any], running_stage: RunningStage):
win_title: str = f"{running_stage} - show_load_sample"
self._show_images_and_labels(samples, len(samples), win_title)

def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
win_title: str = f"{running_stage} - show_pre_tensor_transform"
self._show_images_and_labels(samples, len(samples), win_title)

def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
win_title: str = f"{running_stage} - show_to_tensor_transform"
self._show_images_and_labels(samples, len(samples), win_title)

def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
win_title: str = f"{running_stage} - show_post_tensor_transform"
self._show_images_and_labels(samples, len(samples), win_title)

def show_per_batch_transform(self, batch: List[Any], running_stage):
win_title: str = f"{running_stage} - show_per_batch_transform"
self._show_images_and_labels(batch[0], batch[0][0].shape[0], win_title)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ sentencepiece>=0.1.95
filelock # comes with 3rd-party dependency
pycocotools>=2.0.2 ; python_version >= "3.7"
kornia>=0.5.0
matplotlib # used by the visualisation callback
8 changes: 5 additions & 3 deletions tests/data/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from flash.data.callback import BaseDataFetcher
from flash.data.data_module import DataModule
from flash.data.process import Preprocess
from flash.data.utils import _STAGES_PREFIX
from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX
from flash.vision import ImageClassificationData


Expand Down Expand Up @@ -146,8 +146,9 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization:
for stage in _STAGES_PREFIX.values():

for _ in range(10):
fcn = getattr(dm, f"show_{stage}_batch")
fcn(reset=False)
for fcn_name in _PREPROCESS_FUNCS:
fcn = getattr(dm, f"show_{stage}_batch")
fcn(fcn_name, reset=False)

is_predict = stage == "predict"

Expand Down Expand Up @@ -206,3 +207,4 @@ def test_data_loaders_num_workers_to_0(tmpdir):
assert isinstance(iterator, torch.utils.data.dataloader._SingleProcessDataLoaderIter)
iterator = iter(datamodule.train_dataloader())
assert isinstance(iterator, torch.utils.data.dataloader._MultiProcessingDataLoaderIter)
assert datamodule.num_workers == 3
Loading

0 comments on commit 1f9e151

Please sign in to comment.