Skip to content

Commit

Permalink
[Fix] Unify camera poses (#653)
Browse files Browse the repository at this point in the history
* refactor K and Rt to depth2img for SUN RGB-D

* fix lint

* update 3 tests

* fix extra calib key and comments

* remove calib from browse_dataset

* fix cam to depth; rename return_z
  • Loading branch information
filaPro authored Jun 30, 2021
1 parent 23071a5 commit ff62af6
Show file tree
Hide file tree
Showing 15 changed files with 136 additions and 159 deletions.
6 changes: 3 additions & 3 deletions configs/imvotenet/imvotenet_stage2_16x8_sunrgbd-3d-10class.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@
type='Collect3D',
keys=[
'img', 'gt_bboxes', 'gt_labels', 'points', 'gt_bboxes_3d',
'gt_labels_3d', 'calib'
'gt_labels_3d'
])
]

Expand Down Expand Up @@ -230,7 +230,7 @@
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['img', 'points', 'calib'])
dict(type='Collect3D', keys=['img', 'points'])
]),
]
# construct a pipeline for data and gt loading in show function
Expand All @@ -247,7 +247,7 @@
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['img', 'points', 'calib'])
dict(type='Collect3D', keys=['img', 'points'])
]

data = dict(
Expand Down
25 changes: 10 additions & 15 deletions mmdet3d/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,23 +155,25 @@ def inference_multi_modality_detector(model, pcd, image, ann_file):
bbox_fields=[],
mask_fields=[],
seg_fields=[])

# depth map points to image conversion
if box_mode_3d == Box3DMode.DEPTH:
data.update(dict(calib=info['calib']))

data = test_pipeline(data)

# TODO: this code is dataset-specific. Move lidar2img and
# depth2img to .pkl annotations in the future.
# LiDAR to image conversion
if box_mode_3d == Box3DMode.LIDAR:
rect = info['calib']['R0_rect'].astype(np.float32)
Trv2c = info['calib']['Tr_velo_to_cam'].astype(np.float32)
P2 = info['calib']['P2'].astype(np.float32)
lidar2img = P2 @ rect @ Trv2c
data['img_metas'][0].data['lidar2img'] = lidar2img
# Depth to image conversion
elif box_mode_3d == Box3DMode.DEPTH:
data['calib'][0]['Rt'] = data['calib'][0]['Rt'].astype(np.float32)
data['calib'][0]['K'] = data['calib'][0]['K'].astype(np.float32)
rt_mat = info['calib']['Rt']
# follow Coord3DMode.convert_point
rt_mat = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]
]) @ rt_mat.transpose(1, 0)
depth2img = info['calib']['K'] @ rt_mat
data['img_metas'][0].data['depth2img'] = depth2img

data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
Expand All @@ -182,9 +184,6 @@ def inference_multi_modality_detector(model, pcd, image, ann_file):
data['img_metas'] = data['img_metas'][0].data
data['points'] = data['points'][0].data
data['img'] = data['img'][0].data
if box_mode_3d == Box3DMode.DEPTH:
data['calib'][0]['Rt'] = data['calib'][0]['Rt'][0].data
data['calib'][0]['K'] = data['calib'][0]['K'][0].data

# forward the model
with torch.no_grad():
Expand Down Expand Up @@ -411,17 +410,13 @@ def show_proj_det_result_meshlab(data,
box_mode='lidar',
show=show)
elif box_mode == Box3DMode.DEPTH:
if 'calib' not in data.keys():
raise NotImplementedError(
'camera calibration information is not provided')

show_bboxes = DepthInstance3DBoxes(pred_bboxes, origin=(0.5, 0.5, 0))

show_multi_modality_result(
img,
None,
show_bboxes,
data['calib'][0],
None,
out_dir,
file_name,
box_mode='depth',
Expand Down
77 changes: 48 additions & 29 deletions mmdet3d/core/bbox/box_np_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def corners_nd(dims, origin=0.5):
Args:
dims (np.ndarray, shape=[N, ndim]): Array of length per dim
origin (list or array or float): origin point relate to smallest point.
origin (list or array or float, optional): origin point relate to
smallest point. Defaults to 0.5
Returns:
np.ndarray, shape=[N, 2 ** ndim, ndim]: Returned corners.
Expand Down Expand Up @@ -102,7 +103,10 @@ def center_to_corner_box2d(centers, dims, angles=None, origin=0.5):
Args:
centers (np.ndarray): Locations in kitti label file with shape (N, 2).
dims (np.ndarray): Dimensions in kitti label file with shape (N, 2).
angles (np.ndarray): Rotation_y in kitti label file with shape (N).
angles (np.ndarray, optional): Rotation_y in kitti label file with
shape (N). Defaults to None.
origin (list or array or float, optional): origin point relate to
smallest point. Defaults to 0.5.
Returns:
np.ndarray: Corners with the shape of (N, 4, 2).
Expand Down Expand Up @@ -173,7 +177,7 @@ def rotation_3d_in_axis(points, angles, axis=0):
Args:
points (np.ndarray, shape=[N, point_size, 3]]):
angles (np.ndarray, shape=[N]]):
axis (int): Axis to rotate at.
axis (int, optional): Axis to rotate at. Defaults to 0.
Returns:
np.ndarray: Rotated points.
Expand Down Expand Up @@ -208,10 +212,13 @@ def center_to_corner_box3d(centers,
Args:
centers (np.ndarray): Locations in kitti label file with shape (N, 3).
dims (np.ndarray): Dimensions in kitti label file with shape (N, 3).
angles (np.ndarray): Rotation_y in kitti label file with shape (N).
origin (list or array or float): Origin point relate to smallest point.
use (0.5, 1.0, 0.5) in camera and (0.5, 0.5, 0) in lidar.
axis (int): Rotation axis. 1 for camera and 2 for lidar.
angles (np.ndarray, optional): Rotation_y in kitti label file with
shape (N). Defaults to None.
origin (list or array or float, optional): Origin point relate to
smallest point. Use (0.5, 1.0, 0.5) in camera and (0.5, 0.5, 0)
in lidar. Defaults to (0.5, 1.0, 0.5).
axis (int, optional): Rotation axis. 1 for camera and 2 for lidar.
Defaults to 1.
Returns:
np.ndarray: Corners with the shape of (N, 8, 3).
Expand Down Expand Up @@ -308,8 +315,8 @@ def rotation_points_single_angle(points, angle, axis=0):
Args:
points (np.ndarray, shape=[N, 3]]):
angles (np.ndarray, shape=[1]]):
axis (int): Axis to rotate at.
angle (np.ndarray, shape=[1]]):
axis (int, optional): Axis to rotate at. Defaults to 0.
Returns:
np.ndarray: Rotated points.
Expand Down Expand Up @@ -341,7 +348,8 @@ def points_cam2img(points_3d, proj_mat, with_depth=False):
Args:
points_3d (np.ndarray): Points in shape (N, 3)
proj_mat (np.ndarray): Transformation matrix between coordinates.
with_depth (bool): Whether to keep depth in the output.
with_depth (bool, optional): Whether to keep depth in the output.
Defaults to False.
Returns:
np.ndarray: Points in image coordinates with shape [N, 2].
Expand Down Expand Up @@ -420,8 +428,10 @@ def points_in_rbbox(points, rbbox, z_axis=2, origin=(0.5, 0.5, 0)):
Args:
points (np.ndarray, shape=[N, 3+dim]): Points to query.
rbbox (np.ndarray, shape=[M, 7]): Boxes3d with rotation.
z_axis (int): Indicate which axis is height.
origin (tuple[int]): Indicate the position of box center.
z_axis (int, optional): Indicate which axis is height.
Defaults to 2.
origin (tuple[int], optional): Indicate the position of
box center. Defaults to (0.5, 0.5, 0).
Returns:
np.ndarray, shape=[N, M]: Indices of points in each box.
Expand Down Expand Up @@ -479,11 +489,13 @@ def create_anchors_3d_range(feature_size,
anchor_range (torch.Tensor | list[float]): Range of anchors with
shape [6]. The order is consistent with that of anchors, i.e.,
(x_min, y_min, z_min, x_max, y_max, z_max).
sizes (list[list] | np.ndarray | torch.Tensor): Anchor size with
shape [N, 3], in order of x, y, z.
rotations (list[float] | np.ndarray | torch.Tensor): Rotations of
anchors in a single feature grid.
dtype (type): Data type. Default to np.float32.
sizes (list[list] | np.ndarray | torch.Tensor, optional):
Anchor size with shape [N, 3], in order of x, y, z.
Defaults to ((1.6, 3.9, 1.56), ).
rotations (list[float] | np.ndarray | torch.Tensor, optional):
Rotations of anchors in a single feature grid.
Defaults to (0, np.pi / 2).
dtype (type, optional): Data type. Default to np.float32.
Returns:
np.ndarray: Range based anchors with shape of \
Expand Down Expand Up @@ -520,7 +532,8 @@ def center_to_minmax_2d(centers, dims, origin=0.5):
Args:
centers (np.ndarray): Center points.
dims (np.ndarray): Dimensions.
origin (list or array or float): origin point relate to smallest point.
origin (list or array or float, optional): Origin point relate
to smallest point. Defaults to 0.5.
Returns:
np.ndarray: Minmax points.
Expand Down Expand Up @@ -559,6 +572,8 @@ def iou_jit(boxes, query_boxes, mode='iou', eps=0.0):
Args:
boxes (np.ndarray): Input bounding boxes with shape of (N, 4).
query_boxes (np.ndarray): Query boxes with shape of (K, 4).
mode (str, optional): IoU mode. Defaults to 'iou'.
eps (float, optional): Value added to denominator. Defaults to 0.
Returns:
np.ndarray: Overlap between boxes and query_boxes
Expand Down Expand Up @@ -648,8 +663,10 @@ def get_frustum(bbox_image, C, near_clip=0.001, far_clip=100):
Args:
bbox_image (list[int]): box in image coordinates.
C (np.ndarray): Intrinsics.
near_clip (float): Nearest distance of frustum.
far_clip (float): Farthest distance of frustum.
near_clip (float, optional): Nearest distance of frustum.
Defaults to 0.001.
far_clip (float, optional): Farthest distance of frustum.
Defaults to 100.
Returns:
np.ndarray, shape=[8, 3]: coordinates of frustum corners.
Expand Down Expand Up @@ -742,12 +759,12 @@ def points_in_convex_polygon_3d_jit(points,
Args:
points (np.ndarray): Input points with shape of (num_points, 3).
polygon_surfaces (np.ndarray): Polygon surfaces with shape of \
(num_polygon, max_num_surfaces, max_num_points_of_surface, 3). \
All surfaces' normal vector must direct to internal. \
polygon_surfaces (np.ndarray): Polygon surfaces with shape of
(num_polygon, max_num_surfaces, max_num_points_of_surface, 3).
All surfaces' normal vector must direct to internal.
Max_num_points_of_surface must at least 3.
num_surfaces (np.ndarray): Number of surfaces a polygon contains \
shape of (num_polygon).
num_surfaces (np.ndarray, optional): Number of surfaces a polygon
contains shape of (num_polygon). Defaults to None.
Returns:
np.ndarray: Result matrix with the shape of [num_points, num_polygon].
Expand All @@ -772,7 +789,8 @@ def points_in_convex_polygon_jit(points, polygon, clockwise=True):
points (np.ndarray): Input points with the shape of [num_points, 2].
polygon (np.ndarray): Input polygon with the shape of
[num_polygon, num_points_of_polygon, 2].
clockwise (bool): Indicate polygon is clockwise.
clockwise (bool, optional): Indicate polygon is clockwise. Defaults
to True.
Returns:
np.ndarray: Result matrix with the shape of [num_points, num_polygon].
Expand Down Expand Up @@ -821,10 +839,11 @@ def boxes3d_to_corners3d_lidar(boxes3d, bottom_center=True):
2 -------- 1
Args:
boxes3d (np.ndarray): Boxes with shape of (N, 7) \
[x, y, z, w, l, h, ry] in LiDAR coords, see the definition of ry \
boxes3d (np.ndarray): Boxes with shape of (N, 7)
[x, y, z, w, l, h, ry] in LiDAR coords, see the definition of ry
in KITTI dataset.
bottom_center (bool): Whether z is on the bottom center of object.
bottom_center (bool, optional): Whether z is on the bottom center
of object. Defaults to True.
Returns:
np.ndarray: Box corners with the shape of [N, 8, 3].
Expand Down
10 changes: 0 additions & 10 deletions mmdet3d/core/bbox/structures/coord_3d_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,21 +227,11 @@ def convert_point(point, src, dst, rt_mat=None):
if rt_mat is None:
rt_mat = arr.new_tensor([[0, 0, 1], [-1, 0, 0], [0, -1, 0]])
elif src == Coord3DMode.DEPTH and dst == Coord3DMode.CAM:
# LIDAR-CAM conversion is different from DEPTH-CAM conversion
# because SUNRGB-D camera calibration files are different from
# that of KITTI, and currently we keep this hack
if rt_mat is None:
rt_mat = arr.new_tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
else:
rt_mat = rt_mat.new_tensor(
[[1, 0, 0], [0, 0, -1], [0, 1, 0]]) @ \
rt_mat.transpose(1, 0)
elif src == Coord3DMode.CAM and dst == Coord3DMode.DEPTH:
if rt_mat is None:
rt_mat = arr.new_tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]])
else:
rt_mat = rt_mat @ rt_mat.new_tensor([[1, 0, 0], [0, 0, 1],
[0, -1, 0]])
elif src == Coord3DMode.LIDAR and dst == Coord3DMode.DEPTH:
if rt_mat is None:
rt_mat = arr.new_tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]])
Expand Down
9 changes: 7 additions & 2 deletions mmdet3d/core/bbox/structures/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,14 @@ def get_box_type(box_type):
return box_type_3d, box_mode_3d


def points_cam2img(points_3d, proj_mat):
def points_cam2img(points_3d, proj_mat, with_depth=False):
"""Project points from camera coordicates to image coordinates.
Args:
points_3d (torch.Tensor): Points in shape (N, 3)
points_3d (torch.Tensor): Points in shape (N, 3).
proj_mat (torch.Tensor): Transformation matrix between coordinates.
with_depth (bool, optional): Whether to keep depth in the output.
Defaults to False.
Returns:
torch.Tensor: Points in image coordinates with shape [N, 2].
Expand All @@ -141,6 +143,9 @@ def points_cam2img(points_3d, proj_mat):
[points_3d, points_3d.new_ones(*points_shape)], dim=-1)
point_2d = torch.matmul(points_4, proj_mat.t())
point_2d_res = point_2d[..., :2] / point_2d[..., 2:3]

if with_depth:
return torch.cat([point_2d_res, point_2d[..., 2:3]], dim=-1)
return point_2d_res


Expand Down
18 changes: 3 additions & 15 deletions mmdet3d/core/visualizer/image_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def draw_lidar_bbox3d_on_img(bboxes3d,
return plot_rect3d_on_img(img, num_bbox, imgfov_pts_2d, color, thickness)


# TODO: remove third parameter in all functions here in favour of img_metas
def draw_depth_bbox3d_on_img(bboxes3d,
raw_img,
calibs,
Expand All @@ -137,35 +138,22 @@ def draw_depth_bbox3d_on_img(bboxes3d,
color (tuple[int]): The color to draw bboxes. Default: (0, 255, 0).
thickness (int, optional): The thickness of bboxes. Default: 1.
"""
from mmdet3d.core import Coord3DMode
from mmdet3d.core.bbox import points_cam2img
from mmdet3d.models import apply_3d_transformation

img = raw_img.copy()
calibs = copy.deepcopy(calibs)
img_metas = copy.deepcopy(img_metas)
corners_3d = bboxes3d.corners
num_bbox = corners_3d.shape[0]
points_3d = corners_3d.reshape(-1, 3)
assert ('Rt' in calibs.keys() and 'K' in calibs.keys()), \
'Rt and K matrix should be provided as camera caliberation information'
if not isinstance(calibs['Rt'], torch.Tensor):
calibs['Rt'] = torch.from_numpy(np.array(calibs['Rt']))
if not isinstance(calibs['K'], torch.Tensor):
calibs['K'] = torch.from_numpy(np.array(calibs['K']))
calibs['Rt'] = calibs['Rt'].reshape(3, 3).float().cpu()
calibs['K'] = calibs['K'].reshape(3, 3).float().cpu()

# first reverse the data transformations
xyz_depth = apply_3d_transformation(
points_3d, 'DEPTH', img_metas, reverse=True)

# then convert from depth coords to camera coords
xyz_cam = Coord3DMode.convert_point(
xyz_depth, Coord3DMode.DEPTH, Coord3DMode.CAM, rt_mat=calibs['Rt'])

# project to 2d to get image coords (uv)
uv_origin = points_cam2img(xyz_cam, calibs['K'])
uv_origin = points_cam2img(xyz_depth,
xyz_depth.new_tensor(img_metas['depth2img']))
uv_origin = (uv_origin - 1).round()
imgfov_pts_2d = uv_origin[..., :2].reshape(num_bbox, 8, 2).numpy()

Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/datasets/pipelines/formating.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class Collect3D(object):
- 'ori_shape': original shape of the image as a tuple (h, w, c)
- 'pad_shape': image shape after padding
- 'lidar2img': transform from lidar to image
- 'depth2img': transform from depth to image
- 'pcd_horizontal_flip': a boolean indicating if point cloud is \
flipped horizontally
- 'pcd_vertical_flip': a boolean indicating if point cloud is \
Expand Down Expand Up @@ -134,7 +135,7 @@ class Collect3D(object):
def __init__(self,
keys,
meta_keys=('filename', 'ori_shape', 'img_shape', 'lidar2img',
'pad_shape', 'scale_factor', 'flip',
'depth2img', 'pad_shape', 'scale_factor', 'flip',
'cam_intrinsic', 'pcd_horizontal_flip',
'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d',
'img_norm_cfg', 'rect', 'Trv2c', 'P2', 'pcd_trans',
Expand Down
Loading

0 comments on commit ff62af6

Please sign in to comment.