Skip to content

Commit

Permalink
[Enhance] Remove mmcv.iou3d from mmdetecion3d (#1403)
Browse files Browse the repository at this point in the history
* remove iou3d_boxes_overlap_bev_forward

* remove nms ops from mmcv.ops.iou3d from mmdetection3d
  • Loading branch information
filaPro authored Apr 20, 2022
1 parent 4b73f74 commit 6dd5d32
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 56 deletions.
21 changes: 9 additions & 12 deletions mmdet3d/core/bbox/structures/base_box3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

import numpy as np
import torch
from mmcv._ext import iou3d_boxes_overlap_bev_forward as boxes_overlap_bev_gpu
from mmcv.ops import points_in_boxes_all, points_in_boxes_part
from mmcv.ops import box_iou_rotated, points_in_boxes_all, points_in_boxes_part

from .utils import limit_period, xywhr2xyxyr
from .utils import limit_period


class BaseInstance3DBoxes(object):
Expand Down Expand Up @@ -447,7 +446,7 @@ def overlaps(cls, boxes1, boxes2, mode='iou'):
mode (str, optional): Mode of iou calculation. Defaults to 'iou'.
Returns:
torch.Tensor: Calculated iou of boxes' heights.
torch.Tensor: Calculated 3D overlaps of the boxes.
"""
assert isinstance(boxes1, BaseInstance3DBoxes)
assert isinstance(boxes2, BaseInstance3DBoxes)
Expand All @@ -464,15 +463,13 @@ def overlaps(cls, boxes1, boxes2, mode='iou'):
# height overlap
overlaps_h = cls.height_overlaps(boxes1, boxes2)

# obtain BEV boxes in XYXYR format
boxes1_bev = xywhr2xyxyr(boxes1.bev)
boxes2_bev = xywhr2xyxyr(boxes2.bev)

# bev overlap
overlaps_bev = boxes1_bev.new_zeros(
(boxes1_bev.shape[0], boxes2_bev.shape[0])).cuda() # (N, M)
boxes_overlap_bev_gpu(boxes1_bev.contiguous().cuda(),
boxes2_bev.contiguous().cuda(), overlaps_bev)
iou2d = box_iou_rotated(boxes1.bev, boxes2.bev)
areas1 = (boxes1.bev[:, 2] * boxes1.bev[:, 3]).unsqueeze(1).expand(
rows, cols)
areas2 = (boxes2.bev[:, 2] * boxes2.bev[:, 3]).unsqueeze(0).expand(
rows, cols)
overlaps_bev = iou2d * (areas1 + areas2) / (1 + iou2d)

# 3d overlaps
overlaps_3d = overlaps_bev.to(boxes1.device) * overlaps_h
Expand Down
6 changes: 4 additions & 2 deletions mmdet3d/core/post_processing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from mmdet.core.post_processing import (merge_aug_bboxes, merge_aug_masks,
merge_aug_proposals, merge_aug_scores,
multiclass_nms)
from .box3d_nms import aligned_3d_nms, box3d_multiclass_nms, circle_nms
from .box3d_nms import (aligned_3d_nms, box3d_multiclass_nms, circle_nms,
nms_bev, nms_normal_bev)
from .merge_augs import merge_aug_bboxes_3d

__all__ = [
'multiclass_nms', 'merge_aug_proposals', 'merge_aug_bboxes',
'merge_aug_scores', 'merge_aug_masks', 'box3d_multiclass_nms',
'aligned_3d_nms', 'merge_aug_bboxes_3d', 'circle_nms'
'aligned_3d_nms', 'merge_aug_bboxes_3d', 'circle_nms', 'nms_bev',
'nms_normal_bev'
]
69 changes: 65 additions & 4 deletions mmdet3d/core/post_processing/box3d_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import numba
import numpy as np
import torch
from mmcv.ops import nms_bev as nms_gpu
from mmcv.ops import nms_normal_bev as nms_normal_gpu
from mmcv.ops import nms, nms_rotated

from ..bbox import xywhr2xyxyr


def box3d_multiclass_nms(mlvl_bboxes,
Expand Down Expand Up @@ -61,9 +62,9 @@ def box3d_multiclass_nms(mlvl_bboxes,
_bboxes_for_nms = mlvl_bboxes_for_nms[cls_inds, :]

if cfg.use_rotate_nms:
nms_func = nms_gpu
nms_func = nms_bev
else:
nms_func = nms_normal_gpu
nms_func = nms_normal_bev

selected = nms_func(_bboxes_for_nms, _scores, cfg.nms_thr)
_mlvl_bboxes = mlvl_bboxes[cls_inds, :]
Expand Down Expand Up @@ -224,3 +225,63 @@ def circle_nms(dets, thresh, post_max_size=83):
return keep[:post_max_size]

return keep


# This function duplicates functionality of mmcv.ops.iou_3d.nms_bev
# from mmcv<=1.5, but using cuda ops from mmcv.ops.nms.nms_rotated.
# Nms api will be unified in mmdetection3d one day.
def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None):
"""NMS function GPU implementation (for BEV boxes). The overlap of two
boxes for IoU calculation is defined as the exact overlapping area of the
two boxes. In this function, one can also set ``pre_max_size`` and
``post_max_size``.
Args:
boxes (torch.Tensor): Input boxes with the shape of [N, 5]
([x1, y1, x2, y2, ry]).
scores (torch.Tensor): Scores of boxes with the shape of [N].
thresh (float): Overlap threshold of NMS.
pre_max_size (int, optional): Max size of boxes before NMS.
Default: None.
post_max_size (int, optional): Max size of boxes after NMS.
Default: None.
Returns:
torch.Tensor: Indexes after NMS.
"""
assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]'
order = scores.sort(0, descending=True)[1]
if pre_max_size is not None:
order = order[:pre_max_size]
boxes = boxes[order].contiguous()
# xyxyr -> back to xywhr
# note: better skip this step before nms_bev call in the future
boxes = torch.stack(
((boxes[:, 0] + boxes[:, 2]) / 2, (boxes[:, 1] + boxes[:, 3]) / 2,
boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1], boxes[:, 4]),
dim=-1)

keep = nms_rotated(boxes, scores, thresh)[1]
if post_max_size is not None:
keep = keep[:post_max_size]
return keep


# This function duplicates functionality of mmcv.ops.iou_3d.nms_normal_bev
# from mmcv<=1.5, but using cuda ops from mmcv.ops.nms.nms.
# Nms api will be unified in mmdetection3d one day.
def nms_normal_bev(boxes, scores, thresh):
"""Normal NMS function GPU implementation (for BEV boxes). The overlap of
two boxes for IoU calculation is defined as the exact overlapping area of
the two boxes WITH their yaw angle set to 0.
Args:
boxes (torch.Tensor): Input boxes with shape (N, 5).
scores (torch.Tensor): Scores of predicted boxes with shape (N).
thresh (float): Overlap threshold of NMS.
Returns:
torch.Tensor: Remaining indices with scores in descending order.
"""
assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]'
return nms(xywhr2xyxyr(boxes)[:, :-1], scores, thresh)[1]
7 changes: 3 additions & 4 deletions mmdet3d/core/post_processing/merge_augs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import nms_bev as nms_gpu
from mmcv.ops import nms_normal_bev as nms_normal_gpu

from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
from ..bbox import bbox3d2result, bbox3d_mapping_back, xywhr2xyxyr


Expand Down Expand Up @@ -52,9 +51,9 @@ def merge_aug_bboxes_3d(aug_results, img_metas, test_cfg):

# TODO: use a more elegent way to deal with nms
if test_cfg.use_rotate_nms:
nms_func = nms_gpu
nms_func = nms_bev
else:
nms_func = nms_normal_gpu
nms_func = nms_normal_bev

merged_bboxes = []
merged_scores = []
Expand Down
8 changes: 4 additions & 4 deletions mmdet3d/models/dense_heads/centerpoint_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import torch
from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.ops import nms_bev as nms_gpu
from mmcv.runner import BaseModule, force_fp32
from torch import nn

from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius,
xywhr2xyxyr)
from mmdet3d.core.post_processing import nms_bev
from mmdet3d.models import builder
from mmdet3d.models.builder import HEADS, build_loss
from mmdet3d.models.utils import clip_sigmoid
Expand Down Expand Up @@ -747,9 +747,9 @@ def get_task_detections(self, num_class_with_bg, batch_cls_preds,
for i, (box_preds, cls_preds, cls_labels) in enumerate(
zip(batch_reg_preds, batch_cls_preds, batch_cls_labels)):

# Apply NMS in birdeye view
# Apply NMS in bird eye view

# get highest score per prediction, than apply nms
# get the highest score per prediction, then apply nms
# to remove overlapped box.
if num_class_with_bg == 1:
top_scores = cls_preds.squeeze(-1)
Expand Down Expand Up @@ -778,7 +778,7 @@ def get_task_detections(self, num_class_with_bg, batch_cls_preds,
box_preds[:, :], self.bbox_coder.code_size).bev)
# the nms in 3d detection just remove overlap boxes.

selected = nms_gpu(
selected = nms_bev(
boxes_for_nms,
top_scores,
thresh=self.test_cfg['nms_thr'],
Expand Down
10 changes: 3 additions & 7 deletions mmdet3d/models/dense_heads/parta2_rpn_head.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from __future__ import division

import numpy as np
import torch
from mmcv.ops import nms_bev as nms_gpu
from mmcv.ops import nms_normal_bev as nms_normal_gpu
from mmcv.runner import force_fp32

from mmdet3d.core import limit_period, xywhr2xyxyr
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
from mmdet.models import HEADS
from .anchor3d_head import Anchor3DHead

Expand Down Expand Up @@ -261,9 +258,9 @@ def class_agnostic_nms(self, mlvl_bboxes, mlvl_bboxes_for_nms,
_scores = mlvl_max_scores[score_thr_inds]
_bboxes_for_nms = mlvl_bboxes_for_nms[score_thr_inds, :]
if cfg.use_rotate_nms:
nms_func = nms_gpu
nms_func = nms_bev
else:
nms_func = nms_normal_gpu
nms_func = nms_normal_bev
selected = nms_func(_bboxes_for_nms, _scores, cfg.nms_thr)

_mlvl_bboxes = mlvl_bboxes[score_thr_inds, :]
Expand All @@ -288,7 +285,6 @@ def class_agnostic_nms(self, mlvl_bboxes, mlvl_bboxes_for_nms,
scores = torch.cat(scores, dim=0)
cls_scores = torch.cat(cls_scores, dim=0)
labels = torch.cat(labels, dim=0)
dir_scores = torch.cat(dir_scores, dim=0)
if bboxes.shape[0] > max_num:
_, inds = scores.sort(descending=True)
inds = inds[:max_num]
Expand Down
16 changes: 7 additions & 9 deletions mmdet3d/models/dense_heads/point_rpn_head.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.ops import nms_bev as nms_gpu
from mmcv.ops import nms_normal_bev as nms_normal_gpu
from mmcv.runner import BaseModule, force_fp32
from torch import nn as nn

from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
LiDARInstance3DBoxes)
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS, build_loss

Expand All @@ -19,7 +18,7 @@ class PointRPNHead(BaseModule):
num_classes (int): Number of classes.
train_cfg (dict): Train configs.
test_cfg (dict): Test configs.
pred_layer_cfg (dict, optional): Config of classfication and
pred_layer_cfg (dict, optional): Config of classification and
regression prediction layers. Defaults to None.
enlarge_width (float, optional): Enlarge bbox for each side to ignore
close points. Defaults to 0.1.
Expand Down Expand Up @@ -121,7 +120,7 @@ def forward(self, feat_dict):
batch_size, -1, self._get_cls_out_channels())
point_box_preds = self.reg_layers(feat_reg).reshape(
batch_size, -1, self._get_reg_out_channels())
return (point_box_preds, point_cls_preds)
return point_box_preds, point_cls_preds

@force_fp32(apply_to=('bbox_preds'))
def loss(self,
Expand Down Expand Up @@ -159,7 +158,7 @@ def loss(self,
semantic_targets = mask_targets
semantic_targets[negative_mask] = self.num_classes
semantic_points_label = semantic_targets
# for ignore, but now we do not have ignore label
# for ignore, but now we do not have ignored label
semantic_loss_weight = negative_mask.float() + positive_mask.float()
semantic_loss = self.cls_loss(semantic_points,
semantic_points_label.reshape(-1),
Expand Down Expand Up @@ -220,7 +219,7 @@ def get_targets_single(self, points, gt_bboxes_3d, gt_labels_3d):
gt_bboxes_3d = gt_bboxes_3d[valid_gt]
gt_labels_3d = gt_labels_3d[valid_gt]

# transform the bbox coordinate to the pointcloud coordinate
# transform the bbox coordinate to the point cloud coordinate
gt_bboxes_3d_tensor = gt_bboxes_3d.tensor.clone()
gt_bboxes_3d_tensor[..., 2] += gt_bboxes_3d_tensor[..., 5] / 2

Expand All @@ -233,7 +232,6 @@ def get_targets_single(self, points, gt_bboxes_3d, gt_labels_3d):
points[..., 0:3], mask_targets)

positive_mask = (points_mask.max(1)[0] > 0)
negative_mask = (points_mask.max(1)[0] == 0)
# add ignore_mask
extend_gt_bboxes_3d = gt_bboxes_3d.enlarged_box(self.enlarge_width)
points_mask, _ = self._assign_targets_by_points_inside(
Expand Down Expand Up @@ -297,9 +295,9 @@ def class_agnostic_nms(self, obj_scores, sem_scores, bbox, points,
nms_cfg = self.test_cfg.nms_cfg if not self.training \
else self.train_cfg.nms_cfg
if nms_cfg.use_rotate_nms:
nms_func = nms_gpu
nms_func = nms_bev
else:
nms_func = nms_normal_gpu
nms_func = nms_normal_bev

num_bbox = bbox.shape[0]
bbox = input_meta['box_type_3d'](
Expand Down
7 changes: 3 additions & 4 deletions mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import torch
from mmcv.cnn import ConvModule, normal_init
from mmcv.ops import SparseConvTensor, SparseMaxPool3d, SparseSequential
from mmcv.ops import nms_bev as nms_gpu
from mmcv.ops import nms_normal_bev as nms_normal_gpu
from mmcv.runner import BaseModule
from torch import nn as nn

from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes,
rotation_3d_in_axis, xywhr2xyxyr)
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
from mmdet3d.models.builder import build_loss
from mmdet3d.ops import make_sparse_convmodule
from mmdet.core import build_bbox_coder, multi_apply
Expand Down Expand Up @@ -582,9 +581,9 @@ def multi_class_nms(self,
torch.Tensor: Selected indices.
"""
if use_rotate_nms:
nms_func = nms_gpu
nms_func = nms_bev
else:
nms_func = nms_normal_gpu
nms_func = nms_normal_bev

assert box_probs.shape[
1] == self.num_classes, f'box_probs shape: {str(box_probs.shape)}'
Expand Down
14 changes: 6 additions & 8 deletions mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import torch
from mmcv.cnn import ConvModule, normal_init
from mmcv.cnn.bricks import build_conv_layer
from mmcv.ops import nms_bev as nms_gpu
from mmcv.ops import nms_normal_bev as nms_normal_gpu
from mmcv.runner import BaseModule
from torch import nn as nn

from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes,
rotation_3d_in_axis, xywhr2xyxyr)
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
from mmdet3d.models.builder import build_loss
from mmdet3d.ops import build_sa_module
from mmdet.core import build_bbox_coder, multi_apply
Expand Down Expand Up @@ -239,7 +238,7 @@ def forward(self, feats):
rcnn_reg = self.conv_reg(x_reg)
rcnn_cls = rcnn_cls.transpose(1, 2).contiguous().squeeze(dim=1)
rcnn_reg = rcnn_reg.transpose(1, 2).contiguous().squeeze(dim=1)
return (rcnn_cls, rcnn_reg)
return rcnn_cls, rcnn_reg

def loss(self, cls_score, bbox_pred, rois, labels, bbox_targets,
pos_gt_bboxes, reg_mask, label_weights, bbox_weights):
Expand Down Expand Up @@ -483,7 +482,7 @@ def get_bboxes(self,
local_roi_boxes[..., 0:3] = 0
rcnn_boxes3d = self.bbox_coder.decode(local_roi_boxes, bbox_pred)
rcnn_boxes3d[..., 0:3] = rotation_3d_in_axis(
rcnn_boxes3d[..., 0:3].unsqueeze(1), (roi_ry), axis=2).squeeze(1)
rcnn_boxes3d[..., 0:3].unsqueeze(1), roi_ry, axis=2).squeeze(1)
rcnn_boxes3d[:, 0:3] += roi_xyz

# post processing
Expand All @@ -492,7 +491,6 @@ def get_bboxes(self,
cur_class_labels = class_labels[batch_id]
cur_cls_score = cls_score[roi_batch_id == batch_id].view(-1)

cur_box_prob = cls_score[batch_id]
cur_box_prob = cur_cls_score.unsqueeze(1)
cur_rcnn_boxes3d = rcnn_boxes3d[roi_batch_id == batch_id]
keep = self.multi_class_nms(cur_box_prob, cur_rcnn_boxes3d,
Expand Down Expand Up @@ -524,7 +522,7 @@ def multi_class_nms(self,
merging these two functions in the future.
Args:
box_probs (torch.Tensor): Predicted boxes probabitilies in
box_probs (torch.Tensor): Predicted boxes probabilities in
shape (N,).
box_preds (torch.Tensor): Predicted boxes in shape (N, 7+C).
score_thr (float): Threshold of scores.
Expand All @@ -537,9 +535,9 @@ def multi_class_nms(self,
torch.Tensor: Selected indices.
"""
if use_rotate_nms:
nms_func = nms_gpu
nms_func = nms_bev
else:
nms_func = nms_normal_gpu
nms_func = nms_normal_bev

assert box_probs.shape[
1] == self.num_classes, f'box_probs shape: {str(box_probs.shape)}'
Expand Down
Loading

0 comments on commit 6dd5d32

Please sign in to comment.