Skip to content

Commit

Permalink
[Improvement] browse_dataset.py (open-mmlab#304)
Browse files Browse the repository at this point in the history
* Create browse_transform.py

Upgrate the brow_transform
对pipeline中的transform过程实现了可视化,以及可以将transform中key参数的变化打印出来。

* Update browse_dataset.py

* Delete browse_transform.py

* Update browse_dataset.py

* Update browse_dataset.py

* Update browse_dataset.py

* Update browse_dataset.py

* Update browse_dataset.py

* Update browse_dataset.py

* Update browse_dataset.py

* Update browse_dataset.py

* Update browse_dataset.py

修改了215行result_i = [result['dataset_sample'] for result in intermediate_imgs]通过lint

* fix some error

Co-authored-by: huanghaian <[email protected]>
  • Loading branch information
2 people authored and triple-Mu committed Dec 6, 2022
1 parent 18ea62d commit 8cb655f
Showing 1 changed file with 199 additions and 32 deletions.
231 changes: 199 additions & 32 deletions tools/analysis_tools/browse_dataset.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,64 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import sys
from typing import Tuple

import cv2
import mmcv
import numpy as np
from mmdet.models.utils import mask2ndarray
from mmdet.structures.bbox import BaseBoxes
from mmengine.config import Config, DictAction
from mmengine.dataset import Compose
from mmengine.utils import ProgressBar
from mmengine.visualization import Visualizer

from mmyolo.registry import DATASETS, VISUALIZERS
from mmyolo.utils import register_all_modules


# TODO: Support for printing the change in key of results
def parse_args():
parser = argparse.ArgumentParser(description='Browse a dataset')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--phase',
'-p',
default='train',
type=str,
choices=['train', 'test', 'val'],
help='phase of dataset to visualize, accept "train" "test" and "val".'
' Defaults to "train".')
parser.add_argument(
'--mode',
'-m',
default='transformed',
type=str,
choices=['original', 'transformed', 'pipeline'],
help='display mode; display original pictures or '
'transformed pictures or comparison pictures. "original" '
'means show images load from disk; "transformed" means '
'to show images after transformed; "pipeline" means show all '
'the intermediate images. Defaults to "transformed".')
parser.add_argument(
'--output-dir',
default=None,
type=str,
help='If there is no display interface, you can save it')
help='If there is no display interface, you can save it.')
parser.add_argument('--not-show', default=False, action='store_true')
parser.add_argument(
'--show-number',
'-n',
type=int,
default=sys.maxsize,
help='number of images selected to visualize, '
'must bigger than 0. if the number is bigger than length '
'of dataset, show all the images in dataset; '
'default "sys.maxsize", show all images in dataset')
parser.add_argument(
'--show-interval',
'-i',
type=float,
default=3,
help='the interval of show (s)')
Expand All @@ -40,49 +76,180 @@ def parse_args():
return args


def _get_adaptive_scale(img_shape: Tuple[int, int],
min_scale: float = 0.3,
max_scale: float = 3.0) -> float:
"""Get adaptive scale according to image shape.
The target scale depends on the the short edge length of the image. If the
short edge length equals 224, the output is 1.0. And output linear
scales according the short edge length. You can also specify the minimum
scale and the maximum scale to limit the linear scale.
Args:
img_shape (Tuple[int, int]): The shape of the canvas image.
min_scale (int): The minimum scale. Defaults to 0.3.
max_scale (int): The maximum scale. Defaults to 3.0.
Returns:
int: The adaptive scale.
"""
short_edge_length = min(img_shape)
scale = short_edge_length / 224.
return min(max(scale, min_scale), max_scale)


def make_grid(imgs, names):
"""Concat list of pictures into a single big picture, align height here."""
visualizer = Visualizer.get_current_instance()
ori_shapes = [img.shape[:2] for img in imgs]
max_height = int(max(img.shape[0] for img in imgs) * 1.1)
min_width = min(img.shape[1] for img in imgs)
horizontal_gap = min_width // 10
img_scale = _get_adaptive_scale((max_height, min_width))

texts = []
text_positions = []
start_x = 0
for i, img in enumerate(imgs):
pad_height = (max_height - img.shape[0]) // 2
pad_width = horizontal_gap // 2
# make border
imgs[i] = cv2.copyMakeBorder(
img,
pad_height,
max_height - img.shape[0] - pad_height + int(img_scale * 30 * 2),
pad_width,
pad_width,
cv2.BORDER_CONSTANT,
value=(255, 255, 255))
texts.append(f'{"execution: "}{i}\n{names[i]}\n{ori_shapes[i]}')
text_positions.append(
[start_x + img.shape[1] // 2 + pad_width, max_height])
start_x += img.shape[1] + horizontal_gap

display_img = np.concatenate(imgs, axis=1)
visualizer.set_image(display_img)
img_scale = _get_adaptive_scale(display_img.shape[:2])
visualizer.draw_texts(
texts,
positions=np.array(text_positions),
font_sizes=img_scale * 7,
colors='black',
horizontal_alignments='center',
font_families='monospace')
return visualizer.get_image()


class InspectCompose(Compose):
"""Compose multiple transforms sequentially.
And record "img" field of all results in one list.
"""

def __init__(self, transforms, intermediate_imgs):
super().__init__(transforms=transforms)
self.intermediate_imgs = intermediate_imgs

def __call__(self, data):
if 'img' in data:
self.intermediate_imgs.append({
'name': 'original',
'img': data['img'].copy()
})
self.ptransforms = [
self.transforms[i] for i in range(len(self.transforms) - 1)
]
for t in self.ptransforms:
data = t(data)
# Keep the same meta_keys in the PackDetInputs
self.transforms[-1].meta_keys = [key for key in data]
data_sample = self.transforms[-1](data)
if data is None:
return None
if 'img' in data:
self.intermediate_imgs.append({
'name':
t.__class__.__name__,
'dataset_sample':
data_sample['data_samples']
})
return data


def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

# register all modules in mmdet into the registries
# register all modules in mmyolo into the registries
register_all_modules()

dataset = DATASETS.build(cfg.train_dataloader.dataset)
dataset_cfg = cfg.get(args.phase + '_dataloader').get('dataset')
dataset = DATASETS.build(dataset_cfg)
visualizer = VISUALIZERS.build(cfg.visualizer)
visualizer.dataset_meta = dataset.metainfo

progress_bar = ProgressBar(len(dataset))
for item in dataset:
img = item['inputs'].permute(1, 2, 0).numpy()
data_samples = item['data_samples'].numpy()
gt_instances = data_samples.gt_instances
img_path = osp.basename(item['data_samples'].img_path)

out_file = osp.join(
args.output_dir,
osp.basename(img_path)) if args.output_dir is not None else None

img = img[..., [2, 1, 0]] # bgr to rgb
gt_bboxes = gt_instances.get('bboxes', None)
if gt_bboxes is not None and isinstance(gt_bboxes, BaseBoxes):
gt_instances.bboxes = gt_bboxes.tensor
gt_masks = gt_instances.get('masks', None)
if gt_masks is not None:
masks = mask2ndarray(gt_masks)
gt_instances.masks = masks.astype(np.bool)
data_samples.gt_instances = gt_instances

visualizer.add_datasample(
osp.basename(img_path),
img,
data_samples,
draw_pred=False,
show=not args.not_show,
wait_time=args.show_interval,
out_file=out_file)
intermediate_imgs = []
# TODO: The dataset wrapper occasion is not considered here
dataset.pipeline = InspectCompose(dataset.pipeline.transforms,
intermediate_imgs)

# init visualization image number
assert args.show_number > 0
display_number = min(args.show_number, len(dataset))

progress_bar = ProgressBar(display_number)
for i, item in zip(range(display_number), dataset):
image_i = []
result_i = [result['dataset_sample'] for result in intermediate_imgs]
for k, datasample in enumerate(result_i):
image = datasample.img
gt_instances = datasample.gt_instances
image = image[..., [2, 1, 0]] # bgr to rgb
gt_bboxes = gt_instances.get('bboxes', None)
if gt_bboxes is not None and isinstance(gt_bboxes, BaseBoxes):
gt_instances.bboxes = gt_bboxes.tensor
gt_masks = gt_instances.get('masks', None)
if gt_masks is not None:
masks = mask2ndarray(gt_masks)
gt_instances.masks = masks.astype(np.bool)
datasample.gt_instances = gt_instances
# get filename from dataset or just use index as filename
visualizer.add_datasample(
'result',
image,
datasample,
draw_pred=False,
draw_gt=True,
show=False)
image_show = visualizer.get_image()
image_i.append(image_show)

if args.mode == 'original':
image = image_i[0]
elif args.mode == 'transformed':
image = image_i[-1]
else:
image = make_grid([result for result in image_i],
[result['name'] for result in intermediate_imgs])

if hasattr(datasample, 'img_path'):
filename = osp.basename(datasample.img_path)
else:
# some dataset have not image path
filename = f'{i}.jpg'
out_file = osp.join(args.output_dir,
filename) if args.output_dir is not None else None

if out_file is not None:
mmcv.imwrite(image[..., ::-1], out_file)

if not args.not_show:
visualizer.show(
image, win_name=filename, wait_time=args.show_interval)

intermediate_imgs.clear()
progress_bar.update()


Expand Down

0 comments on commit 8cb655f

Please sign in to comment.