Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support RTMDet Ins Segmentation Inference #583

Merged
merged 16 commits into from
Mar 2, 2023
31 changes: 31 additions & 0 deletions configs/rtmdet/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
_base_ = './rtmdet_s_syncbn_fast_8xb32-300e_coco.py'

widen_factor = 0.5

model = dict(
bbox_head=dict(
type='RTMDetInsSepBNHead',
head_module=dict(
type='RTMDetInsSepBNHeadModule',
use_sigmoid_cls=True,
widen_factor=widen_factor),
loss_mask=dict(
type='mmdet.DiceLoss', loss_weight=2.0, eps=5e-6,
reduction='mean')),
test_cfg=dict(
multi_label=True,
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100,
mask_thr_binary=0.5))

_base_.test_pipeline[-2] = dict(
type='LoadAnnotations', with_bbox=True, with_mask=True, _scope_='mmdet')

val_dataloader = dict(dataset=dict(pipeline=_base_.test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(metric=['bbox', 'segm'])
test_evaluator = val_evaluator
7 changes: 7 additions & 0 deletions mmyolo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@ def yolov5_collate(data_batch: Sequence,
"""
batch_imgs = []
batch_bboxes_labels = []
batch_masks = []
for i in range(len(data_batch)):
datasamples = data_batch[i]['data_samples']
inputs = data_batch[i]['inputs']
batch_imgs.append(inputs)

gt_bboxes = datasamples.gt_instances.bboxes.tensor
gt_labels = datasamples.gt_instances.labels
if 'masks' in datasamples.gt_instances:
masks = datasamples.gt_instances.masks.to_tensor(
dtype=torch.bool, device=gt_bboxes.device)
batch_masks.append(masks)
batch_idx = gt_labels.new_full((len(gt_labels), 1), i)
bboxes_labels = torch.cat((batch_idx, gt_labels[:, None], gt_bboxes),
dim=1)
Expand All @@ -36,6 +41,8 @@ def yolov5_collate(data_batch: Sequence,
'bboxes_labels': torch.cat(batch_bboxes_labels, 0)
}
}
if len(batch_masks) > 0:
collated_results['data_samples']['masks'] = torch.cat(batch_masks, 0)

if use_ms_training:
collated_results['inputs'] = batch_imgs
Expand Down
6 changes: 4 additions & 2 deletions mmyolo/models/data_preprocessors/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,14 @@ def forward(self, data: dict, training: bool = False) -> dict:
inputs, data_samples = batch_aug(inputs, data_samples)

img_metas = [{'batch_input_shape': inputs.shape[2:]}] * len(inputs)
data_samples = {
data_samples_output = {
'bboxes_labels': data_samples['bboxes_labels'],
'img_metas': img_metas
}
if 'masks' in data_samples:
data_samples_output['masks'] = data_samples['masks']

return {'inputs': inputs, 'data_samples': data_samples}
return {'inputs': inputs, 'data_samples': data_samples_output}


@MODELS.register_module()
Expand Down
4 changes: 3 additions & 1 deletion mmyolo/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ppyoloe_head import PPYOLOEHead, PPYOLOEHeadModule
from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule
from .rtmdet_ins_head import RTMDetInsSepBNHead, RTMDetInsSepBNHeadModule
from .rtmdet_rotated_head import (RTMDetRotatedHead,
RTMDetRotatedSepBNHeadModule)
from .yolov5_head import YOLOv5Head, YOLOv5HeadModule
Expand All @@ -14,5 +15,6 @@
'YOLOv6HeadModule', 'YOLOXHeadModule', 'RTMDetHead',
'RTMDetSepBNHeadModule', 'YOLOv7Head', 'PPYOLOEHead', 'PPYOLOEHeadModule',
'YOLOv7HeadModule', 'YOLOv7p6HeadModule', 'YOLOv8Head', 'YOLOv8HeadModule',
'RTMDetRotatedHead', 'RTMDetRotatedSepBNHeadModule'
'RTMDetRotatedHead', 'RTMDetRotatedSepBNHeadModule', 'RTMDetInsSepBNHead',
'RTMDetInsSepBNHeadModule'
]
Loading