Skip to content

Commit 3f38b02

Browse files
committed
remove nms ops from mmcv.ops.iou3d from mmdetection3d
1 parent 52b1924 commit 3f38b02

File tree

8 files changed

+95
-42
lines changed

8 files changed

+95
-42
lines changed

mmdet3d/core/post_processing/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
from mmdet.core.post_processing import (merge_aug_bboxes, merge_aug_masks,
33
merge_aug_proposals, merge_aug_scores,
44
multiclass_nms)
5-
from .box3d_nms import aligned_3d_nms, box3d_multiclass_nms, circle_nms
5+
from .box3d_nms import (aligned_3d_nms, box3d_multiclass_nms, circle_nms,
6+
nms_bev, nms_normal_bev)
67
from .merge_augs import merge_aug_bboxes_3d
78

89
__all__ = [
910
'multiclass_nms', 'merge_aug_proposals', 'merge_aug_bboxes',
1011
'merge_aug_scores', 'merge_aug_masks', 'box3d_multiclass_nms',
11-
'aligned_3d_nms', 'merge_aug_bboxes_3d', 'circle_nms'
12+
'aligned_3d_nms', 'merge_aug_bboxes_3d', 'circle_nms', 'nms_bev',
13+
'nms_normal_bev'
1214
]

mmdet3d/core/post_processing/box3d_nms.py

+65-4
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import numba
33
import numpy as np
44
import torch
5-
from mmcv.ops import nms_bev as nms_gpu
6-
from mmcv.ops import nms_normal_bev as nms_normal_gpu
5+
from mmcv.ops import nms, nms_rotated
6+
7+
from ..bbox import xywhr2xyxyr
78

89

910
def box3d_multiclass_nms(mlvl_bboxes,
@@ -61,9 +62,9 @@ def box3d_multiclass_nms(mlvl_bboxes,
6162
_bboxes_for_nms = mlvl_bboxes_for_nms[cls_inds, :]
6263

6364
if cfg.use_rotate_nms:
64-
nms_func = nms_gpu
65+
nms_func = nms_bev
6566
else:
66-
nms_func = nms_normal_gpu
67+
nms_func = nms_normal_bev
6768

6869
selected = nms_func(_bboxes_for_nms, _scores, cfg.nms_thr)
6970
_mlvl_bboxes = mlvl_bboxes[cls_inds, :]
@@ -224,3 +225,63 @@ def circle_nms(dets, thresh, post_max_size=83):
224225
return keep[:post_max_size]
225226

226227
return keep
228+
229+
230+
# This function duplicates functionality of mmcv.ops.iou_3d.nms_bev
231+
# from mmcv<=1.5, but using cuda ops from mmcv.ops.nms.nms_rotated.
232+
# Nms api will be unified in mmdetection3d one day.
233+
def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None):
234+
"""NMS function GPU implementation (for BEV boxes). The overlap of two
235+
boxes for IoU calculation is defined as the exact overlapping area of the
236+
two boxes. In this function, one can also set ``pre_max_size`` and
237+
``post_max_size``.
238+
239+
Args:
240+
boxes (torch.Tensor): Input boxes with the shape of [N, 5]
241+
([x1, y1, x2, y2, ry]).
242+
scores (torch.Tensor): Scores of boxes with the shape of [N].
243+
thresh (float): Overlap threshold of NMS.
244+
pre_max_size (int, optional): Max size of boxes before NMS.
245+
Default: None.
246+
post_max_size (int, optional): Max size of boxes after NMS.
247+
Default: None.
248+
249+
Returns:
250+
torch.Tensor: Indexes after NMS.
251+
"""
252+
assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]'
253+
order = scores.sort(0, descending=True)[1]
254+
if pre_max_size is not None:
255+
order = order[:pre_max_size]
256+
boxes = boxes[order].contiguous()
257+
# xyxyr -> back to xywhr
258+
# note: better skip this step before nms_bev call in the future
259+
boxes = torch.stack(
260+
((boxes[:, 0] + boxes[:, 2]) / 2, (boxes[:, 1] + boxes[:, 3]) / 2,
261+
boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1], boxes[:, 4]),
262+
dim=-1)
263+
264+
keep = nms_rotated(boxes, scores, thresh)[1]
265+
if post_max_size is not None:
266+
keep = keep[:post_max_size]
267+
return keep
268+
269+
270+
# This function duplicates functionality of mmcv.ops.iou_3d.nms_normal_bev
271+
# from mmcv<=1.5, but using cuda ops from mmcv.ops.nms.nms.
272+
# Nms api will be unified in mmdetection3d one day.
273+
def nms_normal_bev(boxes, scores, thresh):
274+
"""Normal NMS function GPU implementation (for BEV boxes). The overlap of
275+
two boxes for IoU calculation is defined as the exact overlapping area of
276+
the two boxes WITH their yaw angle set to 0.
277+
278+
Args:
279+
boxes (torch.Tensor): Input boxes with shape (N, 5).
280+
scores (torch.Tensor): Scores of predicted boxes with shape (N).
281+
thresh (float): Overlap threshold of NMS.
282+
283+
Returns:
284+
torch.Tensor: Remaining indices with scores in descending order.
285+
"""
286+
assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]'
287+
return nms(xywhr2xyxyr(boxes)[:, :-1], scores, thresh)[1]

mmdet3d/core/post_processing/merge_augs.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import torch
3-
from mmcv.ops import nms_bev as nms_gpu
4-
from mmcv.ops import nms_normal_bev as nms_normal_gpu
53

4+
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
65
from ..bbox import bbox3d2result, bbox3d_mapping_back, xywhr2xyxyr
76

87

@@ -52,9 +51,9 @@ def merge_aug_bboxes_3d(aug_results, img_metas, test_cfg):
5251

5352
# TODO: use a more elegent way to deal with nms
5453
if test_cfg.use_rotate_nms:
55-
nms_func = nms_gpu
54+
nms_func = nms_bev
5655
else:
57-
nms_func = nms_normal_gpu
56+
nms_func = nms_normal_bev
5857

5958
merged_bboxes = []
6059
merged_scores = []

mmdet3d/models/dense_heads/centerpoint_head.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33

44
import torch
55
from mmcv.cnn import ConvModule, build_conv_layer
6-
from mmcv.ops import nms_bev as nms_gpu
76
from mmcv.runner import BaseModule, force_fp32
87
from torch import nn
98

109
from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius,
1110
xywhr2xyxyr)
11+
from mmdet3d.core.post_processing import nms_bev
1212
from mmdet3d.models import builder
1313
from mmdet3d.models.builder import HEADS, build_loss
1414
from mmdet3d.models.utils import clip_sigmoid
@@ -747,9 +747,9 @@ def get_task_detections(self, num_class_with_bg, batch_cls_preds,
747747
for i, (box_preds, cls_preds, cls_labels) in enumerate(
748748
zip(batch_reg_preds, batch_cls_preds, batch_cls_labels)):
749749

750-
# Apply NMS in birdeye view
750+
# Apply NMS in bird eye view
751751

752-
# get highest score per prediction, than apply nms
752+
# get the highest score per prediction, then apply nms
753753
# to remove overlapped box.
754754
if num_class_with_bg == 1:
755755
top_scores = cls_preds.squeeze(-1)
@@ -778,7 +778,7 @@ def get_task_detections(self, num_class_with_bg, batch_cls_preds,
778778
box_preds[:, :], self.bbox_coder.code_size).bev)
779779
# the nms in 3d detection just remove overlap boxes.
780780

781-
selected = nms_gpu(
781+
selected = nms_bev(
782782
boxes_for_nms,
783783
top_scores,
784784
thresh=self.test_cfg['nms_thr'],

mmdet3d/models/dense_heads/parta2_rpn_head.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from __future__ import division
3-
42
import numpy as np
53
import torch
6-
from mmcv.ops import nms_bev as nms_gpu
7-
from mmcv.ops import nms_normal_bev as nms_normal_gpu
84
from mmcv.runner import force_fp32
95

106
from mmdet3d.core import limit_period, xywhr2xyxyr
7+
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
118
from mmdet.models import HEADS
129
from .anchor3d_head import Anchor3DHead
1310

@@ -261,9 +258,9 @@ def class_agnostic_nms(self, mlvl_bboxes, mlvl_bboxes_for_nms,
261258
_scores = mlvl_max_scores[score_thr_inds]
262259
_bboxes_for_nms = mlvl_bboxes_for_nms[score_thr_inds, :]
263260
if cfg.use_rotate_nms:
264-
nms_func = nms_gpu
261+
nms_func = nms_bev
265262
else:
266-
nms_func = nms_normal_gpu
263+
nms_func = nms_normal_bev
267264
selected = nms_func(_bboxes_for_nms, _scores, cfg.nms_thr)
268265

269266
_mlvl_bboxes = mlvl_bboxes[score_thr_inds, :]
@@ -288,7 +285,6 @@ def class_agnostic_nms(self, mlvl_bboxes, mlvl_bboxes_for_nms,
288285
scores = torch.cat(scores, dim=0)
289286
cls_scores = torch.cat(cls_scores, dim=0)
290287
labels = torch.cat(labels, dim=0)
291-
dir_scores = torch.cat(dir_scores, dim=0)
292288
if bboxes.shape[0] > max_num:
293289
_, inds = scores.sort(descending=True)
294290
inds = inds[:max_num]

mmdet3d/models/dense_heads/point_rpn_head.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import torch
3-
from mmcv.ops import nms_bev as nms_gpu
4-
from mmcv.ops import nms_normal_bev as nms_normal_gpu
53
from mmcv.runner import BaseModule, force_fp32
64
from torch import nn as nn
75

86
from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
97
LiDARInstance3DBoxes)
8+
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
109
from mmdet.core import build_bbox_coder, multi_apply
1110
from mmdet.models import HEADS, build_loss
1211

@@ -19,7 +18,7 @@ class PointRPNHead(BaseModule):
1918
num_classes (int): Number of classes.
2019
train_cfg (dict): Train configs.
2120
test_cfg (dict): Test configs.
22-
pred_layer_cfg (dict, optional): Config of classfication and
21+
pred_layer_cfg (dict, optional): Config of classification and
2322
regression prediction layers. Defaults to None.
2423
enlarge_width (float, optional): Enlarge bbox for each side to ignore
2524
close points. Defaults to 0.1.
@@ -121,7 +120,7 @@ def forward(self, feat_dict):
121120
batch_size, -1, self._get_cls_out_channels())
122121
point_box_preds = self.reg_layers(feat_reg).reshape(
123122
batch_size, -1, self._get_reg_out_channels())
124-
return (point_box_preds, point_cls_preds)
123+
return point_box_preds, point_cls_preds
125124

126125
@force_fp32(apply_to=('bbox_preds'))
127126
def loss(self,
@@ -159,7 +158,7 @@ def loss(self,
159158
semantic_targets = mask_targets
160159
semantic_targets[negative_mask] = self.num_classes
161160
semantic_points_label = semantic_targets
162-
# for ignore, but now we do not have ignore label
161+
# for ignore, but now we do not have ignored label
163162
semantic_loss_weight = negative_mask.float() + positive_mask.float()
164163
semantic_loss = self.cls_loss(semantic_points,
165164
semantic_points_label.reshape(-1),
@@ -220,7 +219,7 @@ def get_targets_single(self, points, gt_bboxes_3d, gt_labels_3d):
220219
gt_bboxes_3d = gt_bboxes_3d[valid_gt]
221220
gt_labels_3d = gt_labels_3d[valid_gt]
222221

223-
# transform the bbox coordinate to the pointcloud coordinate
222+
# transform the bbox coordinate to the point cloud coordinate
224223
gt_bboxes_3d_tensor = gt_bboxes_3d.tensor.clone()
225224
gt_bboxes_3d_tensor[..., 2] += gt_bboxes_3d_tensor[..., 5] / 2
226225

@@ -233,7 +232,6 @@ def get_targets_single(self, points, gt_bboxes_3d, gt_labels_3d):
233232
points[..., 0:3], mask_targets)
234233

235234
positive_mask = (points_mask.max(1)[0] > 0)
236-
negative_mask = (points_mask.max(1)[0] == 0)
237235
# add ignore_mask
238236
extend_gt_bboxes_3d = gt_bboxes_3d.enlarged_box(self.enlarge_width)
239237
points_mask, _ = self._assign_targets_by_points_inside(
@@ -297,9 +295,9 @@ def class_agnostic_nms(self, obj_scores, sem_scores, bbox, points,
297295
nms_cfg = self.test_cfg.nms_cfg if not self.training \
298296
else self.train_cfg.nms_cfg
299297
if nms_cfg.use_rotate_nms:
300-
nms_func = nms_gpu
298+
nms_func = nms_bev
301299
else:
302-
nms_func = nms_normal_gpu
300+
nms_func = nms_normal_bev
303301

304302
num_bbox = bbox.shape[0]
305303
bbox = input_meta['box_type_3d'](

mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
import torch
44
from mmcv.cnn import ConvModule, normal_init
55
from mmcv.ops import SparseConvTensor, SparseMaxPool3d, SparseSequential
6-
from mmcv.ops import nms_bev as nms_gpu
7-
from mmcv.ops import nms_normal_bev as nms_normal_gpu
86
from mmcv.runner import BaseModule
97
from torch import nn as nn
108

119
from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes,
1210
rotation_3d_in_axis, xywhr2xyxyr)
11+
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
1312
from mmdet3d.models.builder import build_loss
1413
from mmdet3d.ops import make_sparse_convmodule
1514
from mmdet.core import build_bbox_coder, multi_apply
@@ -582,9 +581,9 @@ def multi_class_nms(self,
582581
torch.Tensor: Selected indices.
583582
"""
584583
if use_rotate_nms:
585-
nms_func = nms_gpu
584+
nms_func = nms_bev
586585
else:
587-
nms_func = nms_normal_gpu
586+
nms_func = nms_normal_bev
588587

589588
assert box_probs.shape[
590589
1] == self.num_classes, f'box_probs shape: {str(box_probs.shape)}'

mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
import torch
44
from mmcv.cnn import ConvModule, normal_init
55
from mmcv.cnn.bricks import build_conv_layer
6-
from mmcv.ops import nms_bev as nms_gpu
7-
from mmcv.ops import nms_normal_bev as nms_normal_gpu
86
from mmcv.runner import BaseModule
97
from torch import nn as nn
108

119
from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes,
1210
rotation_3d_in_axis, xywhr2xyxyr)
11+
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
1312
from mmdet3d.models.builder import build_loss
1413
from mmdet3d.ops import build_sa_module
1514
from mmdet.core import build_bbox_coder, multi_apply
@@ -239,7 +238,7 @@ def forward(self, feats):
239238
rcnn_reg = self.conv_reg(x_reg)
240239
rcnn_cls = rcnn_cls.transpose(1, 2).contiguous().squeeze(dim=1)
241240
rcnn_reg = rcnn_reg.transpose(1, 2).contiguous().squeeze(dim=1)
242-
return (rcnn_cls, rcnn_reg)
241+
return rcnn_cls, rcnn_reg
243242

244243
def loss(self, cls_score, bbox_pred, rois, labels, bbox_targets,
245244
pos_gt_bboxes, reg_mask, label_weights, bbox_weights):
@@ -483,7 +482,7 @@ def get_bboxes(self,
483482
local_roi_boxes[..., 0:3] = 0
484483
rcnn_boxes3d = self.bbox_coder.decode(local_roi_boxes, bbox_pred)
485484
rcnn_boxes3d[..., 0:3] = rotation_3d_in_axis(
486-
rcnn_boxes3d[..., 0:3].unsqueeze(1), (roi_ry), axis=2).squeeze(1)
485+
rcnn_boxes3d[..., 0:3].unsqueeze(1), roi_ry, axis=2).squeeze(1)
487486
rcnn_boxes3d[:, 0:3] += roi_xyz
488487

489488
# post processing
@@ -492,7 +491,6 @@ def get_bboxes(self,
492491
cur_class_labels = class_labels[batch_id]
493492
cur_cls_score = cls_score[roi_batch_id == batch_id].view(-1)
494493

495-
cur_box_prob = cls_score[batch_id]
496494
cur_box_prob = cur_cls_score.unsqueeze(1)
497495
cur_rcnn_boxes3d = rcnn_boxes3d[roi_batch_id == batch_id]
498496
keep = self.multi_class_nms(cur_box_prob, cur_rcnn_boxes3d,
@@ -524,7 +522,7 @@ def multi_class_nms(self,
524522
merging these two functions in the future.
525523
526524
Args:
527-
box_probs (torch.Tensor): Predicted boxes probabitilies in
525+
box_probs (torch.Tensor): Predicted boxes probabilities in
528526
shape (N,).
529527
box_preds (torch.Tensor): Predicted boxes in shape (N, 7+C).
530528
score_thr (float): Threshold of scores.
@@ -537,9 +535,9 @@ def multi_class_nms(self,
537535
torch.Tensor: Selected indices.
538536
"""
539537
if use_rotate_nms:
540-
nms_func = nms_gpu
538+
nms_func = nms_bev
541539
else:
542-
nms_func = nms_normal_gpu
540+
nms_func = nms_normal_bev
543541

544542
assert box_probs.shape[
545543
1] == self.num_classes, f'box_probs shape: {str(box_probs.shape)}'

0 commit comments

Comments
 (0)