-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
Copy pathbox3d_nms.py
288 lines (250 loc) · 10.5 KB
/
box3d_nms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
# Copyright (c) OpenMMLab. All rights reserved.
import numba
import numpy as np
import torch
from mmcv.ops import nms, nms_rotated
def box3d_multiclass_nms(mlvl_bboxes,
mlvl_bboxes_for_nms,
mlvl_scores,
score_thr,
max_num,
cfg,
mlvl_dir_scores=None,
mlvl_attr_scores=None,
mlvl_bboxes2d=None):
"""Multi-class NMS for 3D boxes. The IoU used for NMS is defined as the 2D
IoU between BEV boxes.
Args:
mlvl_bboxes (torch.Tensor): Multi-level boxes with shape (N, M).
M is the dimensions of boxes.
mlvl_bboxes_for_nms (torch.Tensor): Multi-level boxes with shape
(N, 5) ([x1, y1, x2, y2, ry]). N is the number of boxes.
The coordinate system of the BEV boxes is counterclockwise.
mlvl_scores (torch.Tensor): Multi-level boxes with shape
(N, C + 1). N is the number of boxes. C is the number of classes.
score_thr (float): Score threshold to filter boxes with low
confidence.
max_num (int): Maximum number of boxes will be kept.
cfg (dict): Configuration dict of NMS.
mlvl_dir_scores (torch.Tensor, optional): Multi-level scores
of direction classifier. Defaults to None.
mlvl_attr_scores (torch.Tensor, optional): Multi-level scores
of attribute classifier. Defaults to None.
mlvl_bboxes2d (torch.Tensor, optional): Multi-level 2D bounding
boxes. Defaults to None.
Returns:
tuple[torch.Tensor]: Return results after nms, including 3D
bounding boxes, scores, labels, direction scores, attribute
scores (optional) and 2D bounding boxes (optional).
"""
# do multi class nms
# the fg class id range: [0, num_classes-1]
num_classes = mlvl_scores.shape[1] - 1
bboxes = []
scores = []
labels = []
dir_scores = []
attr_scores = []
bboxes2d = []
for i in range(0, num_classes):
# get bboxes and scores of this class
cls_inds = mlvl_scores[:, i] > score_thr
if not cls_inds.any():
continue
_scores = mlvl_scores[cls_inds, i]
_bboxes_for_nms = mlvl_bboxes_for_nms[cls_inds, :]
if cfg.use_rotate_nms:
nms_func = nms_bev
else:
nms_func = nms_normal_bev
selected = nms_func(_bboxes_for_nms, _scores, cfg.nms_thr)
_mlvl_bboxes = mlvl_bboxes[cls_inds, :]
bboxes.append(_mlvl_bboxes[selected])
scores.append(_scores[selected])
cls_label = mlvl_bboxes.new_full((len(selected), ),
i,
dtype=torch.long)
labels.append(cls_label)
if mlvl_dir_scores is not None:
_mlvl_dir_scores = mlvl_dir_scores[cls_inds]
dir_scores.append(_mlvl_dir_scores[selected])
if mlvl_attr_scores is not None:
_mlvl_attr_scores = mlvl_attr_scores[cls_inds]
attr_scores.append(_mlvl_attr_scores[selected])
if mlvl_bboxes2d is not None:
_mlvl_bboxes2d = mlvl_bboxes2d[cls_inds]
bboxes2d.append(_mlvl_bboxes2d[selected])
if bboxes:
bboxes = torch.cat(bboxes, dim=0)
scores = torch.cat(scores, dim=0)
labels = torch.cat(labels, dim=0)
if mlvl_dir_scores is not None:
dir_scores = torch.cat(dir_scores, dim=0)
if mlvl_attr_scores is not None:
attr_scores = torch.cat(attr_scores, dim=0)
if mlvl_bboxes2d is not None:
bboxes2d = torch.cat(bboxes2d, dim=0)
if bboxes.shape[0] > max_num:
_, inds = scores.sort(descending=True)
inds = inds[:max_num]
bboxes = bboxes[inds, :]
labels = labels[inds]
scores = scores[inds]
if mlvl_dir_scores is not None:
dir_scores = dir_scores[inds]
if mlvl_attr_scores is not None:
attr_scores = attr_scores[inds]
if mlvl_bboxes2d is not None:
bboxes2d = bboxes2d[inds]
else:
bboxes = mlvl_scores.new_zeros((0, mlvl_bboxes.size(-1)))
scores = mlvl_scores.new_zeros((0, ))
labels = mlvl_scores.new_zeros((0, ), dtype=torch.long)
if mlvl_dir_scores is not None:
dir_scores = mlvl_scores.new_zeros((0, ))
if mlvl_attr_scores is not None:
attr_scores = mlvl_scores.new_zeros((0, ))
if mlvl_bboxes2d is not None:
bboxes2d = mlvl_scores.new_zeros((0, 4))
results = (bboxes, scores, labels)
if mlvl_dir_scores is not None:
results = results + (dir_scores, )
if mlvl_attr_scores is not None:
results = results + (attr_scores, )
if mlvl_bboxes2d is not None:
results = results + (bboxes2d, )
return results
def aligned_3d_nms(boxes, scores, classes, thresh):
"""3D NMS for aligned boxes.
Args:
boxes (torch.Tensor): Aligned box with shape [n, 6].
scores (torch.Tensor): Scores of each box.
classes (torch.Tensor): Class of each box.
thresh (float): IoU threshold for nms.
Returns:
torch.Tensor: Indices of selected boxes.
"""
x1 = boxes[:, 0]
y1 = boxes[:, 1]
z1 = boxes[:, 2]
x2 = boxes[:, 3]
y2 = boxes[:, 4]
z2 = boxes[:, 5]
area = (x2 - x1) * (y2 - y1) * (z2 - z1)
zero = boxes.new_zeros(1, )
score_sorted = torch.argsort(scores)
pick = []
while (score_sorted.shape[0] != 0):
last = score_sorted.shape[0]
i = score_sorted[-1]
pick.append(i)
xx1 = torch.max(x1[i], x1[score_sorted[:last - 1]])
yy1 = torch.max(y1[i], y1[score_sorted[:last - 1]])
zz1 = torch.max(z1[i], z1[score_sorted[:last - 1]])
xx2 = torch.min(x2[i], x2[score_sorted[:last - 1]])
yy2 = torch.min(y2[i], y2[score_sorted[:last - 1]])
zz2 = torch.min(z2[i], z2[score_sorted[:last - 1]])
classes1 = classes[i]
classes2 = classes[score_sorted[:last - 1]]
inter_l = torch.max(zero, xx2 - xx1)
inter_w = torch.max(zero, yy2 - yy1)
inter_h = torch.max(zero, zz2 - zz1)
inter = inter_l * inter_w * inter_h
iou = inter / (area[i] + area[score_sorted[:last - 1]] - inter)
iou = iou * (classes1 == classes2).float()
score_sorted = score_sorted[torch.nonzero(
iou <= thresh, as_tuple=False).flatten()]
indices = boxes.new_tensor(pick, dtype=torch.long)
return indices
@numba.jit(nopython=True)
def circle_nms(dets, thresh, post_max_size=83):
"""Circular NMS.
An object is only counted as positive if no other center
with a higher confidence exists within a radius r using a
bird-eye view distance metric.
Args:
dets (torch.Tensor): Detection results with the shape of [N, 3].
thresh (float): Value of threshold.
post_max_size (int, optional): Max number of prediction to be kept.
Defaults to 83.
Returns:
torch.Tensor: Indexes of the detections to be kept.
"""
x1 = dets[:, 0]
y1 = dets[:, 1]
scores = dets[:, 2]
order = scores.argsort()[::-1].astype(np.int32) # highest->lowest
ndets = dets.shape[0]
suppressed = np.zeros((ndets), dtype=np.int32)
keep = []
for _i in range(ndets):
i = order[_i] # start with highest score box
if suppressed[
i] == 1: # if any box have enough iou with this, remove it
continue
keep.append(i)
for _j in range(_i + 1, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
# calculate center distance between i and j box
dist = (x1[i] - x1[j])**2 + (y1[i] - y1[j])**2
# ovr = inter / areas[j]
if dist <= thresh:
suppressed[j] = 1
if post_max_size < len(keep):
return keep[:post_max_size]
return keep
# This function duplicates functionality of mmcv.ops.iou_3d.nms_bev
# from mmcv<=1.5, but using cuda ops from mmcv.ops.nms.nms_rotated.
# Nms api will be unified in mmdetection3d one day.
def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None):
"""NMS function GPU implementation (for BEV boxes). The overlap of two
boxes for IoU calculation is defined as the exact overlapping area of the
two boxes. In this function, one can also set ``pre_max_size`` and
``post_max_size``.
Args:
boxes (torch.Tensor): Input boxes with the shape of [N, 5]
([x1, y1, x2, y2, ry]).
scores (torch.Tensor): Scores of boxes with the shape of [N].
thresh (float): Overlap threshold of NMS.
pre_max_size (int, optional): Max size of boxes before NMS.
Default: None.
post_max_size (int, optional): Max size of boxes after NMS.
Default: None.
Returns:
torch.Tensor: Indexes after NMS.
"""
assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]'
order = scores.sort(0, descending=True)[1]
if pre_max_size is not None:
order = order[:pre_max_size]
boxes = boxes[order].contiguous()
scores = scores[order]
# xyxyr -> back to xywhr
# note: better skip this step before nms_bev call in the future
boxes = torch.stack(
((boxes[:, 0] + boxes[:, 2]) / 2, (boxes[:, 1] + boxes[:, 3]) / 2,
boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1], boxes[:, 4]),
dim=-1)
keep = nms_rotated(boxes, scores, thresh)[1]
keep = order[keep]
if post_max_size is not None:
keep = keep[:post_max_size]
return keep
# This function duplicates functionality of mmcv.ops.iou_3d.nms_normal_bev
# from mmcv<=1.5, but using cuda ops from mmcv.ops.nms.nms.
# Nms api will be unified in mmdetection3d one day.
def nms_normal_bev(boxes, scores, thresh):
"""Normal NMS function GPU implementation (for BEV boxes). The overlap of
two boxes for IoU calculation is defined as the exact overlapping area of
the two boxes WITH their yaw angle set to 0.
Args:
boxes (torch.Tensor): Input boxes with shape (N, 5).
scores (torch.Tensor): Scores of predicted boxes with shape (N).
thresh (float): Overlap threshold of NMS.
Returns:
torch.Tensor: Remaining indices with scores in descending order.
"""
assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]'
return nms(boxes[:, :-1], scores, thresh)[1]