From 24a76e59e5fd51ae8368f2614fba7e468572c512 Mon Sep 17 00:00:00 2001 From: Tai-Wang Date: Fri, 9 Jul 2021 16:36:58 +0800 Subject: [PATCH 1/9] Fix rotation and dim hacks between nuscbox and our cambox --- ...d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d.py | 2 +- mmdet3d/datasets/nuscenes_mono_dataset.py | 12 +++++++++--- mmdet3d/datasets/pipelines/transforms_3d.py | 2 ++ mmdet3d/models/dense_heads/fcos_mono3d_head.py | 4 ++++ tools/data_converter/nuscenes_converter.py | 9 +++++++-- 5 files changed, 23 insertions(+), 6 deletions(-) diff --git a/configs/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d.py b/configs/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d.py index 3b7eb99fce..167c139c6a 100644 --- a/configs/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d.py +++ b/configs/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d.py @@ -25,7 +25,7 @@ with_label_3d=True, with_bbox_depth=True), dict(type='Resize', img_scale=(1600, 900), keep_ratio=True), - dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), + dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.0), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size_divisor=32), dict(type='DefaultFormatBundle3D', class_names=class_names), diff --git a/mmdet3d/datasets/nuscenes_mono_dataset.py b/mmdet3d/datasets/nuscenes_mono_dataset.py index 8096c288a0..5a34aa7051 100644 --- a/mmdet3d/datasets/nuscenes_mono_dataset.py +++ b/mmdet3d/datasets/nuscenes_mono_dataset.py @@ -173,9 +173,6 @@ def _parse_ann_info(self, img_info, ann_info): gt_masks_ann.append(ann.get('segmentation', None)) # 3D annotations in camera coordinates bbox_cam3d = np.array(ann['bbox_cam3d']).reshape(1, -1) - # change orientation to local yaw - bbox_cam3d[0, 6] = -np.arctan2( - bbox_cam3d[0, 0], bbox_cam3d[0, 2]) + bbox_cam3d[0, 6] velo_cam3d = np.array(ann['velo_cam3d']).reshape(1, 2) nan_mask = np.isnan(velo_cam3d[:, 0]) velo_cam3d[nan_mask] = [0.0, 0.0] @@ -666,6 +663,10 @@ def output_to_nusc_box(detection): box_dims = box3d.dims.numpy() box_yaw = box3d.yaw.numpy() + # convert the dim/rot to nuscbox convention + box_dims[:, [1, 2]] = box_dims[:, [2, 1]] + box_yaw = -box_yaw - np.pi / 2 + box_list = [] for i in range(len(box3d)): q1 = pyquaternion.Quaternion(axis=[0, 0, 1], radians=box_yaw[i]) @@ -778,6 +779,11 @@ def nusc_box_to_cam_box3d(boxes): rots = torch.Tensor([b.orientation.yaw_pitch_roll[0] for b in boxes]).view(-1, 1) velocity = torch.Tensor([b.velocity[:2] for b in boxes]).view(-1, 2) + + # convert nusbox to cambox convention + dims[:, [1, 2]] = dims[:, [2, 1]] + rots = -np.pi / 2.0 - rots + boxes_3d = torch.cat([locs, dims, rots, velocity], dim=1).cuda() cam_boxes3d = CameraInstance3DBoxes( boxes_3d, box_dim=9, origin=(0.5, 0.5, 0.5)) diff --git a/mmdet3d/datasets/pipelines/transforms_3d.py b/mmdet3d/datasets/pipelines/transforms_3d.py index e75e63fb6e..45a0a82fa3 100644 --- a/mmdet3d/datasets/pipelines/transforms_3d.py +++ b/mmdet3d/datasets/pipelines/transforms_3d.py @@ -121,6 +121,8 @@ def random_flip_data_3d(self, input_dict, direction='horizontal'): w = input_dict['img_shape'][1] input_dict['centers2d'][..., 0] = \ w - input_dict['centers2d'][..., 0] + # input_dict['cam_intrinsic'][0][2] = \ + # w - input_dict['cam_intrinsic'][0][2] def __call__(self, input_dict): """Call function to flip points, values in the ``bbox3d_fields`` and \ diff --git a/mmdet3d/models/dense_heads/fcos_mono3d_head.py b/mmdet3d/models/dense_heads/fcos_mono3d_head.py index 6d99e3b6f9..a1f0c4f8ee 100644 --- a/mmdet3d/models/dense_heads/fcos_mono3d_head.py +++ b/mmdet3d/models/dense_heads/fcos_mono3d_head.py @@ -862,6 +862,10 @@ def _get_target_single(self, gt_bboxes, gt_labels, gt_bboxes_3d, attr_labels.new_full( (num_points,), self.attr_background_label) + # change orientation to local yaw + gt_bboxes_3d[..., 6] = -torch.atan2( + gt_bboxes_3d[..., 0], gt_bboxes_3d[..., 2]) + gt_bboxes_3d[..., 6] + areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * ( gt_bboxes[:, 3] - gt_bboxes[:, 1]) areas = areas[None].repeat(num_points, 1) diff --git a/tools/data_converter/nuscenes_converter.py b/tools/data_converter/nuscenes_converter.py index d7da7306c2..bd02a02679 100644 --- a/tools/data_converter/nuscenes_converter.py +++ b/tools/data_converter/nuscenes_converter.py @@ -483,8 +483,13 @@ def get_2d_boxes(nusc, # If mono3d=True, add 3D annotations in camera coordinates if mono3d and (repro_rec is not None): loc = box.center.tolist() - dim = box.wlh.tolist() - rot = [box.orientation.yaw_pitch_roll[0]] + + dim = box.wlh + dim[[1, 2]] = dim[[2, 1]] # convert wlh to our whl + dim = dim.tolist() + + rot = box.orientation.yaw_pitch_roll[0] + rot = [-rot - np.pi / 2] # convert the rot to our cam coordinate global_velo2d = nusc.box_velocity(box.token)[:2] global_velo3d = np.array([*global_velo2d, 0.0]) From f5e66e6416838853edd38978bdbba5eb31f833ef Mon Sep 17 00:00:00 2001 From: Tai-Wang Date: Sun, 11 Jul 2021 18:13:54 +0800 Subject: [PATCH 2/9] Remove hack in the mono browse, fix cam_intrinsic in the RandomFlip3D --- mmdet3d/datasets/pipelines/transforms_3d.py | 6 +++--- tools/misc/browse_dataset.py | 4 ---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/mmdet3d/datasets/pipelines/transforms_3d.py b/mmdet3d/datasets/pipelines/transforms_3d.py index 45a0a82fa3..477208bf9b 100644 --- a/mmdet3d/datasets/pipelines/transforms_3d.py +++ b/mmdet3d/datasets/pipelines/transforms_3d.py @@ -118,11 +118,11 @@ def random_flip_data_3d(self, input_dict, direction='horizontal'): if 'centers2d' in input_dict: assert self.sync_2d is True and direction == 'horizontal', \ 'Only support sync_2d=True and horizontal flip with images' - w = input_dict['img_shape'][1] + w = input_dict['ori_shape'][1] input_dict['centers2d'][..., 0] = \ w - input_dict['centers2d'][..., 0] - # input_dict['cam_intrinsic'][0][2] = \ - # w - input_dict['cam_intrinsic'][0][2] + input_dict['cam_intrinsic'][0][2] = \ + w - input_dict['cam_intrinsic'][0][2] def __call__(self, input_dict): """Call function to flip points, values in the ``bbox3d_fields`` and \ diff --git a/tools/misc/browse_dataset.py b/tools/misc/browse_dataset.py index 6bc8e994ba..007883d837 100644 --- a/tools/misc/browse_dataset.py +++ b/tools/misc/browse_dataset.py @@ -159,10 +159,6 @@ def show_proj_bbox_img(idx, img_metas=img_metas, show=show) elif isinstance(gt_bboxes, CameraInstance3DBoxes): - # TODO: remove the hack of box from NuScenesMonoDataset - if is_nus_mono: - from mmdet3d.core.bbox import mono_cam_box2vis - gt_bboxes = mono_cam_box2vis(gt_bboxes) show_multi_modality_result( img, gt_bboxes, From b0f84a3dfb09dd3e0c73d7d41ab4c5a8fcdd465a Mon Sep 17 00:00:00 2001 From: Tai-Wang Date: Mon, 12 Jul 2021 08:58:04 +0800 Subject: [PATCH 3/9] Apply a more suitable hack transformation (dim & yaw hack) --- ...fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d.py | 2 +- mmdet3d/datasets/nuscenes_mono_dataset.py | 8 ++++---- tools/data_converter/nuscenes_converter.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/configs/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d.py b/configs/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d.py index 167c139c6a..3b7eb99fce 100644 --- a/configs/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d.py +++ b/configs/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d.py @@ -25,7 +25,7 @@ with_label_3d=True, with_bbox_depth=True), dict(type='Resize', img_scale=(1600, 900), keep_ratio=True), - dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.0), + dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size_divisor=32), dict(type='DefaultFormatBundle3D', class_names=class_names), diff --git a/mmdet3d/datasets/nuscenes_mono_dataset.py b/mmdet3d/datasets/nuscenes_mono_dataset.py index 5a34aa7051..99644375ef 100644 --- a/mmdet3d/datasets/nuscenes_mono_dataset.py +++ b/mmdet3d/datasets/nuscenes_mono_dataset.py @@ -664,8 +664,8 @@ def output_to_nusc_box(detection): box_yaw = box3d.yaw.numpy() # convert the dim/rot to nuscbox convention - box_dims[:, [1, 2]] = box_dims[:, [2, 1]] - box_yaw = -box_yaw - np.pi / 2 + box_dims[:, [0, 1, 2]] = box_dims[:, [2, 0, 1]] + box_yaw = -box_yaw box_list = [] for i in range(len(box3d)): @@ -781,8 +781,8 @@ def nusc_box_to_cam_box3d(boxes): velocity = torch.Tensor([b.velocity[:2] for b in boxes]).view(-1, 2) # convert nusbox to cambox convention - dims[:, [1, 2]] = dims[:, [2, 1]] - rots = -np.pi / 2.0 - rots + dims[:, [0, 1, 2]] = dims[:, [2, 0, 1]] + rots = -rots boxes_3d = torch.cat([locs, dims, rots, velocity], dim=1).cuda() cam_boxes3d = CameraInstance3DBoxes( diff --git a/tools/data_converter/nuscenes_converter.py b/tools/data_converter/nuscenes_converter.py index bd02a02679..f585a11969 100644 --- a/tools/data_converter/nuscenes_converter.py +++ b/tools/data_converter/nuscenes_converter.py @@ -485,11 +485,11 @@ def get_2d_boxes(nusc, loc = box.center.tolist() dim = box.wlh - dim[[1, 2]] = dim[[2, 1]] # convert wlh to our whl + dim[[0, 1, 2]] = dim[[1, 2, 0]] # convert wlh to our lhw dim = dim.tolist() rot = box.orientation.yaw_pitch_roll[0] - rot = [-rot - np.pi / 2] # convert the rot to our cam coordinate + rot = [-rot] # convert the rot to our cam coordinate global_velo2d = nusc.box_velocity(box.token)[:2] global_velo3d = np.array([*global_velo2d, 0.0]) From 8f5c1ef303d2e7f7e808286cfaec17e01bc98e1d Mon Sep 17 00:00:00 2001 From: Tai-Wang Date: Mon, 12 Jul 2021 18:00:59 +0800 Subject: [PATCH 4/9] Fix incorrect transformation of dim in post-processing --- mmdet3d/datasets/nuscenes_mono_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmdet3d/datasets/nuscenes_mono_dataset.py b/mmdet3d/datasets/nuscenes_mono_dataset.py index 99644375ef..f5ed089821 100644 --- a/mmdet3d/datasets/nuscenes_mono_dataset.py +++ b/mmdet3d/datasets/nuscenes_mono_dataset.py @@ -781,7 +781,7 @@ def nusc_box_to_cam_box3d(boxes): velocity = torch.Tensor([b.velocity[:2] for b in boxes]).view(-1, 2) # convert nusbox to cambox convention - dims[:, [0, 1, 2]] = dims[:, [2, 0, 1]] + dims[:, [0, 1, 2]] = dims[:, [1, 2, 0]] rots = -rots boxes_3d = torch.cat([locs, dims, rots, velocity], dim=1).cuda() From 89f5da4194b22ac76d0b477b2d3fff125f3c2142 Mon Sep 17 00:00:00 2001 From: Tai-Wang Date: Tue, 13 Jul 2021 16:24:03 +0800 Subject: [PATCH 5/9] Remove transformation from global to local yaw in kitti mono dataset --- mmdet3d/datasets/kitti_mono_dataset.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mmdet3d/datasets/kitti_mono_dataset.py b/mmdet3d/datasets/kitti_mono_dataset.py index ec2975a081..5b9d31b12d 100644 --- a/mmdet3d/datasets/kitti_mono_dataset.py +++ b/mmdet3d/datasets/kitti_mono_dataset.py @@ -88,9 +88,6 @@ def _parse_ann_info(self, img_info, ann_info): gt_masks_ann.append(ann.get('segmentation', None)) # 3D annotations in camera coordinates bbox_cam3d = np.array(ann['bbox_cam3d']).reshape(-1, ) - # change orientation to local yaw - bbox_cam3d[6] = -np.arctan2(bbox_cam3d[0], - bbox_cam3d[2]) + bbox_cam3d[6] gt_bboxes_cam3d.append(bbox_cam3d) # 2.5D annotations in camera coordinates center2d = ann['center2d'][:2] From deeb495975bbac2130e5eb5c1dd08c0a7a20bda7 Mon Sep 17 00:00:00 2001 From: Tai-Wang Date: Wed, 14 Jul 2021 10:26:14 +0800 Subject: [PATCH 6/9] Add comments for cam_intrinsic modification in the doc and code --- docs/demo.md | 2 ++ mmdet3d/datasets/pipelines/transforms_3d.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/docs/demo.md b/docs/demo.md index 8b280dd92d..22563353e0 100644 --- a/docs/demo.md +++ b/docs/demo.md @@ -70,6 +70,8 @@ Example on nuScenes data using [FCOS3D](https://github.com/open-mmlab/mmdetectio python demo/mono_det_demo.py demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__CAM_BACK__1532402927637525.jpg demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__CAM_BACK__1532402927637525_mono3d.coco.json configs/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune.py checkpoints/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune_20210427_091419-35aaaad0.pth ``` +Note that when visualizing results of monocular 3D detection for flipped images, the camera intrinsic matrix should also be modified accordingly. See more details and examples in PR [#744](https://github.com/open-mmlab/mmdetection3d/pull/744). + ### 3D Segmentation To test a 3D segmentor on point cloud data, simply run: diff --git a/mmdet3d/datasets/pipelines/transforms_3d.py b/mmdet3d/datasets/pipelines/transforms_3d.py index 456100476d..6788df68c4 100644 --- a/mmdet3d/datasets/pipelines/transforms_3d.py +++ b/mmdet3d/datasets/pipelines/transforms_3d.py @@ -122,6 +122,11 @@ def random_flip_data_3d(self, input_dict, direction='horizontal'): w = input_dict['ori_shape'][1] input_dict['centers2d'][..., 0] = \ w - input_dict['centers2d'][..., 0] + # need to modify the horizontal position of camera center + # along u-axis in the image (flip like centers2d) + # ['cam_intrinsic'][0][2] = c_u + # see more details and examples at + # https://github.com/open-mmlab/mmdetection3d/pull/744 input_dict['cam_intrinsic'][0][2] = \ w - input_dict['cam_intrinsic'][0][2] From 22781649e9abf3ab778c636c6862b17dbd7b5ac6 Mon Sep 17 00:00:00 2001 From: Tai-Wang Date: Sat, 17 Jul 2021 19:53:24 +0800 Subject: [PATCH 7/9] Fix typos and invalid links --- docs/compatibility.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index cca2522bcb..a6921d7893 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -10,12 +10,12 @@ In order to fix the problem that the priority of EvalHook is too low, all hook p ### Unified parameter initialization -To unify the parameter initialization in OpenMMLab projects, MMCV supports `BaseModule` that accepts `init_cfg` to allow the modules' parameters initialized in a flexible and unified manner. Now the users need to explicitly call `model.init_weights()` in the training script to initialize the model (as in [here](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/train.py#L183), previously this was handled by the detector. Please refer to PR #622 for details. +To unify the parameter initialization in OpenMMLab projects, MMCV supports `BaseModule` that accepts `init_cfg` to allow the modules' parameters initialized in a flexible and unified manner. Now the users need to explicitly call `model.init_weights()` in the training script to initialize the model (as in [here](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/train.py#L183), previously this was handled by the detector. Please refer to PR [#622](https://github.com/open-mmlab/mmdetection3d/pull/622) for details. ### BackgroundPointsFilter -We modified the dataset aumentation function `BackgroundPointsFilter`(in [here](https://github.com/open-mmlab/mmdetection3d/blob/mmdet3d/datasets/pipelines/transforms_3d.py#L1101)). In previous version of MMdetection3D, `BackgroundPointsFilter` changes the gt_bboxes_3d's bottom center to the gravity center. In MMDetection3D 0.15.0, -`BackgroundPointsFilter` will not change it. Please refer to PR #609 for details. +We modified the dataset augmentation function `BackgroundPointsFilter`([here](https://github.com/open-mmlab/mmdetection3d/blob/v0.15.0/mmdet3d/datasets/pipelines/transforms_3d.py#L1132)). In previous version of MMdetection3D, `BackgroundPointsFilter` changes the gt_bboxes_3d's bottom center to the gravity center. In MMDetection3D 0.15.0, +`BackgroundPointsFilter` will not change it. Please refer to PR [#609](https://github.com/open-mmlab/mmdetection3d/pull/609) for details. ### Enhance `IndoorPatchPointSample` transform @@ -45,7 +45,7 @@ We have trained a [VoteNet](https://github.com/open-mmlab/mmdetection3d/blob/mas ### SUNRGBD dataset for ImVoteNet -We adopt a new pre-processing procedure for the SUNRGBD dataset in order to support ImVoteNet, which is a multi-modality method requiring both image and point cloud data. In previous versions of MMDetection3D, SUNRGBD dataset was only used for point cloud based 3D detection methods. In MMDetection3D 0.12.0, we add ImVoteNet to our model zoo, thus updating SUNRGBD correspondingly by adding image-related pre-processing steps. Specificly, we made these changes: +We adopt a new pre-processing procedure for the SUNRGBD dataset in order to support ImVoteNet, which is a multi-modality method requiring both image and point cloud data. In previous versions of MMDetection3D, SUNRGBD dataset was only used for point cloud based 3D detection methods. In MMDetection3D 0.12.0, we add ImVoteNet to our model zoo, thus updating SUNRGBD correspondingly by adding image-related pre-processing steps. Specifically, we made these changes: - Fix a bug in the image file path in meta data. - Convert calibration matrices from double to float to avoid type mismatch in further operations. From 8c3ae136ae2469b5712a8211b34aee4b8ab921af Mon Sep 17 00:00:00 2001 From: Tai-Wang Date: Sat, 17 Jul 2021 20:34:09 +0800 Subject: [PATCH 8/9] Add compatibility doc for fixing nus hacks --- docs/compatibility.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/compatibility.md b/docs/compatibility.md index a6921d7893..1dac3ce245 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -2,6 +2,14 @@ This document provides detailed descriptions of the BC-breaking changes in MMDetection3D. +## MMDetection3D 0.16.0 + +### NuScenes coco-style data pre-processing + +We remove the rotation and dimension hack in the monocular 3D detection on nuScenes. Specifically, we transform the rotation and dimension of boxes defined by nuScenes devkit to the coordinate system of our `CameraInstance3DBoxes` in the pre-processing and transform them back in the post-processing. In this way, we can remove the corresponding [hack](https://github.com/open-mmlab/mmdetection3d/pull/744/files#diff-5bee5062bd84e6fa25a2fdd71353f6f283dfdc4a66a0316c3b1ca26078c978b6L165) used in the visualization tools. The modification also guarantees the correctness of all the operations based on our `CameraInstance3DBoxes` (such as NMS and flip augmentation) when training monocular 3D detectors. + +The modification only influences nuScenes coco-style json files. Please re-run the nuScenes data preparation script if necessary. See more details in the PR [#744](https://github.com/open-mmlab/mmdetection3d/pull/744). + ## MMDetection3D 0.15.0 ### MMCV Version From 15ab1ee8d41197ef2f8bdb03e95a777c5589cd63 Mon Sep 17 00:00:00 2001 From: Tai-Wang Date: Wed, 21 Jul 2021 17:02:34 +0800 Subject: [PATCH 9/9] Update benchmark --- configs/fcos3d/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/fcos3d/README.md b/configs/fcos3d/README.md index 264b03d31b..de1aae6b29 100644 --- a/configs/fcos3d/README.md +++ b/configs/fcos3d/README.md @@ -58,6 +58,6 @@ We also provide visualization functions to show the monocular 3D detection resul | Backbone | Lr schd | Mem (GB) | Inf time (fps) | mAP | NDS | Download | | :---------: | :-----: | :------: | :------------: | :----: |:----: | :------: | -|[ResNet101 w/ DCN](./fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d.py)|1x|8.69||29.9|37.3|[model](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_20210425_181341-8d5a21fe.pth) | [log](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_20210425_181341.log.json)| -|[above w/ finetune](./fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune.py)|1x|8.69||32.1|39.3|[model](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune_20210427_091419-35aaaad0.pth) | [log](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune_20210427_091419.log.json)| -|above w/ tta|1x|8.69||33.1|40.0|| +|[ResNet101 w/ DCN](./fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d.py)|1x|8.69||29.8|37.7|[model](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_20210715_235813-4bed5239.pth) | [log](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_20210715_235813.log.json)| +|[above w/ finetune](./fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune.py)|1x|8.69||32.1|39.5|[model](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune_20210717_095645-8d806dc2.pth) | [log](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune_20210717_095645.log.json)| +|above w/ tta|1x|8.69||33.1|40.3||