Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add visualisation callback for image classification #228

Merged
merged 30 commits into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7850c17
resolve bug
tchaton Apr 19, 2021
dec2e16
update
tchaton Apr 19, 2021
98b0564
draft version of the image classification callback
edgarriba Apr 19, 2021
18f6339
Merge branch 'master' into feat/viz_callback
edgarriba Apr 19, 2021
108ff3b
Merge branch 'master' into feat/viz_callback
edgarriba Apr 19, 2021
d2b423b
move callback under base class
edgarriba Apr 19, 2021
c2587fd
add crashing tests
edgarriba Apr 20, 2021
c2095be
add more tests
edgarriba Apr 20, 2021
b7d448d
more tests
edgarriba Apr 20, 2021
3dd9b19
more tests
edgarriba Apr 20, 2021
bd42635
fix issues with files directory loading and add better error message
edgarriba Apr 20, 2021
32ac9b3
fix visualisation test
edgarriba Apr 20, 2021
0ef5d2a
fix data outpye type
edgarriba Apr 20, 2021
d444841
fix tests with from_paths
edgarriba Apr 20, 2021
3a0b088
add matplotlib check import
edgarriba Apr 20, 2021
ce64c40
improve tests on from_folders
edgarriba Apr 20, 2021
be47408
add missing imports
edgarriba Apr 20, 2021
b2a19bf
fixed test_classification
edgarriba Apr 20, 2021
48ea3c2
implement all hooks
edgarriba Apr 20, 2021
76457f7
fix more tests
edgarriba Apr 20, 2021
159c002
fix more tests
edgarriba Apr 20, 2021
b339bc9
add matplotlib in reaquirements
edgarriba Apr 20, 2021
0123284
remove useless test
edgarriba Apr 20, 2021
a48f149
Merge branch 'master' into feat/viz_callback
edgarriba Apr 20, 2021
d125c5d
implement function filtering to visualize
edgarriba Apr 20, 2021
21efbca
fix comments
edgarriba Apr 21, 2021
9e8138f
Merge branch 'master' into feat/viz_callback
edgarriba Apr 21, 2021
afeedac
add setter method to block windows to show with matplotlib
edgarriba Apr 22, 2021
0761bfb
sync with master and fix conflicts
edgarriba Apr 22, 2021
f05c0d4
remove unused variable
edgarriba Apr 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flash/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES, OBJ_DETECTION_BACKBONES
from flash.vision.classification import ImageClassificationData, ImageClassifier
from flash.vision.classification import ImageClassificationData, ImageClassificationDataVisualizer, ImageClassifier
from flash.vision.detection import ObjectDetectionData, ObjectDetector
from flash.vision.embedding import ImageEmbedder
2 changes: 1 addition & 1 deletion flash/vision/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from flash.vision.classification.data import ImageClassificationData
from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataVisualizer
from flash.vision.classification.model import ImageClassifier
66 changes: 58 additions & 8 deletions flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,23 @@
from torch.utils.data._utils.collate import default_collate
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, make_dataset

from flash.core.utils import _is_overriden
from flash.data.auto_dataset import AutoDataset
from flash.data.base_viz import BaseVisualization # for viz
from flash.data.data_module import DataModule
from flash.data.data_pipeline import DataPipeline
from flash.data.process import Preprocess
from flash.data.utils import _PREPROCESS_FUNCS
from flash.utils.imports import _KORNIA_AVAILABLE

if _KORNIA_AVAILABLE:
import kornia.augmentation as K
import kornia.geometry.transform as T
import kornia as K
else:
from torchvision import transforms as T

# TODO(edgar): check if is available: if _MATPLOTLIB_AVAILABLE
import matplotlib.pyplot as plt


class ImageClassificationPreprocess(Preprocess):

Expand Down Expand Up @@ -253,14 +258,16 @@ def _check_transforms(transform: Dict[str, Union[nn.Module, Callable]]) -> Dict[

@staticmethod
def default_train_transforms():
image_size = ImageClassificationData.image_size
image_size: Tuple[int, int] = ImageClassificationData.image_size
if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1":
# Better approach as all transforms are applied on tensor directly
return {
"to_tensor_transform": torchvision.transforms.ToTensor(),
"post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size), K.RandomHorizontalFlip()),
"post_tensor_transform": nn.Sequential(
K.augmentation.RandomResizedCrop(image_size), K.augmentation.RandomHorizontalFlip()
),
"per_batch_transform_on_device": nn.Sequential(
K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
)
}
else:
Expand All @@ -273,14 +280,14 @@ def default_train_transforms():

@staticmethod
def default_val_transforms():
image_size = ImageClassificationData.image_size
image_size: Tuple[int, int] = ImageClassificationData.image_size
if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1":
# Better approach as all transforms are applied on tensor directly
return {
"to_tensor_transform": torchvision.transforms.ToTensor(),
"post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size)),
"post_tensor_transform": nn.Sequential(K.augmentation.RandomResizedCrop(image_size)),
"per_batch_transform_on_device": nn.Sequential(
K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
)
}
else:
Expand Down Expand Up @@ -527,3 +534,46 @@ def from_filepaths(
seed=seed,
**kwargs
)


class _CustomBaseVisualization(BaseVisualization):
edgarriba marked this conversation as resolved.
Show resolved Hide resolved
"""Process and show the image batch and its associated label.
"""
max_cols: int = 4 # maximum number of columns we accept

def show_per_batch_transform(self, batch: List[Any], running_stage):
edgarriba marked this conversation as resolved.
Show resolved Hide resolved
# get the batch data
img, label = batch[0]

# define the image grid
cols: int = min(img.shape[0], self.max_cols)
rows: int = img.shape[0] // cols

fig, axs = plt.subplots(rows, cols)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raise an exception if matplotlib isn't available here: _MATPLOTLIB_AVAILABLE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, check if that's appropriate

fig.suptitle(str(running_stage))

for i, ax in enumerate(axs.ravel()):
_img, _label = img[i], label[i]
ax.imshow(K.tensor_to_image(_img))
edgarriba marked this conversation as resolved.
Show resolved Hide resolved
ax.set_title(_label)
ax.axis('off')
plt.show()


class ImageClassificationDataVisualizer(ImageClassificationData):
"""Base class to be used for visualizing the Image Classificatio data.

Usage:

data_viz = ImageClassificationDataVisualizer.from_filepaths(
train_filepaths=["path/img1.png", "path/img2.png"],
train_labels=[0, 1],
batch_size=2,
)
data_viz.show_train_batch()

"""

@staticmethod
def configure_data_fetcher(*args, **kwargs) -> 'BaseDataFetcher':
return _CustomBaseVisualization(*args, **kwargs)
27 changes: 26 additions & 1 deletion tests/vision/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from PIL import Image

from flash.data.data_utils import labels_from_categorical_csv
from flash.vision import ImageClassificationData
from flash.vision import ImageClassificationData, ImageClassificationDataVisualizer


def _dummy_image_loader(_):
Expand Down Expand Up @@ -96,6 +96,31 @@ def test_from_filepaths(tmpdir):
assert labels.shape == (1, )


def test_from_filepaths_visualise(tmpdir):
tmpdir = Path(tmpdir)

(tmpdir / "a").mkdir()
(tmpdir / "b").mkdir()
_rand_image().save(tmpdir / "a" / "a_1.png")
_rand_image().save(tmpdir / "a" / "a_2.png")

_rand_image().save(tmpdir / "b" / "a_1.png")
_rand_image().save(tmpdir / "b" / "a_2.png")

data_viz = ImageClassificationDataVisualizer.from_filepaths(
train_filepaths=[tmpdir / "a", tmpdir / "b"],
train_labels=[0, 1],
val_filepaths=[tmpdir / "a", tmpdir / "b"],
val_labels=[0, 1],
test_filepaths=[tmpdir / "a", tmpdir / "b"],
test_labels=[0, 1],
batch_size=2,
)
data_viz.show_train_batch()
data_viz.show_val_batch()
data_viz.show_test_batch()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to decide how we validate this functionality during tests since it involves matplotlib visualization

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



def test_categorical_csv_labels(tmpdir):
train_dir = Path(tmpdir / "some_dataset")
train_dir.mkdir()
Expand Down