Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add visualisation callback for image classification #228

Merged
merged 30 commits into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7850c17
resolve bug
tchaton Apr 19, 2021
dec2e16
update
tchaton Apr 19, 2021
98b0564
draft version of the image classification callback
edgarriba Apr 19, 2021
18f6339
Merge branch 'master' into feat/viz_callback
edgarriba Apr 19, 2021
108ff3b
Merge branch 'master' into feat/viz_callback
edgarriba Apr 19, 2021
d2b423b
move callback under base class
edgarriba Apr 19, 2021
c2587fd
add crashing tests
edgarriba Apr 20, 2021
c2095be
add more tests
edgarriba Apr 20, 2021
b7d448d
more tests
edgarriba Apr 20, 2021
3dd9b19
more tests
edgarriba Apr 20, 2021
bd42635
fix issues with files directory loading and add better error message
edgarriba Apr 20, 2021
32ac9b3
fix visualisation test
edgarriba Apr 20, 2021
0ef5d2a
fix data outpye type
edgarriba Apr 20, 2021
d444841
fix tests with from_paths
edgarriba Apr 20, 2021
3a0b088
add matplotlib check import
edgarriba Apr 20, 2021
ce64c40
improve tests on from_folders
edgarriba Apr 20, 2021
be47408
add missing imports
edgarriba Apr 20, 2021
b2a19bf
fixed test_classification
edgarriba Apr 20, 2021
48ea3c2
implement all hooks
edgarriba Apr 20, 2021
76457f7
fix more tests
edgarriba Apr 20, 2021
159c002
fix more tests
edgarriba Apr 20, 2021
b339bc9
add matplotlib in reaquirements
edgarriba Apr 20, 2021
0123284
remove useless test
edgarriba Apr 20, 2021
a48f149
Merge branch 'master' into feat/viz_callback
edgarriba Apr 20, 2021
d125c5d
implement function filtering to visualize
edgarriba Apr 20, 2021
21efbca
fix comments
edgarriba Apr 21, 2021
9e8138f
Merge branch 'master' into feat/viz_callback
edgarriba Apr 21, 2021
afeedac
add setter method to block windows to show with matplotlib
edgarriba Apr 22, 2021
0761bfb
sync with master and fix conflicts
edgarriba Apr 22, 2021
f05c0d4
remove unused variable
edgarriba Apr 22, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions flash/data/base_viz.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, List, Sequence
from typing import Any, Dict, List, Sequence, Set

from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash.core.utils import _is_overriden
from flash.data.callback import BaseDataFetcher
Expand Down Expand Up @@ -94,14 +95,19 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage):

"""

def _show(self, running_stage: RunningStage) -> None:
self.show(self.batches[running_stage], running_stage)
def _show(self, running_stage: RunningStage, func_names_list: List[str]) -> None:
self.show(self.batches[running_stage], running_stage, func_names_list)

def show(self, batch: Dict[str, Any], running_stage: RunningStage) -> None:
def show(self, batch: Dict[str, Any], running_stage: RunningStage, func_names_list: List[str]) -> None:
"""
Override this function when you want to visualize a composition.
"""
for func_name in _PREPROCESS_FUNCS:
# filter out the functions to visualise
func_names_set: Set[str] = set(func_names_list) & set(_PREPROCESS_FUNCS)
if len(func_names_set) == 0:
raise MisconfigurationException(f"Invalid function names: {func_names_list}.")

for func_name in func_names_set:
hook_name = f"show_{func_name}"
if _is_overriden(hook_name, self, BaseVisualization):
getattr(self, hook_name)(batch[func_name], running_stage)
Expand Down
33 changes: 22 additions & 11 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
import platform
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -92,6 +92,9 @@ def __init__(
# this may also trigger data preloading
self.set_running_stages()

# buffer to store the functions to visualise
self._viz_func_white_list: Dict[str, Set[str]] = {}

@property
def train_dataset(self) -> Optional[Dataset]:
"""This property returns the train dataset"""
Expand Down Expand Up @@ -147,7 +150,7 @@ def _reset_iterator(self, stage: RunningStage) -> Iterable[Any]:
setattr(self, iter_name, iterator)
return iterator

def _show_batch(self, stage: RunningStage, reset: bool = True) -> None:
def _show_batch(self, stage: RunningStage, func_names: Union[str, List[str]], reset: bool = True) -> None:
"""
This function is used to handle transforms profiling for batch visualization.
"""
Expand All @@ -156,6 +159,10 @@ def _show_batch(self, stage: RunningStage, reset: bool = True) -> None:
if not hasattr(self, iter_name):
self._reset_iterator(stage)

# list of functions to visualise
if isinstance(func_names, str):
func_names = [func_names]

iter_dataloader = getattr(self, iter_name)
with self.data_fetcher.enable():
try:
Expand All @@ -164,25 +171,29 @@ def _show_batch(self, stage: RunningStage, reset: bool = True) -> None:
iter_dataloader = self._reset_iterator(stage)
_ = next(iter_dataloader)
data_fetcher: BaseVisualization = self.data_fetcher
data_fetcher._show(stage)
data_fetcher._show(stage, func_names)
if reset:
self.viz.batches[stage] = {}

def show_train_batch(self, reset: bool = True) -> None:
def show_train_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None:
"""This function is used to visualize a batch from the train dataloader."""
self._show_batch(_STAGES_PREFIX[RunningStage.TRAINING], reset=reset)
stage_name: str = _STAGES_PREFIX[RunningStage.TRAINING]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this logic to _show_batch to reduce duplicated code and raise a MisConfigurationError is the provided names aren't in _Preprocess_funcs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self._show_batch(stage_name, hooks_names, reset=reset)

def show_val_batch(self, reset: bool = True) -> None:
def show_val_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None:
"""This function is used to visualize a batch from the validation dataloader."""
self._show_batch(_STAGES_PREFIX[RunningStage.VALIDATING], reset=reset)
stage_name: str = _STAGES_PREFIX[RunningStage.VALIDATING]
self._show_batch(stage_name, hooks_names, reset=reset)

def show_test_batch(self, reset: bool = True) -> None:
def show_test_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None:
"""This function is used to visualize a batch from the test dataloader."""
self._show_batch(_STAGES_PREFIX[RunningStage.TESTING], reset=reset)
stage_name: str = _STAGES_PREFIX[RunningStage.TESTING]
self._show_batch(stage_name, hooks_names, reset=reset)

def show_predict_batch(self, reset: bool = True) -> None:
def show_predict_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None:
"""This function is used to visualize a batch from the predict dataloader."""
self._show_batch(_STAGES_PREFIX[RunningStage.PREDICTING], reset=reset)
stage_name: str = _STAGES_PREFIX[RunningStage.PREDICTING]
self._show_batch(stage_name, hooks_names, reset=reset)

@staticmethod
def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, default: Optional[Any] = None) -> Any:
Expand Down
1 change: 1 addition & 0 deletions flash/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
_COCO_AVAILABLE = _module_available("pycocotools")
_TIMM_AVAILABLE = _module_available("timm")
_TORCHVISION_AVAILABLE = _module_available("torchvision")
_MATPLOTLIB_AVAILABLE = _module_available("matplotlib")
_TRANSFORMERS_AVAILABLE = _module_available("transformers")
132 changes: 119 additions & 13 deletions flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,37 @@
import pathlib
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union

import numpy as np
import torch
import torchvision
from PIL import Image
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data._utils.collate import default_collate
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, make_dataset

from flash.core.utils import _is_overriden
from flash.data.auto_dataset import AutoDataset
from flash.data.base_viz import BaseVisualization # for viz
from flash.data.callback import BaseDataFetcher
from flash.data.data_module import DataModule
from flash.data.data_pipeline import DataPipeline
from flash.data.process import Preprocess
from flash.utils.imports import _KORNIA_AVAILABLE
from flash.data.utils import _PREPROCESS_FUNCS
from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE

if _KORNIA_AVAILABLE:
import kornia.augmentation as K
import kornia.geometry.transform as T
import kornia as K
else:
from torchvision import transforms as T

if _MATPLOTLIB_AVAILABLE:
import matplotlib.pyplot as plt
else:
plt = None


class ImageClassificationPreprocess(Preprocess):

Expand Down Expand Up @@ -74,20 +84,37 @@ def _get_predicting_files(samples: Union[Sequence, str]) -> List[str]:
return files

@classmethod
def _load_data_dir(cls, data: Any, dataset: Optional[AutoDataset] = None) -> List[str]:
def _load_data_dir(cls, data: Any, dataset: Optional[AutoDataset] = None) -> List[Tuple[str, int]]:
# case where we pass a list of files
if isinstance(data, list):
# TODO: define num_classes elsewhere. This is a bad assumption since the list of
# labels might not contain the complete set of ids so that you can infer the total
# number of classes to train in your dataset.
dataset.num_classes = len(data)
out = []
out: List[Tuple[str, int]] = []
for p, label in data:
if os.path.isdir(p):
for f in os.listdir(p):
# TODO: there is an issue here when a path is provided along with labels.
# os.listdir cannot assure the same file order as the passed labels list.
files_list: List[str] = os.listdir(p)
if len(files_list) > 1:
raise ValueError(
f"The provided directory contains more than one file."
f"Directory: {p} -> Contains: {files_list}"
)
for f in files_list:
if has_file_allowed_extension(f, IMG_EXTENSIONS):
out.append([os.path.join(p, f), label])
elif os.path.isfile(p) and has_file_allowed_extension(p, IMG_EXTENSIONS):
elif os.path.isfile(p) and has_file_allowed_extension(str(p), IMG_EXTENSIONS):
out.append([p, label])
else:
raise TypeError(f"Unexpected file path type: {p}.")
return out
else:
classes, class_to_idx = cls._find_classes(data)
# TODO: define num_classes elsewhere. This is a bad assumption since the list of
# labels might not contain the complete set of ids so that you can infer the total
# number of classes to train in your dataset.
dataset.num_classes = len(classes)
return make_dataset(data, class_to_idx, IMG_EXTENSIONS, None)

Expand Down Expand Up @@ -231,6 +258,14 @@ def __init__(
if self._predict_ds:
self.set_dataset_attribute(self._predict_ds, 'num_classes', self.num_classes)

def set_block_viz_window(self, value: bool) -> None:
"""Setter method to switch on/off matplotlib to pop up windows."""
self.data_fetcher.block_viz_window = value

@staticmethod
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
return MatplotlibVisualization(*args, **kwargs)

@staticmethod
def _check_transforms(transform: Dict[str, Union[nn.Module, Callable]]) -> Dict[str, Union[nn.Module, Callable]]:
if transform and not isinstance(transform, Dict):
Expand All @@ -247,14 +282,16 @@ def _check_transforms(transform: Dict[str, Union[nn.Module, Callable]]) -> Dict[

@staticmethod
def default_train_transforms():
image_size = ImageClassificationData.image_size
image_size: Tuple[int, int] = ImageClassificationData.image_size
if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1":
# Better approach as all transforms are applied on tensor directly
return {
"to_tensor_transform": torchvision.transforms.ToTensor(),
"post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size), K.RandomHorizontalFlip()),
"post_tensor_transform": nn.Sequential(
K.augmentation.RandomResizedCrop(image_size), K.augmentation.RandomHorizontalFlip()
),
"per_batch_transform_on_device": nn.Sequential(
K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
)
}
else:
Expand All @@ -267,14 +304,14 @@ def default_train_transforms():

@staticmethod
def default_val_transforms():
image_size = ImageClassificationData.image_size
image_size: Tuple[int, int] = ImageClassificationData.image_size
if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1":
# Better approach as all transforms are applied on tensor directly
return {
"to_tensor_transform": torchvision.transforms.ToTensor(),
"post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size)),
"post_tensor_transform": nn.Sequential(K.augmentation.RandomResizedCrop(image_size)),
"per_batch_transform_on_device": nn.Sequential(
K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
)
}
else:
Expand Down Expand Up @@ -521,3 +558,72 @@ def from_filepaths(
seed=seed,
**kwargs
)


class MatplotlibVisualization(BaseVisualization):
"""Process and show the image batch and its associated label using matplotlib.
"""
max_cols: int = 4 # maximum number of columns we accept
block_viz_window: bool = True # parameter to allow user to block visualisation windows

@staticmethod
def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray:
out: np.ndarray
if isinstance(img, Image.Image):
out = np.array(img)
elif isinstance(img, torch.Tensor):
out = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
else:
raise TypeError(f"Unknown image type. Got: {type(img)}.")
return out

def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str):
# define the image grid
cols: int = min(num_samples, self.max_cols)
rows: int = num_samples // cols

if not _MATPLOTLIB_AVAILABLE:
raise MisconfigurationException("You need matplotlib to visualise. Please, pip install matplotlib")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise MisconfigurationException("You need matplotlib to visualise. Please, pip install matplotlib")
raise MisconfigurationException("You need matplotlib to visualise. Please, use `pip install matplotlib`")


# create figure and set title
fig, axs = plt.subplots(rows, cols)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raise an exception if matplotlib isn't available here: _MATPLOTLIB_AVAILABLE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, check if that's appropriate

fig.suptitle(title)

for i, ax in enumerate(axs.ravel()):
# unpack images and labels
if isinstance(data, list):
_img, _label = data[i]
elif isinstance(data, tuple):
imgs, labels = data
_img, _label = imgs[i], labels[i]
else:
raise TypeError(f"Unknown data type. Got: {type(data)}.")
# convert images to numpy
_img: np.ndarray = self._to_numpy(_img)
if isinstance(_label, torch.Tensor):
_label = _label.squeeze().tolist()
# show image and set label as subplot title
ax.imshow(_img)
ax.set_title(str(_label))
ax.axis('off')
plt.show(block=self.block_viz_window)

def show_load_sample(self, samples: List[Any], running_stage: RunningStage):
win_title: str = f"{running_stage} - show_load_sample"
self._show_images_and_labels(samples, len(samples), win_title)

def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
win_title: str = f"{running_stage} - show_pre_tensor_transform"
self._show_images_and_labels(samples, len(samples), win_title)

def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
win_title: str = f"{running_stage} - show_to_tensor_transform"
self._show_images_and_labels(samples, len(samples), win_title)

def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
win_title: str = f"{running_stage} - show_post_tensor_transform"
self._show_images_and_labels(samples, len(samples), win_title)

def show_per_batch_transform(self, batch: List[Any], running_stage):
win_title: str = f"{running_stage} - show_per_batch_transform"
self._show_images_and_labels(batch[0], batch[0][0].shape[0], win_title)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ sentencepiece>=0.1.95
filelock # comes with 3rd-party dependency
pycocotools>=2.0.2 ; python_version >= "3.7"
kornia>=0.5.0
matplotlib # used by the visualisation callback
10 changes: 5 additions & 5 deletions tests/data/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from flash.data.base_viz import BaseVisualization
from flash.data.callback import BaseDataFetcher
from flash.data.data_module import DataModule
from flash.data.utils import _STAGES_PREFIX
from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX
from flash.vision import ImageClassificationData


Expand Down Expand Up @@ -87,10 +87,7 @@ def test_base_viz(tmpdir):
(tmpdir / "a").mkdir()
(tmpdir / "b").mkdir()
_rand_image().save(tmpdir / "a" / "a_1.png")
_rand_image().save(tmpdir / "a" / "a_2.png")

_rand_image().save(tmpdir / "b" / "a_1.png")
_rand_image().save(tmpdir / "b" / "a_2.png")

class CustomBaseVisualization(BaseVisualization):

Expand Down Expand Up @@ -148,7 +145,9 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization:
for stage in _STAGES_PREFIX.values():

for _ in range(10):
getattr(dm, f"show_{stage}_batch")(reset=False)
for fcn_name in _PREPROCESS_FUNCS:
fcn = getattr(dm, f"show_{stage}_batch")
fcn(fcn_name, reset=False)

is_predict = stage == "predict"

Expand Down Expand Up @@ -193,3 +192,4 @@ def test_data_loaders_num_workers_to_0(tmpdir):
assert isinstance(iterator, torch.utils.data.dataloader._SingleProcessDataLoaderIter)
iterator = iter(datamodule.train_dataloader())
assert isinstance(iterator, torch.utils.data.dataloader._MultiProcessingDataLoaderIter)
assert datamodule.num_workers == 3
Loading