Skip to content

Commit e9f48a4

Browse files
authored
[Enhance] Replace BEV IoU with 3D IoU (#1902)
* add iou3d * revert deprecated python function * fix lint * replace 3d iou/nms calls for bev iou/nms
1 parent 7e6f462 commit e9f48a4

File tree

10 files changed

+372
-312
lines changed

10 files changed

+372
-312
lines changed

docs/en/understand_mmcv/ops.md

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
2424
- MaskedConv
2525
- MinAreaPolygon
2626
- NMS
27+
- NMS3D
2728
- PointsInPolygons
2829
- PSAMask
2930
- RiRoIAlignRotated

docs/zh_cn/understand_mmcv/ops.md

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
2323
- MaskedConv
2424
- MinAreaPolygon
2525
- NMS
26+
- NMS3D
2627
- PointsInPolygons
2728
- PSAMask
2829
- RotatedFeatureAlign

mmcv/ops/__init__.py

100644100755
+12-10
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from .group_points import GroupAll, QueryAndGroup, grouping_operation
2929
from .info import (get_compiler_version, get_compiling_cuda_version,
3030
get_onnxruntime_op_path)
31-
from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev
31+
from .iou3d import (boxes_iou3d, boxes_iou_bev, nms3d, nms3d_normal, nms_bev,
32+
nms_normal_bev)
3233
from .knn import knn
3334
from .masked_conv import MaskedConv2d, masked_conv2d
3435
from .min_area_polygons import min_area_polygons
@@ -89,13 +90,14 @@
8990
'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
9091
'border_align', 'gather_points', 'furthest_point_sample',
9192
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
92-
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
93-
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
94-
'SparseConv2d', 'SparseConv3d', 'SparseConvTranspose2d',
95-
'SparseConvTranspose3d', 'SparseInverseConv2d', 'SparseInverseConv3d',
96-
'SubMConv2d', 'SubMConv3d', 'SparseModule', 'SparseSequential',
97-
'SparseMaxPool2d', 'SparseMaxPool3d', 'SparseConvTensor', 'scatter_nd',
98-
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all',
99-
'points_in_polygons', 'min_area_polygons', 'active_rotated_filter',
100-
'convex_iou', 'convex_giou', 'diff_iou_rotated_2d', 'diff_iou_rotated_3d'
93+
'boxes_iou3d', 'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'nms3d',
94+
'nms3d_normal', 'Voxelization', 'voxelization', 'dynamic_scatter',
95+
'DynamicScatter', 'RoIAwarePool3d', 'SparseConv2d', 'SparseConv3d',
96+
'SparseConvTranspose2d', 'SparseConvTranspose3d', 'SparseInverseConv2d',
97+
'SparseInverseConv3d', 'SubMConv2d', 'SubMConv3d', 'SparseModule',
98+
'SparseSequential', 'SparseMaxPool2d', 'SparseMaxPool3d',
99+
'SparseConvTensor', 'scatter_nd', 'points_in_boxes_part',
100+
'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons',
101+
'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou',
102+
'diff_iou_rotated_2d', 'diff_iou_rotated_3d'
101103
]

mmcv/ops/csrc/common/cuda/iou3d_cuda_kernel.cuh

+85-88
Original file line numberDiff line numberDiff line change
@@ -50,21 +50,17 @@ __device__ int check_rect_cross(const Point &p1, const Point &p2,
5050
}
5151

5252
__device__ inline int check_in_box2d(const float *box, const Point &p) {
53-
// params: box (5) [x1, y1, x2, y2, angle]
54-
const float MARGIN = 1e-5;
55-
56-
float center_x = (box[0] + box[2]) / 2;
57-
float center_y = (box[1] + box[3]) / 2;
58-
float angle_cos = cos(-box[4]),
59-
angle_sin =
60-
sin(-box[4]); // rotate the point in the opposite direction of box
61-
float rot_x =
62-
(p.x - center_x) * angle_cos - (p.y - center_y) * angle_sin + center_x;
63-
float rot_y =
64-
(p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos + center_y;
65-
66-
return (rot_x > box[0] - MARGIN && rot_x < box[2] + MARGIN &&
67-
rot_y > box[1] - MARGIN && rot_y < box[3] + MARGIN);
53+
// params: box (7) [x, y, z, dx, dy, dz, heading]
54+
const float MARGIN = 1e-2;
55+
56+
float center_x = box[0], center_y = box[1];
57+
// rotate the point in the opposite direction of box
58+
float angle_cos = cos(-box[6]), angle_sin = sin(-box[6]);
59+
float rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin);
60+
float rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos;
61+
62+
return (fabs(rot_x) < box[3] / 2 + MARGIN &&
63+
fabs(rot_y) < box[4] / 2 + MARGIN);
6864
}
6965

7066
__device__ inline int intersection(const Point &p1, const Point &p0,
@@ -116,16 +112,19 @@ __device__ inline int point_cmp(const Point &a, const Point &b,
116112
}
117113

118114
__device__ inline float box_overlap(const float *box_a, const float *box_b) {
119-
// params: box_a (5) [x1, y1, x2, y2, angle]
120-
// params: box_b (5) [x1, y1, x2, y2, angle]
115+
// params box_a: [x, y, z, dx, dy, dz, heading]
116+
// params box_b: [x, y, z, dx, dy, dz, heading]
121117

122-
float a_x1 = box_a[0], a_y1 = box_a[1], a_x2 = box_a[2], a_y2 = box_a[3],
123-
a_angle = box_a[4];
124-
float b_x1 = box_b[0], b_y1 = box_b[1], b_x2 = box_b[2], b_y2 = box_b[3],
125-
b_angle = box_b[4];
118+
float a_angle = box_a[6], b_angle = box_b[6];
119+
float a_dx_half = box_a[3] / 2, b_dx_half = box_b[3] / 2,
120+
a_dy_half = box_a[4] / 2, b_dy_half = box_b[4] / 2;
121+
float a_x1 = box_a[0] - a_dx_half, a_y1 = box_a[1] - a_dy_half;
122+
float a_x2 = box_a[0] + a_dx_half, a_y2 = box_a[1] + a_dy_half;
123+
float b_x1 = box_b[0] - b_dx_half, b_y1 = box_b[1] - b_dy_half;
124+
float b_x2 = box_b[0] + b_dx_half, b_y2 = box_b[1] + b_dy_half;
126125

127-
Point center_a((a_x1 + a_x2) / 2, (a_y1 + a_y2) / 2);
128-
Point center_b((b_x1 + b_x2) / 2, (b_y1 + b_y2) / 2);
126+
Point center_a(box_a[0], box_a[1]);
127+
Point center_b(box_b[0], box_b[1]);
129128

130129
Point box_a_corners[5];
131130
box_a_corners[0].set(a_x1, a_y1);
@@ -209,50 +208,36 @@ __device__ inline float box_overlap(const float *box_a, const float *box_b) {
209208
}
210209

211210
__device__ inline float iou_bev(const float *box_a, const float *box_b) {
212-
// params: box_a (5) [x1, y1, x2, y2, angle]
213-
// params: box_b (5) [x1, y1, x2, y2, angle]
214-
float sa = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]);
215-
float sb = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]);
211+
// params box_a: [x, y, z, dx, dy, dz, heading]
212+
// params box_b: [x, y, z, dx, dy, dz, heading]
213+
float sa = box_a[3] * box_a[4];
214+
float sb = box_b[3] * box_b[4];
216215
float s_overlap = box_overlap(box_a, box_b);
217216
return s_overlap / fmaxf(sa + sb - s_overlap, EPS);
218217
}
219218

220-
__global__ void iou3d_boxes_overlap_bev_forward_cuda_kernel(
221-
const int num_a, const float *boxes_a, const int num_b,
222-
const float *boxes_b, float *ans_overlap) {
223-
CUDA_2D_KERNEL_LOOP(b_idx, num_b, a_idx, num_a) {
224-
if (a_idx >= num_a || b_idx >= num_b) {
225-
return;
226-
}
227-
const float *cur_box_a = boxes_a + a_idx * 5;
228-
const float *cur_box_b = boxes_b + b_idx * 5;
229-
float s_overlap = box_overlap(cur_box_a, cur_box_b);
230-
ans_overlap[a_idx * num_b + b_idx] = s_overlap;
231-
}
232-
}
233-
234-
__global__ void iou3d_boxes_iou_bev_forward_cuda_kernel(const int num_a,
235-
const float *boxes_a,
236-
const int num_b,
237-
const float *boxes_b,
238-
float *ans_iou) {
219+
__global__ void iou3d_boxes_iou3d_forward_cuda_kernel(const int num_a,
220+
const float *boxes_a,
221+
const int num_b,
222+
const float *boxes_b,
223+
float *ans_iou) {
239224
CUDA_2D_KERNEL_LOOP(b_idx, num_b, a_idx, num_a) {
240225
if (a_idx >= num_a || b_idx >= num_b) {
241226
return;
242227
}
243228

244-
const float *cur_box_a = boxes_a + a_idx * 5;
245-
const float *cur_box_b = boxes_b + b_idx * 5;
229+
const float *cur_box_a = boxes_a + a_idx * 7;
230+
const float *cur_box_b = boxes_b + b_idx * 7;
246231
float cur_iou_bev = iou_bev(cur_box_a, cur_box_b);
247232
ans_iou[a_idx * num_b + b_idx] = cur_iou_bev;
248233
}
249234
}
250235

251-
__global__ void nms_forward_cuda_kernel(const int boxes_num,
252-
const float nms_overlap_thresh,
253-
const float *boxes,
254-
unsigned long long *mask) {
255-
// params: boxes (N, 5) [x1, y1, x2, y2, ry]
236+
__global__ void iou3d_nms3d_forward_cuda_kernel(const int boxes_num,
237+
const float nms_overlap_thresh,
238+
const float *boxes,
239+
unsigned long long *mask) {
240+
// params: boxes (N, 7) [x, y, z, dx, dy, dz, heading]
256241
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
257242
const int blocks =
258243
(boxes_num + THREADS_PER_BLOCK_NMS - 1) / THREADS_PER_BLOCK_NMS;
@@ -264,25 +249,29 @@ __global__ void nms_forward_cuda_kernel(const int boxes_num,
264249
const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS,
265250
THREADS_PER_BLOCK_NMS);
266251

267-
__shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5];
252+
__shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 7];
268253

269254
if (threadIdx.x < col_size) {
270-
block_boxes[threadIdx.x * 5 + 0] =
271-
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0];
272-
block_boxes[threadIdx.x * 5 + 1] =
273-
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1];
274-
block_boxes[threadIdx.x * 5 + 2] =
275-
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2];
276-
block_boxes[threadIdx.x * 5 + 3] =
277-
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3];
278-
block_boxes[threadIdx.x * 5 + 4] =
279-
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4];
255+
block_boxes[threadIdx.x * 7 + 0] =
256+
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 0];
257+
block_boxes[threadIdx.x * 7 + 1] =
258+
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 1];
259+
block_boxes[threadIdx.x * 7 + 2] =
260+
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 2];
261+
block_boxes[threadIdx.x * 7 + 3] =
262+
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 3];
263+
block_boxes[threadIdx.x * 7 + 4] =
264+
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 4];
265+
block_boxes[threadIdx.x * 7 + 5] =
266+
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 5];
267+
block_boxes[threadIdx.x * 7 + 6] =
268+
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 6];
280269
}
281270
__syncthreads();
282271

283272
if (threadIdx.x < row_size) {
284273
const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
285-
const float *cur_box = boxes + cur_box_idx * 5;
274+
const float *cur_box = boxes + cur_box_idx * 7;
286275

287276
int i = 0;
288277
unsigned long long t = 0;
@@ -291,7 +280,7 @@ __global__ void nms_forward_cuda_kernel(const int boxes_num,
291280
start = threadIdx.x + 1;
292281
}
293282
for (i = start; i < col_size; i++) {
294-
if (iou_bev(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
283+
if (iou_bev(cur_box, block_boxes + i * 7) > nms_overlap_thresh) {
295284
t |= 1ULL << i;
296285
}
297286
}
@@ -303,20 +292,24 @@ __global__ void nms_forward_cuda_kernel(const int boxes_num,
303292
}
304293

305294
__device__ inline float iou_normal(float const *const a, float const *const b) {
306-
float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]);
307-
float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]);
295+
// params: a: [x, y, z, dx, dy, dz, heading]
296+
// params: b: [x, y, z, dx, dy, dz, heading]
297+
298+
float left = fmaxf(a[0] - a[3] / 2, b[0] - b[3] / 2),
299+
right = fminf(a[0] + a[3] / 2, b[0] + b[3] / 2);
300+
float top = fmaxf(a[1] - a[4] / 2, b[1] - b[4] / 2),
301+
bottom = fminf(a[1] + a[4] / 2, b[1] + b[4] / 2);
308302
float width = fmaxf(right - left, 0.f), height = fmaxf(bottom - top, 0.f);
309303
float interS = width * height;
310-
float Sa = (a[2] - a[0]) * (a[3] - a[1]);
311-
float Sb = (b[2] - b[0]) * (b[3] - b[1]);
304+
float Sa = a[3] * a[4];
305+
float Sb = b[3] * b[4];
312306
return interS / fmaxf(Sa + Sb - interS, EPS);
313307
}
314308

315-
__global__ void nms_normal_forward_cuda_kernel(const int boxes_num,
316-
const float nms_overlap_thresh,
317-
const float *boxes,
318-
unsigned long long *mask) {
319-
// params: boxes (N, 5) [x1, y1, x2, y2, ry]
309+
__global__ void iou3d_nms3d_normal_forward_cuda_kernel(
310+
const int boxes_num, const float nms_overlap_thresh, const float *boxes,
311+
unsigned long long *mask) {
312+
// params: boxes (N, 7) [x, y, z, dx, dy, dz, heading]
320313
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
321314

322315
const int blocks =
@@ -329,25 +322,29 @@ __global__ void nms_normal_forward_cuda_kernel(const int boxes_num,
329322
const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS,
330323
THREADS_PER_BLOCK_NMS);
331324

332-
__shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5];
325+
__shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 7];
333326

334327
if (threadIdx.x < col_size) {
335-
block_boxes[threadIdx.x * 5 + 0] =
336-
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0];
337-
block_boxes[threadIdx.x * 5 + 1] =
338-
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1];
339-
block_boxes[threadIdx.x * 5 + 2] =
340-
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2];
341-
block_boxes[threadIdx.x * 5 + 3] =
342-
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3];
343-
block_boxes[threadIdx.x * 5 + 4] =
344-
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4];
328+
block_boxes[threadIdx.x * 7 + 0] =
329+
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 0];
330+
block_boxes[threadIdx.x * 7 + 1] =
331+
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 1];
332+
block_boxes[threadIdx.x * 7 + 2] =
333+
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 2];
334+
block_boxes[threadIdx.x * 7 + 3] =
335+
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 3];
336+
block_boxes[threadIdx.x * 7 + 4] =
337+
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 4];
338+
block_boxes[threadIdx.x * 7 + 5] =
339+
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 5];
340+
block_boxes[threadIdx.x * 7 + 6] =
341+
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 6];
345342
}
346343
__syncthreads();
347344

348345
if (threadIdx.x < row_size) {
349346
const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
350-
const float *cur_box = boxes + cur_box_idx * 5;
347+
const float *cur_box = boxes + cur_box_idx * 7;
351348

352349
int i = 0;
353350
unsigned long long t = 0;
@@ -356,7 +353,7 @@ __global__ void nms_normal_forward_cuda_kernel(const int boxes_num,
356353
start = threadIdx.x + 1;
357354
}
358355
for (i = start; i < col_size; i++) {
359-
if (iou_normal(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
356+
if (iou_normal(cur_box, block_boxes + i * 7) > nms_overlap_thresh) {
360357
t |= 1ULL << i;
361358
}
362359
}

0 commit comments

Comments
 (0)