This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 211
Add visualisation callback for image classification #228
Merged
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
7850c17
resolve bug
tchaton dec2e16
update
tchaton 98b0564
draft version of the image classification callback
edgarriba 18f6339
Merge branch 'master' into feat/viz_callback
edgarriba 108ff3b
Merge branch 'master' into feat/viz_callback
edgarriba d2b423b
move callback under base class
edgarriba c2587fd
add crashing tests
edgarriba c2095be
add more tests
edgarriba b7d448d
more tests
edgarriba 3dd9b19
more tests
edgarriba bd42635
fix issues with files directory loading and add better error message
edgarriba 32ac9b3
fix visualisation test
edgarriba 0ef5d2a
fix data outpye type
edgarriba d444841
fix tests with from_paths
edgarriba 3a0b088
add matplotlib check import
edgarriba ce64c40
improve tests on from_folders
edgarriba be47408
add missing imports
edgarriba b2a19bf
fixed test_classification
edgarriba 48ea3c2
implement all hooks
edgarriba 76457f7
fix more tests
edgarriba 159c002
fix more tests
edgarriba b339bc9
add matplotlib in reaquirements
edgarriba 0123284
remove useless test
edgarriba a48f149
Merge branch 'master' into feat/viz_callback
edgarriba d125c5d
implement function filtering to visualize
edgarriba 21efbca
fix comments
edgarriba 9e8138f
Merge branch 'master' into feat/viz_callback
edgarriba afeedac
add setter method to block windows to show with matplotlib
edgarriba 0761bfb
sync with master and fix conflicts
edgarriba f05c0d4
remove unused variable
edgarriba File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -15,28 +15,37 @@ | |||||
import pathlib | ||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union | ||||||
|
||||||
import numpy as np | ||||||
import torch | ||||||
import torchvision | ||||||
from PIL import Image | ||||||
from pytorch_lightning.trainer.states import RunningStage | ||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||||||
from torch import nn | ||||||
from torch.utils.data import Dataset | ||||||
from torch.utils.data._utils.collate import default_collate | ||||||
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, make_dataset | ||||||
|
||||||
from flash.core.classification import ClassificationState | ||||||
from flash.core.utils import _is_overriden | ||||||
from flash.data.auto_dataset import AutoDataset | ||||||
from flash.data.base_viz import BaseVisualization # for viz | ||||||
from flash.data.callback import BaseDataFetcher | ||||||
from flash.data.data_module import DataModule | ||||||
from flash.data.process import Preprocess | ||||||
from flash.utils.imports import _KORNIA_AVAILABLE | ||||||
from flash.data.utils import _PREPROCESS_FUNCS | ||||||
from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE | ||||||
|
||||||
if _KORNIA_AVAILABLE: | ||||||
import kornia.augmentation as K | ||||||
import kornia.geometry.transform as T | ||||||
import kornia as K | ||||||
else: | ||||||
from torchvision import transforms as T | ||||||
|
||||||
if _MATPLOTLIB_AVAILABLE: | ||||||
import matplotlib.pyplot as plt | ||||||
else: | ||||||
plt = None | ||||||
|
||||||
|
||||||
class ImageClassificationPreprocess(Preprocess): | ||||||
|
||||||
|
@@ -93,9 +102,11 @@ def default_train_transforms(self): | |||||
# Better approach as all transforms are applied on tensor directly | ||||||
return { | ||||||
"to_tensor_transform": torchvision.transforms.ToTensor(), | ||||||
"post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size), K.RandomHorizontalFlip()), | ||||||
"post_tensor_transform": nn.Sequential( | ||||||
K.augmentation.RandomResizedCrop(image_size), K.augmentation.RandomHorizontalFlip() | ||||||
), | ||||||
"per_batch_transform_on_device": nn.Sequential( | ||||||
K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), | ||||||
K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), | ||||||
) | ||||||
} | ||||||
else: | ||||||
|
@@ -112,9 +123,9 @@ def default_val_transforms(self): | |||||
# Better approach as all transforms are applied on tensor directly | ||||||
return { | ||||||
"to_tensor_transform": torchvision.transforms.ToTensor(), | ||||||
"post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size)), | ||||||
"post_tensor_transform": nn.Sequential(K.augmentation.RandomResizedCrop(image_size)), | ||||||
"per_batch_transform_on_device": nn.Sequential( | ||||||
K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), | ||||||
K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), | ||||||
) | ||||||
} | ||||||
else: | ||||||
|
@@ -159,18 +170,34 @@ def _load_data_dir( | |||||
dataset: Optional[AutoDataset] = None, | ||||||
) -> Tuple[Optional[List[str]], List[Tuple[str, int]]]: | ||||||
if isinstance(data, list): | ||||||
# TODO: define num_classes elsewhere. This is a bad assumption since the list of | ||||||
# labels might not contain the complete set of ids so that you can infer the total | ||||||
# number of classes to train in your dataset. | ||||||
dataset.num_classes = len(data) | ||||||
out = [] | ||||||
out: List[Tuple[str, int]] = [] | ||||||
for p, label in data: | ||||||
if os.path.isdir(p): | ||||||
for f in os.listdir(p): | ||||||
# TODO: there is an issue here when a path is provided along with labels. | ||||||
# os.listdir cannot assure the same file order as the passed labels list. | ||||||
files_list: List[str] = os.listdir(p) | ||||||
if len(files_list) > 1: | ||||||
raise ValueError( | ||||||
f"The provided directory contains more than one file." | ||||||
f"Directory: {p} -> Contains: {files_list}" | ||||||
) | ||||||
for f in files_list: | ||||||
if has_file_allowed_extension(f, IMG_EXTENSIONS): | ||||||
out.append([os.path.join(p, f), label]) | ||||||
elif os.path.isfile(p) and has_file_allowed_extension(p, IMG_EXTENSIONS): | ||||||
elif os.path.isfile(p) and has_file_allowed_extension(str(p), IMG_EXTENSIONS): | ||||||
out.append([p, label]) | ||||||
else: | ||||||
raise TypeError(f"Unexpected file path type: {p}.") | ||||||
return None, out | ||||||
else: | ||||||
classes, class_to_idx = cls._find_classes(data) | ||||||
# TODO: define num_classes elsewhere. This is a bad assumption since the list of | ||||||
# labels might not contain the complete set of ids so that you can infer the total | ||||||
# number of classes to train in your dataset. | ||||||
dataset.num_classes = len(classes) | ||||||
return classes, make_dataset(data, class_to_idx, IMG_EXTENSIONS, None) | ||||||
|
||||||
|
@@ -318,6 +345,14 @@ def __init__( | |||||
if self._predict_ds: | ||||||
self.set_dataset_attribute(self._predict_ds, 'num_classes', self.num_classes) | ||||||
|
||||||
def set_block_viz_window(self, value: bool) -> None: | ||||||
"""Setter method to switch on/off matplotlib to pop up windows.""" | ||||||
self.data_fetcher.block_viz_window = value | ||||||
|
||||||
@staticmethod | ||||||
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: | ||||||
return MatplotlibVisualization(*args, **kwargs) | ||||||
|
||||||
@property | ||||||
def num_classes(self) -> int: | ||||||
if self._num_classes is None: | ||||||
|
@@ -494,3 +529,72 @@ def from_filepaths( | |||||
seed=seed, | ||||||
**kwargs | ||||||
) | ||||||
|
||||||
|
||||||
class MatplotlibVisualization(BaseVisualization): | ||||||
"""Process and show the image batch and its associated label using matplotlib. | ||||||
""" | ||||||
max_cols: int = 4 # maximum number of columns we accept | ||||||
block_viz_window: bool = True # parameter to allow user to block visualisation windows | ||||||
|
||||||
@staticmethod | ||||||
def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray: | ||||||
out: np.ndarray | ||||||
if isinstance(img, Image.Image): | ||||||
out = np.array(img) | ||||||
elif isinstance(img, torch.Tensor): | ||||||
out = img.squeeze(0).permute(1, 2, 0).cpu().numpy() | ||||||
else: | ||||||
raise TypeError(f"Unknown image type. Got: {type(img)}.") | ||||||
return out | ||||||
|
||||||
def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str): | ||||||
# define the image grid | ||||||
cols: int = min(num_samples, self.max_cols) | ||||||
rows: int = num_samples // cols | ||||||
|
||||||
if not _MATPLOTLIB_AVAILABLE: | ||||||
raise MisconfigurationException("You need matplotlib to visualise. Please, pip install matplotlib") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
# create figure and set title | ||||||
fig, axs = plt.subplots(rows, cols) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Raise an exception if matplotlib isn't available here: _MATPLOTLIB_AVAILABLE There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, check if that's appropriate |
||||||
fig.suptitle(title) | ||||||
|
||||||
for i, ax in enumerate(axs.ravel()): | ||||||
# unpack images and labels | ||||||
if isinstance(data, list): | ||||||
_img, _label = data[i] | ||||||
elif isinstance(data, tuple): | ||||||
imgs, labels = data | ||||||
_img, _label = imgs[i], labels[i] | ||||||
else: | ||||||
raise TypeError(f"Unknown data type. Got: {type(data)}.") | ||||||
# convert images to numpy | ||||||
_img: np.ndarray = self._to_numpy(_img) | ||||||
if isinstance(_label, torch.Tensor): | ||||||
_label = _label.squeeze().tolist() | ||||||
# show image and set label as subplot title | ||||||
ax.imshow(_img) | ||||||
ax.set_title(str(_label)) | ||||||
ax.axis('off') | ||||||
plt.show(block=self.block_viz_window) | ||||||
|
||||||
def show_load_sample(self, samples: List[Any], running_stage: RunningStage): | ||||||
win_title: str = f"{running_stage} - show_load_sample" | ||||||
self._show_images_and_labels(samples, len(samples), win_title) | ||||||
|
||||||
def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage): | ||||||
win_title: str = f"{running_stage} - show_pre_tensor_transform" | ||||||
self._show_images_and_labels(samples, len(samples), win_title) | ||||||
|
||||||
def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage): | ||||||
win_title: str = f"{running_stage} - show_to_tensor_transform" | ||||||
self._show_images_and_labels(samples, len(samples), win_title) | ||||||
|
||||||
def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): | ||||||
win_title: str = f"{running_stage} - show_post_tensor_transform" | ||||||
self._show_images_and_labels(samples, len(samples), win_title) | ||||||
|
||||||
def show_per_batch_transform(self, batch: List[Any], running_stage): | ||||||
win_title: str = f"{running_stage} - show_per_batch_transform" | ||||||
self._show_images_and_labels(batch[0], batch[0][0].shape[0], win_title) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this logic to
_show_batch
to reduce duplicated code and raise a MisConfigurationError is the provided names aren't in _Preprocess_funcs.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done