Skip to content

Commit

Permalink
[feature] Visualizer compatible with MultiTaskDataSample
Browse files Browse the repository at this point in the history
  • Loading branch information
haofengsiji committed Jul 10, 2023
1 parent 7d850df commit afc3d6f
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 28 deletions.
89 changes: 62 additions & 27 deletions mmpretrain/visualization/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mmengine.visualization.utils import img_from_canvas

from mmpretrain.registry import VISUALIZERS
from mmpretrain.structures import DataSample
from mmpretrain.structures import DataSample, MultiTaskDataSample
from .utils import create_figure, get_adaptive_scale


Expand Down Expand Up @@ -114,33 +114,9 @@ def visualize_cls(self,
texts = []
self.set_image(image)

if draw_gt and 'gt_label' in data_sample:
idx = data_sample.gt_label.tolist()
class_labels = [''] * len(idx)
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))]
prefix = 'Ground truth: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))

if draw_pred and 'pred_label' in data_sample:
idx = data_sample.pred_label.tolist()
score_labels = [''] * len(idx)
class_labels = [''] * len(idx)
if draw_score and 'pred_score' in data_sample:
score_labels = [
f', {data_sample.pred_score[i].item():.2f}' for i in idx
]

if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
self.draw_gt(data_sample, classes, draw_gt, texts)

labels = [
str(idx[i]) + score_labels[i] + class_labels[i]
for i in range(len(idx))
]
prefix = 'Prediction: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
self.draw_pred(data_sample, classes, draw_pred, draw_score, texts)

img_scale = get_adaptive_scale(image.shape[:2])
text_cfg = {
Expand All @@ -167,6 +143,65 @@ def visualize_cls(self,

return drawn_img

def draw_pred(self,
data_sample: DataSample,
classes: Optional[Sequence[str]],
draw_pred: bool,
draw_score: bool,
texts: Sequence[str],
parent_task: str = ''):
if isinstance(data_sample, MultiTaskDataSample):
for task in data_sample.tasks:
sub_task = f'{parent_task}_{task}' if parent_task else task
self.draw_pred(
data_sample.get(task), classes, draw_pred, draw_score,
texts, sub_task)
else:
if draw_pred and 'pred_label' in data_sample:
idx = data_sample.pred_label.tolist()
score_labels = [''] * len(idx)
class_labels = [''] * len(idx)
if draw_score and 'pred_score' in data_sample:
score_labels = [
f', {data_sample.pred_score[i].item():.2f}'
for i in idx
]

if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]

labels = [
str(idx[i]) + score_labels[i] + class_labels[i]
for i in range(len(idx))
]
prefix = f'{parent_task} Prediction: ' if parent_task \
else 'Prediction: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))

def draw_gt(self,
data_sample: DataSample,
classes: Optional[Sequence[str]],
draw_gt: bool,
texts: Sequence[str],
parent_task: str = ''):
if isinstance(data_sample, MultiTaskDataSample):
for task in data_sample.tasks:
sub_task = f'{parent_task}_{task}' if parent_task else task
self.draw_gt(
data_sample.get(task), classes, draw_gt, texts, sub_task)
else:
if draw_gt and 'gt_label' in data_sample:
idx = data_sample.gt_label.tolist()
class_labels = [''] * len(idx)
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [
str(idx[i]) + class_labels[i] for i in range(len(idx))
]
prefix = f'{parent_task} Ground truth: ' if parent_task \
else 'Ground truth: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))

@master_only
def visualize_image_retrieval(self,
image: np.ndarray,
Expand Down
42 changes: 41 additions & 1 deletion tests/test_visualization/test_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import torch

from mmpretrain.structures import DataSample
from mmpretrain.structures import DataSample, MultiTaskDataSample
from mmpretrain.visualization import UniversalVisualizer


Expand Down Expand Up @@ -123,6 +123,46 @@ def draw_texts(text, font_sizes, *_, **__):
data_sample,
rescale_factor=2.)

def test_visualize_multitask_cls(self):
image = np.ones((1000, 1000, 3), np.uint8)
gt_label = {'task0': {'task00': 2, 'task01': 1}, 'task1': 1}
data_sample = MultiTaskDataSample()
task_sample = DataSample().set_gt_label(
gt_label['task1']).set_pred_label(1).set_pred_score(
torch.tensor([0.1, 0.8, 0.1]))
data_sample.set_field(task_sample, 'task1')
data_sample.set_field(MultiTaskDataSample(), 'task0')
for task_name in gt_label['task0']:
task_sample = DataSample().set_gt_label(
gt_label['task0'][task_name]).set_pred_label(2).set_pred_score(
torch.tensor([0.1, 0.4, 0.5]))
data_sample.task0.set_field(task_sample, task_name)

# Test show
def mock_show(drawn_img, win_name, wait_time):
self.assertFalse((image == drawn_img).all())
self.assertEqual(win_name, 'test_cls')
self.assertEqual(wait_time, 0)

with patch.object(self.vis, 'show', mock_show):
self.vis.visualize_cls(
image=image,
data_sample=data_sample,
show=True,
name='test_cls',
step=2)

# Test storage backend.
save_file = osp.join(self.tmpdir.name,
'vis_data/vis_image/test_cls_2.png')
self.assertTrue(osp.exists(save_file))

# Test out_file
out_file = osp.join(self.tmpdir.name, 'results_2.png')
self.vis.visualize_cls(
image=image, data_sample=data_sample, out_file=out_file)
self.assertTrue(osp.exists(out_file))

def test_visualize_image_retrieval(self):
image = np.ones((10, 10, 3), np.uint8)
data_sample = DataSample().set_pred_score([0.1, 0.8, 0.1])
Expand Down

0 comments on commit afc3d6f

Please sign in to comment.