Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add visualisation callback for image classification #228

Merged
merged 30 commits into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7850c17
resolve bug
tchaton Apr 19, 2021
dec2e16
update
tchaton Apr 19, 2021
98b0564
draft version of the image classification callback
edgarriba Apr 19, 2021
18f6339
Merge branch 'master' into feat/viz_callback
edgarriba Apr 19, 2021
108ff3b
Merge branch 'master' into feat/viz_callback
edgarriba Apr 19, 2021
d2b423b
move callback under base class
edgarriba Apr 19, 2021
c2587fd
add crashing tests
edgarriba Apr 20, 2021
c2095be
add more tests
edgarriba Apr 20, 2021
b7d448d
more tests
edgarriba Apr 20, 2021
3dd9b19
more tests
edgarriba Apr 20, 2021
bd42635
fix issues with files directory loading and add better error message
edgarriba Apr 20, 2021
32ac9b3
fix visualisation test
edgarriba Apr 20, 2021
0ef5d2a
fix data outpye type
edgarriba Apr 20, 2021
d444841
fix tests with from_paths
edgarriba Apr 20, 2021
3a0b088
add matplotlib check import
edgarriba Apr 20, 2021
ce64c40
improve tests on from_folders
edgarriba Apr 20, 2021
be47408
add missing imports
edgarriba Apr 20, 2021
b2a19bf
fixed test_classification
edgarriba Apr 20, 2021
48ea3c2
implement all hooks
edgarriba Apr 20, 2021
76457f7
fix more tests
edgarriba Apr 20, 2021
159c002
fix more tests
edgarriba Apr 20, 2021
b339bc9
add matplotlib in reaquirements
edgarriba Apr 20, 2021
0123284
remove useless test
edgarriba Apr 20, 2021
a48f149
Merge branch 'master' into feat/viz_callback
edgarriba Apr 20, 2021
d125c5d
implement function filtering to visualize
edgarriba Apr 20, 2021
21efbca
fix comments
edgarriba Apr 21, 2021
9e8138f
Merge branch 'master' into feat/viz_callback
edgarriba Apr 21, 2021
afeedac
add setter method to block windows to show with matplotlib
edgarriba Apr 22, 2021
0761bfb
sync with master and fix conflicts
edgarriba Apr 22, 2021
f05c0d4
remove unused variable
edgarriba Apr 22, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions flash/data/base_viz.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Sequence
from typing import Any, Dict, List, Sequence, Set

from pytorch_lightning.trainer.states import RunningStage

Expand Down Expand Up @@ -101,7 +101,13 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage) -> None:
"""
Override this function when you want to visualize a composition.
"""
for func_name in _PREPROCESS_FUNCS:
# filter out the functions to visualise
func_name: str = self._fcn_white_list[running_stage]
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
func_names_list: Set[str] = list(set([func_name]) & set(_PREPROCESS_FUNCS))
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
if len(func_names_list) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this check the show_{}_batches in the DataModule.

raise ValueError(f"Invalid function name: {func_name}.")

for func_name in func_names_list:
hook_name = f"show_{func_name}"
if _is_overriden(hook_name, self, BaseVisualization):
getattr(self, hook_name)(batch[func_name], running_stage)
Expand Down
31 changes: 22 additions & 9 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
import platform
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -92,6 +92,9 @@ def __init__(
# this may also trigger data preloading
self.set_running_stages()

# buffer to store the functions to visualise
self._fcn_white_list: Dict[str, Set[str]] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable name isn't clean. Mind finding a name related to viz.


@property
def train_dataset(self) -> Optional[Dataset]:
"""This property returns the train dataset"""
Expand Down Expand Up @@ -164,25 +167,35 @@ def _show_batch(self, stage: RunningStage, reset: bool = True) -> None:
iter_dataloader = self._reset_iterator(stage)
_ = next(iter_dataloader)
data_fetcher: BaseVisualization = self.data_fetcher
# list of functions to visualise
data_fetcher._fcn_white_list = self._fcn_white_list
data_fetcher._show(stage)
if reset:
self.viz.batches[stage] = {}

def show_train_batch(self, reset: bool = True) -> None:
def show_train_batch(self, name: str = 'load_sample', reset: bool = True) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List of hook names should also be supported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can provide the following api:

def show_train_batch(self, name: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None:
    ...

"""This function is used to visualize a batch from the train dataloader."""
self._show_batch(_STAGES_PREFIX[RunningStage.TRAINING], reset=reset)
stage_name: str = _STAGES_PREFIX[RunningStage.TRAINING]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this logic to _show_batch to reduce duplicated code and raise a MisConfigurationError is the provided names aren't in _Preprocess_funcs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self._fcn_white_list[stage_name] = name
self._show_batch(stage_name, reset=reset)

def show_val_batch(self, reset: bool = True) -> None:
def show_val_batch(self, name: str = 'load_sample', reset: bool = True) -> None:
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
"""This function is used to visualize a batch from the validation dataloader."""
self._show_batch(_STAGES_PREFIX[RunningStage.VALIDATING], reset=reset)
stage_name: str = _STAGES_PREFIX[RunningStage.VALIDATING]
self._fcn_white_list[stage_name] = name
self._show_batch(stage_name, reset=reset)

def show_test_batch(self, reset: bool = True) -> None:
def show_test_batch(self, name: str = 'load_sample', reset: bool = True) -> None:
"""This function is used to visualize a batch from the test dataloader."""
self._show_batch(_STAGES_PREFIX[RunningStage.TESTING], reset=reset)
stage_name: str = _STAGES_PREFIX[RunningStage.TESTING]
self._fcn_white_list[stage_name] = name
self._show_batch(stage_name, reset=reset)

def show_predict_batch(self, reset: bool = True) -> None:
def show_predict_batch(self, name: str = 'load_sample', reset: bool = True) -> None:
"""This function is used to visualize a batch from the predict dataloader."""
self._show_batch(_STAGES_PREFIX[RunningStage.PREDICTING], reset=reset)
stage_name: str = _STAGES_PREFIX[RunningStage.PREDICTING]
self._fcn_white_list[stage_name] = name
self._show_batch(stage_name, reset=reset)

@staticmethod
def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, default: Optional[Any] = None) -> Any:
Expand Down
1 change: 1 addition & 0 deletions flash/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
_COCO_AVAILABLE = _module_available("pycocotools")
_TIMM_AVAILABLE = _module_available("timm")
_TORCHVISION_AVAILABLE = _module_available("torchvision")
_MATPLOTLIB_AVAILABLE = _module_available("matplotlib")
128 changes: 115 additions & 13 deletions flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,37 @@
import pathlib
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union

import numpy as np
import torch
import torchvision
from PIL import Image
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data._utils.collate import default_collate
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, make_dataset

from flash.core.utils import _is_overriden
from flash.data.auto_dataset import AutoDataset
from flash.data.base_viz import BaseVisualization # for viz
from flash.data.callback import BaseDataFetcher
from flash.data.data_module import DataModule
from flash.data.data_pipeline import DataPipeline
from flash.data.process import Preprocess
from flash.utils.imports import _KORNIA_AVAILABLE
from flash.data.utils import _PREPROCESS_FUNCS
from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE

if _KORNIA_AVAILABLE:
import kornia.augmentation as K
import kornia.geometry.transform as T
import kornia as K
else:
from torchvision import transforms as T

if _MATPLOTLIB_AVAILABLE:
import matplotlib.pyplot as plt
else:
plt = None


class ImageClassificationPreprocess(Preprocess):

Expand Down Expand Up @@ -74,20 +84,37 @@ def _get_predicting_files(samples: Union[Sequence, str]) -> List[str]:
return files

@classmethod
def _load_data_dir(cls, data: Any, dataset: Optional[AutoDataset] = None) -> List[str]:
def _load_data_dir(cls, data: Any, dataset: Optional[AutoDataset] = None) -> List[Tuple[str, int]]:
# case where we pass a list of files
if isinstance(data, list):
# TODO: define num_classes elsewhere. This is a bad assumption since the list of
# labels might not contain the complete set of ids so that you can infer the total
# number of classes to train in your dataset.
dataset.num_classes = len(data)
out = []
out: List[Tuple[str, int]] = []
for p, label in data:
if os.path.isdir(p):
for f in os.listdir(p):
# TODO: there is an issue here when a path is provided along with labels.
# os.listdir cannot assure the same file order as the passed labels list.
files_list: List[str] = os.listdir(p)
if len(files_list) > 1:
raise ValueError(
f"The provided directory contains more than one file."
f"Directory: {p} -> Contains: {files_list}"
)
for f in files_list:
if has_file_allowed_extension(f, IMG_EXTENSIONS):
out.append([os.path.join(p, f), label])
elif os.path.isfile(p) and has_file_allowed_extension(p, IMG_EXTENSIONS):
elif os.path.isfile(p) and has_file_allowed_extension(str(p), IMG_EXTENSIONS):
out.append([p, label])
else:
raise TypeError(f"Unexpected file path type: {p}.")
return out
else:
classes, class_to_idx = cls._find_classes(data)
# TODO: define num_classes elsewhere. This is a bad assumption since the list of
# labels might not contain the complete set of ids so that you can infer the total
# number of classes to train in your dataset.
dataset.num_classes = len(classes)
return make_dataset(data, class_to_idx, IMG_EXTENSIONS, None)

Expand Down Expand Up @@ -231,6 +258,14 @@ def __init__(
if self._predict_ds:
self.set_dataset_attribute(self._predict_ds, 'num_classes', self.num_classes)

@staticmethod
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
return _MatplotlibVisualization(*args, **kwargs)

def show(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's find another way to do that. This is confusing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can have a better naming here. An alternative but a bit more hacky, is to pass blocking flag across all the different functions until the matplotlib plt.show calls.

"""Method to block matplotlib windows."""
plt.show()

@staticmethod
def _check_transforms(transform: Dict[str, Union[nn.Module, Callable]]) -> Dict[str, Union[nn.Module, Callable]]:
if transform and not isinstance(transform, Dict):
Expand All @@ -247,14 +282,16 @@ def _check_transforms(transform: Dict[str, Union[nn.Module, Callable]]) -> Dict[

@staticmethod
def default_train_transforms():
image_size = ImageClassificationData.image_size
image_size: Tuple[int, int] = ImageClassificationData.image_size
if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1":
# Better approach as all transforms are applied on tensor directly
return {
"to_tensor_transform": torchvision.transforms.ToTensor(),
"post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size), K.RandomHorizontalFlip()),
"post_tensor_transform": nn.Sequential(
K.augmentation.RandomResizedCrop(image_size), K.augmentation.RandomHorizontalFlip()
),
"per_batch_transform_on_device": nn.Sequential(
K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
)
}
else:
Expand All @@ -267,14 +304,14 @@ def default_train_transforms():

@staticmethod
def default_val_transforms():
image_size = ImageClassificationData.image_size
image_size: Tuple[int, int] = ImageClassificationData.image_size
if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1":
# Better approach as all transforms are applied on tensor directly
return {
"to_tensor_transform": torchvision.transforms.ToTensor(),
"post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size)),
"post_tensor_transform": nn.Sequential(K.augmentation.RandomResizedCrop(image_size)),
"per_batch_transform_on_device": nn.Sequential(
K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
)
}
else:
Expand Down Expand Up @@ -521,3 +558,68 @@ def from_filepaths(
seed=seed,
**kwargs
)


class _MatplotlibVisualization(BaseVisualization):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this class public.

"""Process and show the image batch and its associated label using matplotlib.
"""
max_cols: int = 4 # maximum number of columns we accept

@staticmethod
def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray:
out: np.ndarray
if isinstance(img, Image.Image):
out = np.array(img)
elif isinstance(img, torch.Tensor):
out = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
else:
raise TypeError(f"Unknown image type. Got: {type(img)}.")
return out

def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str):
# define the image grid
cols: int = min(num_samples, self.max_cols)
rows: int = num_samples // cols

# create figure and set title
fig, axs = plt.subplots(rows, cols)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raise an exception if matplotlib isn't available here: _MATPLOTLIB_AVAILABLE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, check if that's appropriate

fig.suptitle(title)

for i, ax in enumerate(axs.ravel()):
# unpack images and labels
if isinstance(data, list):
_img, _label = data[i]
elif isinstance(data, tuple):
imgs, labels = data
_img, _label = imgs[i], labels[i]
else:
raise TypeError(f"Unknown data type. Got: {type(data)}.")
# convert images to numpy
_img: np.ndarray = self._to_numpy(_img)
if isinstance(_label, torch.Tensor):
_label = int(_label.item())
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
# show image and set label as subplot title
ax.imshow(_img)
ax.set_title(str(_label))
ax.axis('off')
plt.show(block=False)

def show_load_sample(self, samples: List[Any], running_stage: RunningStage):
win_title: str = f"{running_stage} - show_load_sample"
self._show_images_and_labels(samples, len(samples), win_title)

def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
win_title: str = f"{running_stage} - show_pre_tensor_transform"
self._show_images_and_labels(samples, len(samples), win_title)

def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
win_title: str = f"{running_stage} - show_to_tensor_transform"
self._show_images_and_labels(samples, len(samples), win_title)

def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
win_title: str = f"{running_stage} - show_post_tensor_transform"
self._show_images_and_labels(samples, len(samples), win_title)

def show_per_batch_transform(self, batch: List[Any], running_stage):
win_title: str = f"{running_stage} - show_per_batch_transform"
self._show_images_and_labels(batch[0], batch[0][0].shape[0], win_title)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ sentencepiece>=0.1.95
filelock # comes with 3rd-party dependency
pycocotools>=2.0.2 ; python_version >= "3.7"
kornia>=0.5.0
matplotlib # used by the visualisation callback
11 changes: 5 additions & 6 deletions tests/data/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from flash.data.base_viz import BaseVisualization
from flash.data.callback import BaseDataFetcher
from flash.data.data_module import DataModule
from flash.data.utils import _STAGES_PREFIX
from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX
from flash.vision import ImageClassificationData


Expand Down Expand Up @@ -87,10 +87,7 @@ def test_base_viz(tmpdir):
(tmpdir / "a").mkdir()
(tmpdir / "b").mkdir()
_rand_image().save(tmpdir / "a" / "a_1.png")
_rand_image().save(tmpdir / "a" / "a_2.png")

_rand_image().save(tmpdir / "b" / "a_1.png")
_rand_image().save(tmpdir / "b" / "a_2.png")

class CustomBaseVisualization(BaseVisualization):

Expand Down Expand Up @@ -147,8 +144,9 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization:

for stage in _STAGES_PREFIX.values():

for _ in range(10):
getattr(dm, f"show_{stage}_batch")(reset=False)
for fcn_name in _PREPROCESS_FUNCS:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep the previous one, it was asserting we could iterate more than the dataset length and the iterator was being reset.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And let's add yours too.

fcn = getattr(dm, f"show_{stage}_batch")
fcn(fcn_name, reset=False)

is_predict = stage == "predict"

Expand Down Expand Up @@ -193,3 +191,4 @@ def test_data_loaders_num_workers_to_0(tmpdir):
assert isinstance(iterator, torch.utils.data.dataloader._SingleProcessDataLoaderIter)
iterator = iter(datamodule.train_dataloader())
assert isinstance(iterator, torch.utils.data.dataloader._MultiProcessingDataLoaderIter)
assert datamodule.num_workers == 3
Loading