Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support RTMDet Ins Segmentation Inference #583

Merged
merged 16 commits into from
Mar 2, 2023
31 changes: 31 additions & 0 deletions configs/rtmdet/rtmdet-ins_s_syncbn_fast_8xb32-300e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
_base_ = './rtmdet_s_syncbn_fast_8xb32-300e_coco.py'

widen_factor = 0.5

model = dict(
bbox_head=dict(
type='RTMDetInsSepBNHead',
head_module=dict(
type='RTMDetInsSepBNHeadModule',
use_sigmoid_cls=True,
widen_factor=widen_factor),
loss_mask=dict(
type='mmdet.DiceLoss', loss_weight=2.0, eps=5e-6,
reduction='mean')),
test_cfg=dict(
multi_label=True,
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100,
mask_thr_binary=0.5))

_base_.test_pipeline[-2] = dict(
type='LoadAnnotations', with_bbox=True, with_mask=True, _scope_='mmdet')

val_dataloader = dict(dataset=dict(pipeline=_base_.test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(metric=['bbox', 'segm'])
test_evaluator = val_evaluator
7 changes: 7 additions & 0 deletions mmyolo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@ def yolov5_collate(data_batch: Sequence,
"""
batch_imgs = []
batch_bboxes_labels = []
batch_masks = []
for i in range(len(data_batch)):
datasamples = data_batch[i]['data_samples']
inputs = data_batch[i]['inputs']
batch_imgs.append(inputs)

gt_bboxes = datasamples.gt_instances.bboxes.tensor
gt_labels = datasamples.gt_instances.labels
if 'masks' in datasamples.gt_instances:
masks = datasamples.gt_instances.masks.to_tensor(
dtype=torch.bool, device=gt_bboxes.device)
batch_masks.append(masks)
batch_idx = gt_labels.new_full((len(gt_labels), 1), i)
bboxes_labels = torch.cat((batch_idx, gt_labels[:, None], gt_bboxes),
dim=1)
Expand All @@ -36,6 +41,8 @@ def yolov5_collate(data_batch: Sequence,
'bboxes_labels': torch.cat(batch_bboxes_labels, 0)
}
}
if len(batch_masks) > 0:
collated_results['data_samples']['masks'] = torch.cat(batch_masks, 0)

if use_ms_training:
collated_results['inputs'] = batch_imgs
Expand Down
6 changes: 4 additions & 2 deletions mmyolo/models/data_preprocessors/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,14 @@ def forward(self, data: dict, training: bool = False) -> dict:
inputs, data_samples = batch_aug(inputs, data_samples)

img_metas = [{'batch_input_shape': inputs.shape[2:]}] * len(inputs)
data_samples = {
data_samples_output = {
'bboxes_labels': data_samples['bboxes_labels'],
'img_metas': img_metas
}
if 'masks' in data_samples:
data_samples_output['masks'] = data_samples['masks']

return {'inputs': inputs, 'data_samples': data_samples}
return {'inputs': inputs, 'data_samples': data_samples_output}


@MODELS.register_module()
Expand Down
4 changes: 3 additions & 1 deletion mmyolo/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .rotated_rtmdet_head import (RotatedRTMDetHead,
RotatedRTMDetSepBNHeadModule)
from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule
from .rtmdet_ins_head import RTMDetInsSepBNHead, RTMDetInsSepBNHeadModule
from .yolov5_head import YOLOv5Head, YOLOv5HeadModule
from .yolov6_head import YOLOv6Head, YOLOv6HeadModule
from .yolov7_head import YOLOv7Head, YOLOv7HeadModule, YOLOv7p6HeadModule
Expand All @@ -14,5 +15,6 @@
'YOLOv6HeadModule', 'YOLOXHeadModule', 'RTMDetHead',
'RTMDetSepBNHeadModule', 'YOLOv7Head', 'PPYOLOEHead', 'PPYOLOEHeadModule',
'YOLOv7HeadModule', 'YOLOv7p6HeadModule', 'YOLOv8Head', 'YOLOv8HeadModule',
'RotatedRTMDetHead', 'RotatedRTMDetSepBNHeadModule'
'RotatedRTMDetHead', 'RotatedRTMDetSepBNHeadModule', 'RTMDetInsSepBNHead',
'RTMDetInsSepBNHeadModule'
]
Loading