From 4b3c26aab553dcb8acd5632aed67b0804a80ee9a Mon Sep 17 00:00:00 2001 From: zhoujiaming Date: Wed, 5 Aug 2020 17:01:53 +0800 Subject: [PATCH 01/10] add primitive head --- mmdet3d/models/dense_heads/primitive_head.py | 816 +++++++++++++++++++ tests/test_heads.py | 101 +++ 2 files changed, 917 insertions(+) create mode 100644 mmdet3d/models/dense_heads/primitive_head.py diff --git a/mmdet3d/models/dense_heads/primitive_head.py b/mmdet3d/models/dense_heads/primitive_head.py new file mode 100644 index 0000000000..801a0de2ad --- /dev/null +++ b/mmdet3d/models/dense_heads/primitive_head.py @@ -0,0 +1,816 @@ +import torch +from mmcv.cnn import ConvModule +from torch import nn as nn + +from mmdet3d.models.builder import build_loss +from mmdet3d.models.model_utils import VoteModule +from mmdet3d.ops import PointSAModule, furthest_point_sample +from mmdet.core import multi_apply +from mmdet.models import HEADS + + +@HEADS.register_module() +class PrimitiveHead(nn.Module): + r"""Bbox head of `Votenet `_. + + Args: + num_dim (int): The dimension of primitive. + num_classes (int): The number of class. + primitive_mode (str): The mode of primitive. + bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and + decoding boxes. + train_cfg (dict): Config for training. + test_cfg (dict): Config for testing. + vote_moudule_cfg (dict): Config of VoteModule for point-wise votes. + vote_aggregation_cfg (dict): Config of vote aggregation layer. + feat_channels (tuple[int]): Convolution channels of + prediction layer. + conv_cfg (dict): Config of convolution in prediction layer. + norm_cfg (dict): Config of BN in prediction layer. + objectness_loss (dict): Config of objectness loss. + center_loss (dict): Config of center loss. + semantic_loss (dict): Config of point-wise semantic segmentation loss. + """ + + def __init__(self, + num_dim, + num_classes, + primitive_mode, + train_cfg=None, + test_cfg=None, + vote_moudule_cfg=None, + vote_aggregation_cfg=None, + feat_channels=(128, 128), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + objectness_loss=None, + center_loss=None, + semantic_loss=None): + super(PrimitiveHead, self).__init__() + self.num_dim = num_dim + self.num_classes = num_classes + self.primitive_mode = primitive_mode + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.gt_per_seed = vote_moudule_cfg['gt_per_seed'] + self.num_proposal = vote_aggregation_cfg['num_point'] + + self.objectness_loss = build_loss(objectness_loss) + self.center_loss = build_loss(center_loss) + self.semantic_loss = build_loss(semantic_loss) + + assert vote_aggregation_cfg['mlp_channels'][0] == vote_moudule_cfg[ + 'in_channels'] + + # Existence flag prediction + self.flag_conv = ConvModule( + 256, + 128, + 1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + bias=True, + inplace=True) + self.flag_pred = torch.nn.Conv1d(128, 2, 1) + + self.vote_module = VoteModule(**vote_moudule_cfg) + self.vote_aggregation = PointSAModule(**vote_aggregation_cfg) + + prev_channel = vote_aggregation_cfg['mlp_channels'][-1] + conv_pred_list = list() + for k in range(len(feat_channels)): + conv_pred_list.append( + ConvModule( + prev_channel, + feat_channels[k], + 1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + bias=True, + inplace=True)) + prev_channel = feat_channels[k] + self.conv_pred = nn.Sequential(*conv_pred_list) + + conv_out_channel = 3 + num_dim + num_classes + self.conv_pred.add_module('conv_out', + nn.Conv1d(prev_channel, conv_out_channel, 1)) + + def init_weights(self): + """Initialize weights of VoteHead.""" + pass + + def forward(self, feat_dict, sample_mod): + """Forward pass. + + Note: + The forward of VoteHead is devided into 4 steps: + + 1. Generate vote_points from seed_points. + 2. Aggregate vote_points. + 3. Predict primitive cue and score. + 4. Decode predictions. + + Args: + feat_dict (dict): Feature dict from backbone. + sample_mod (str): Sample mode for vote aggregation layer. + valid modes are "vote", "seed" and "random". + + Returns: + dict: Predictions of primitive head. + """ + assert sample_mod in ['vote', 'seed', 'random'] + + seed_points = feat_dict['fp_xyz_net0'][-1] + seed_features = feat_dict['hd_feature'] + results = {} + + # net_flag = F.relu(self.bn_flag1(self.conv_flag1(seed_features))) + # net_flag = self.conv_flag2(net_flag) + net_flag = self.flag_conv(seed_features) + net_flag = self.flag_pred(net_flag) + + results['pred_flag_' + self.primitive_mode] = net_flag + + # 1. generate vote_points from seed_points + vote_points, vote_features = self.vote_module(seed_points, + seed_features) + results['vote_' + self.primitive_mode] = vote_points + results['vote_features_' + self.primitive_mode] = vote_features + + # 2. aggregate vote_points + if sample_mod == 'vote': + # use fps in vote_aggregation + sample_indices = None + elif sample_mod == 'seed': + # FPS on seed and choose the votes corresponding to the seeds + sample_indices = furthest_point_sample(seed_points, + self.num_proposal) + elif sample_mod == 'random': + # Random sampling from the votes + batch_size, num_seed = seed_points.shape[:2] + sample_indices = seed_points.new_tensor( + torch.randint(0, num_seed, (batch_size, self.num_proposal)), + dtype=torch.int32) + else: + raise NotImplementedError + + vote_aggregation_ret = self.vote_aggregation(vote_points, + vote_features, + sample_indices) + aggregated_points, features, aggregated_indices = vote_aggregation_ret + results['aggregated_points_' + self.primitive_mode] = aggregated_points + results['aggregated_features_' + self.primitive_mode] = features + results['aggregated_indices_' + + self.primitive_mode] = aggregated_indices + + # 3. predict bbox and score + predictions = self.conv_pred(features) + + # 4. decode predictions + newcenter, decode_res = self.primitive_decode_scores( + predictions, results, self.num_classes, mode=self.primitive_mode) + results.update(decode_res) + + return results + + def loss(self, + bbox_preds, + points, + gt_bboxes_3d, + gt_labels_3d, + pts_semantic_mask=None, + pts_instance_mask=None, + img_metas=None, + gt_bboxes_ignore=None): + """Compute loss. + + Args: + bbox_preds (dict): Predictions from forward of primitive head. + points (list[torch.Tensor]): Input points. + gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \ + bboxes of each sample. + gt_labels_3d (list[torch.Tensor]): Labels of each sample. + pts_semantic_mask (None | list[torch.Tensor]): Point-wise + semantic mask. + pts_instance_mask (None | list[torch.Tensor]): Point-wise + instance mask. + img_metas (list[dict]): Contain pcd and img's meta info. + gt_bboxes_ignore (None | list[torch.Tensor]): Specify + which bounding. + + Returns: + dict: Losses of Votenet. + """ + targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d, + pts_semantic_mask, pts_instance_mask) + + (point_mask, point_sem, point_offset) = targets + + losses = {} + flag_loss = self.compute_flag_loss( + bbox_preds, point_mask, mode=self.primitive_mode) * 30 + losses['flag_loss_' + self.primitive_mode] = flag_loss + + # calculate vote loss + vote_loss = self.vote_module.get_loss( + bbox_preds['seed_points'], + bbox_preds['vote_' + self.primitive_mode], + bbox_preds['seed_indices'], point_mask, point_offset) + losses['vote_loss_' + self.primitive_mode] = vote_loss + + center_loss, size_loss, sem_cls_loss = self.compute_primitivesem_loss( + bbox_preds, + point_mask, + point_offset, + point_sem, + mode=self.primitive_mode) + losses['center_loss_' + self.primitive_mode] = center_loss + losses['size_loss_' + self.primitive_mode] = size_loss + losses['sem_loss_' + self.primitive_mode] = sem_cls_loss + losses['surface_loss_' + self.primitive_mode] = center_loss * 0.5 + \ + size_loss * 0.5 + sem_cls_loss + return losses + + def get_targets(self, + points, + gt_bboxes_3d, + gt_labels_3d, + pts_semantic_mask=None, + pts_instance_mask=None, + bbox_preds=None): + """Generate targets of vote head. + + Args: + points (list[torch.Tensor]): Points of each batch. + gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \ + bboxes of each batch. + gt_labels_3d (list[torch.Tensor]): Labels of each batch. + pts_semantic_mask (None | list[torch.Tensor]): Point-wise semantic + label of each batch. + pts_instance_mask (None | list[torch.Tensor]): Point-wise instance + label of each batch. + bbox_preds (torch.Tensor): Predictions of primitive head. + + Returns: + tuple[torch.Tensor]: Targets of primitive head. + """ + valid_gt_masks = list() + gt_num = list() + for index in range(len(gt_labels_3d)): + if len(gt_labels_3d[index]) == 0: + fake_box = gt_bboxes_3d[index].tensor.new_zeros( + 1, gt_bboxes_3d[index].tensor.shape[-1]) + gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box) + gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1) + valid_gt_masks.append(gt_labels_3d[index].new_zeros(1)) + gt_num.append(1) + else: + valid_gt_masks.append(gt_labels_3d[index].new_ones( + gt_labels_3d[index].shape)) + gt_num.append(gt_labels_3d[index].shape[0]) + + if pts_semantic_mask is None: + pts_semantic_mask = [None for i in range(len(gt_labels_3d))] + pts_instance_mask = [None for i in range(len(gt_labels_3d))] + + (point_mask, point_sem, + point_offset) = multi_apply(self.get_targets_single, points, + gt_bboxes_3d, gt_labels_3d, + pts_semantic_mask, pts_instance_mask) + + point_mask = torch.stack(point_mask) + point_sem = torch.stack(point_sem) + point_offset = torch.stack(point_offset) + + return (point_mask, point_sem, point_offset) + + def get_targets_single(self, + points, + gt_bboxes_3d, + gt_labels_3d, + pts_semantic_mask=None, + pts_instance_mask=None): + """Generate targets of vote head for single batch. + + Args: + points (torch.Tensor): Points of each batch. + gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth \ + boxes of each batch. + gt_labels_3d (torch.Tensor): Labels of each batch. + pts_semantic_mask (None | torch.Tensor): Point-wise semantic + label of each batch. + pts_instance_mask (None | torch.Tensor): Point-wise instance + label of each batch. + + Returns: + tuple[torch.Tensor]: Targets of primitive head. + """ + gt_bboxes_3d = gt_bboxes_3d.to(points.device) + num_points = points.shape[0] + device = points.device + + if self.primitive_mode == 'z': + point_boundary_mask_z = torch.zeros(num_points).to(device) + point_boundary_offset_z = torch.zeros([num_points, 3]).to(device) + point_boundary_sem_z = torch.zeros( + [num_points, 3 + self.num_dim + 1]).to(device) + elif self.primitive_mode == 'xy': + point_boundary_mask_xy = torch.zeros(num_points).to(device) + point_boundary_offset_xy = torch.zeros([num_points, 3]).to(device) + point_boundary_sem_xy = torch.zeros( + [num_points, 3 + self.num_dim + 1]).to(device) + elif self.primitive_mode == 'line': + point_line_mask = torch.zeros(num_points).to(device) + point_line_offset = torch.zeros([num_points, 3]).to(device) + point_line_sem = torch.zeros([num_points, 3 + 1]).to(device) + else: + NotImplementedError + + instance_flag = torch.nonzero( + pts_semantic_mask != self.num_classes).squeeze(1) + instance_labels = pts_instance_mask[instance_flag].unique() + + for i, i_instance in enumerate(instance_labels): + ind = instance_flag[pts_instance_mask[instance_flag] == i_instance] + x = points[ind, :3] + + # Corners + corners = gt_bboxes_3d.corners[i][[0, 1, 3, 2, 4, 5, 7, 6]] + xmin, ymin, zmin = corners.min(0)[0] + xmax, ymax, zmax = corners.max(0)[0] + + # Get lower four lines + plane_lower_temp = torch.as_tensor([0, 0, 1, + -corners[6, -1]]).to(device) + para_points = corners[[1, 3, 5, 7]] + newd = torch.sum(para_points * plane_lower_temp[:3], 1) + if self.check_upright(para_points) and \ + plane_lower_temp[0] + plane_lower_temp[1] < \ + self.train_cfg['lower_thresh']: + plane_lower = torch.as_tensor([0, 0, 1, plane_lower_temp[-1] + ]).to(device) + plane_upper = torch.as_tensor([0, 0, 1, + -torch.mean(newd)]).to(device) + else: + import pdb + pdb.set_trace() + print('error with upright') + if self.check_z(plane_upper, para_points) is False: + import pdb + pdb.set_trace() + + # Get the boundary points here + alldist = torch.abs( + torch.sum(x * plane_lower[:3], 1) + plane_lower[-1]) + mind = alldist.min() + sel = torch.abs(alldist - mind) < self.train_cfg['dist_thresh'] + + # Get lower four lines + line_sel1, line_sel2, line_sel3, line_sel4 = self.get_linesel( + x[sel], xmin, xmax, ymin, ymax) + if self.primitive_mode == 'line': + if torch.sum(line_sel1) > self.train_cfg['num_point_line']: + point_line_mask[ind[sel][line_sel1]] = 1.0 + linecenter = torch.mean(x[sel][line_sel1], axis=0) + linecenter[1] = (ymin + ymax) / 2.0 + point_line_offset[ + ind[sel][line_sel1]] = linecenter - x[sel][line_sel1] + point_line_sem[ind[sel][line_sel1]] = torch.as_tensor([ + linecenter[0], linecenter[1], linecenter[2], + (pts_semantic_mask[ind][0]) + ]).to(device) + if torch.sum(line_sel2) > self.train_cfg['num_point_line']: + point_line_mask[ind[sel][line_sel2]] = 1.0 + linecenter = torch.mean(x[sel][line_sel2], axis=0) + linecenter[1] = (ymin + ymax) / 2.0 + point_line_offset[ + ind[sel][line_sel2]] = linecenter - x[sel][line_sel2] + point_line_sem[ind[sel][line_sel2]] = torch.as_tensor([ + linecenter[0], linecenter[1], linecenter[2], + (pts_semantic_mask[ind][0]) + ]).to(device) + if torch.sum(line_sel3) > self.train_cfg['num_point_line']: + point_line_mask[ind[sel][line_sel3]] = 1.0 + linecenter = torch.mean(x[sel][line_sel3], axis=0) + linecenter[0] = (xmin + xmax) / 2.0 + point_line_offset[ + ind[sel][line_sel3]] = linecenter - x[sel][line_sel3] + point_line_sem[ind[sel][line_sel3]] = torch.as_tensor([ + linecenter[0], linecenter[1], linecenter[2], + pts_semantic_mask[ind][0] + ]).to(device) + if torch.sum(line_sel4) > self.train_cfg['num_point_line']: + point_line_mask[ind[sel][line_sel4]] = 1.0 + linecenter = torch.mean(x[sel][line_sel4], axis=0) + linecenter[0] = (xmin + xmax) / 2.0 + point_line_offset[ + ind[sel][line_sel4]] = linecenter - x[sel][line_sel4] + point_line_sem[ind[sel][line_sel4]] = torch.as_tensor([ + linecenter[0], linecenter[1], linecenter[2], + (pts_semantic_mask[ind][0]) + ]).to(device) + + # Set the surface labels here + if self.primitive_mode == 'z': + if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( + alldist[sel]) < self.train_cfg['var_thresh']: + center = torch.as_tensor([ + (xmin + xmax) / 2.0, (ymin + ymax) / 2.0, + torch.mean(x[sel][:, 2]) + ]).to(device) + sel_global = ind[sel] + point_boundary_mask_z[sel_global] = 1.0 + point_boundary_sem_z[sel_global] = torch.as_tensor([ + center[0], center[1], center[2], xmax - xmin, + ymax - ymin, (pts_semantic_mask[ind][0]) + ]).to(device) + point_boundary_offset_z[sel_global] = center - x[sel] + + # Get the boundary points here + alldist = torch.abs( + torch.sum(x * plane_upper[:3], 1) + plane_upper[-1]) + mind = alldist.min() + sel = torch.abs(alldist - mind) < self.train_cfg['dist_thresh'] + + # Get upper four lines + line_sel1, line_sel2, line_sel3, line_sel4 = self.get_linesel( + x[sel], xmin, xmax, ymin, ymax) + + if self.primitive_mode == 'line': + if torch.sum(line_sel1) > self.train_cfg['num_point_line']: + point_line_mask[ind[sel][line_sel1]] = 1.0 + linecenter = torch.mean(x[sel][line_sel1], axis=0) + linecenter[1] = (ymin + ymax) / 2.0 + point_line_offset[ + ind[sel][line_sel1]] = linecenter - x[sel][line_sel1] + point_line_sem[ind[sel][line_sel1]] = torch.as_tensor([ + linecenter[0], linecenter[1], linecenter[2], + (pts_semantic_mask[ind][0]) + ]).to(device) + if torch.sum(line_sel2) > self.train_cfg['num_point_line']: + point_line_mask[ind[sel][line_sel2]] = 1.0 + linecenter = torch.mean(x[sel][line_sel2], axis=0) + linecenter[1] = (ymin + ymax) / 2.0 + point_line_offset[ + ind[sel][line_sel2]] = linecenter - x[sel][line_sel2] + point_line_sem[ind[sel][line_sel2]] = torch.as_tensor([ + linecenter[0], linecenter[1], linecenter[2], + (pts_semantic_mask[ind][0]) + ]).to(device) + if torch.sum(line_sel3) > self.train_cfg['num_point_line']: + point_line_mask[ind[sel][line_sel3]] = 1.0 + linecenter = torch.mean(x[sel][line_sel3], axis=0) + linecenter[0] = (xmin + xmax) / 2.0 + point_line_offset[ + ind[sel][line_sel3]] = linecenter - x[sel][line_sel3] + point_line_sem[ind[sel][line_sel3]] = torch.as_tensor([ + linecenter[0], linecenter[1], linecenter[2], + (pts_semantic_mask[ind][0]) + ]).to(device) + if torch.sum(line_sel4) > self.train_cfg['num_point_line']: + point_line_mask[ind[sel][line_sel4]] = 1.0 + linecenter = torch.mean(x[sel][line_sel4], axis=0) + linecenter[0] = (xmin + xmax) / 2.0 + point_line_offset[ + ind[sel][line_sel4]] = linecenter - x[sel][line_sel4] + point_line_sem[ind[sel][line_sel4]] = torch.as_tensor([ + linecenter[0], linecenter[1], linecenter[2], + (pts_semantic_mask[ind][0]) + ]).to(device) + + if self.primitive_mode == 'z': + if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( + alldist[sel]) < self.train_cfg['var_thresh']: + center = torch.as_tensor([ + (xmin + xmax) / 2.0, (ymin + ymax) / 2.0, + torch.mean(x[sel][:, 2]) + ]).to(device) + sel_global = ind[sel] + point_boundary_mask_z[sel_global] = 1.0 + point_boundary_sem_z[sel_global] = torch.as_tensor([ + center[0], center[1], center[2], xmax - xmin, + ymax - ymin, (pts_semantic_mask[ind][0]) + ]).to(device) + point_boundary_offset_z[sel_global] = center - x[sel] + + # Get left two lines + v1 = corners[3] - corners[2] + v2 = corners[2] - corners[0] + cp = torch.cross(v1, v2) + d = -torch.dot(cp, corners[0]) + a, b, c = cp + plane_left_temp = torch.as_tensor([a, b, c, d]).to(device) + para_points = corners[[4, 5, 6, 7]] + # Normalize xy here + plane_left_temp /= torch.norm(plane_left_temp[:3]) + newd = torch.sum(para_points * plane_left_temp[:3], 1) + if plane_left_temp[2] < self.train_cfg['lower_thresh']: + plane_left = plane_left_temp + plane_right = torch.as_tensor([ + plane_left_temp[0], plane_left_temp[1], plane_left_temp[2], + -torch.mean(newd) + ]).to(device) + else: + import pdb + pdb.set_trace() + print('error with upright') + + # Get the boundary points here + alldist = torch.abs( + torch.sum(x * plane_left[:3], 1) + plane_left[-1]) + mind = alldist.min() + sel = torch.abs(alldist - mind) < self.train_cfg['dist_thresh'] + + # Get upper four lines + line_sel1, line_sel2 = self.get_linesel2( + x[sel], ymin, ymax, zmin, zmax, axis=1) + if self.primitive_mode == 'line': + if torch.sum(line_sel1) > self.train_cfg['num_point_line']: + point_line_mask[ind[sel][line_sel1]] = 1.0 + linecenter = torch.mean(x[sel][line_sel1], axis=0) + linecenter[2] = (zmin + zmax) / 2.0 + point_line_offset[ + ind[sel][line_sel1]] = linecenter - x[sel][line_sel1] + point_line_sem[ind[sel][line_sel1]] = torch.as_tensor([ + linecenter[0], linecenter[1], linecenter[2], + (pts_semantic_mask[ind][0]) + ]).to(device) + if torch.sum(line_sel2) > self.train_cfg['num_point_line']: + point_line_mask[ind[sel][line_sel2]] = 1.0 + linecenter = torch.mean(x[sel][line_sel2], axis=0) + linecenter[2] = (zmin + zmax) / 2.0 + point_line_offset[ + ind[sel][line_sel2]] = linecenter - x[sel][line_sel2] + point_line_sem[ind[sel][line_sel2]] = torch.as_tensor([ + linecenter[0], linecenter[1], linecenter[2], + (pts_semantic_mask[ind][0]) + ]).to(device) + + if self.primitive_mode == 'xy': + if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( + alldist[sel]) < self.train_cfg['var_thresh']: + center = torch.as_tensor([ + torch.mean(x[sel][:, 0]), + torch.mean(x[sel][:, 1]), (zmin + zmax) / 2.0 + ]).to(device) + sel_global = ind[sel] + point_boundary_mask_xy[sel_global] = 1.0 + point_boundary_sem_xy[sel_global] = torch.as_tensor([ + center[0], center[1], center[2], zmax - zmin, + (pts_semantic_mask[ind][0]) + ]).to(device) + point_boundary_offset_xy[sel_global] = center - x[sel] + + # Get the boundary points here + alldist = torch.abs( + torch.sum(x * plane_right[:3], 1) + plane_right[-1]) + mind = alldist.min() + sel = torch.abs(alldist - mind) < self.train_cfg['dist_thresh'] + line_sel1, line_sel2 = self.get_linesel2( + x[sel], ymin, ymax, zmin, zmax, axis=1) + if self.primitive_mode == 'line': + if torch.sum(line_sel1) > self.train_cfg['num_point_line']: + point_line_mask[ind[sel][line_sel1]] = 1.0 + linecenter = torch.mean(x[sel][line_sel1], axis=0) + linecenter[2] = (zmin + zmax) / 2.0 + point_line_offset[ + ind[sel][line_sel1]] = linecenter - x[sel][line_sel1] + point_line_sem[ind[sel][line_sel1]] = torch.as_tensor([ + linecenter[0], linecenter[1], linecenter[2], + (pts_semantic_mask[ind][0]) + ]).to(device) + if torch.sum(line_sel2) > self.train_cfg['num_point_line']: + point_line_mask[ind[sel][line_sel2]] = 1.0 + linecenter = torch.mean(x[sel][line_sel2], axis=0) + linecenter[2] = (zmin + zmax) / 2.0 + point_line_offset[ + ind[sel][line_sel2]] = linecenter - x[sel][line_sel2] + point_line_sem[ind[sel][line_sel2]] = torch.as_tensor([ + linecenter[0], linecenter[1], linecenter[2], + (pts_semantic_mask[ind][0]) + ]).to(device) + + if self.primitive_mode == 'xy': + if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( + alldist[sel]) < self.train_cfg['var_thresh']: + center = torch.as_tensor([ + torch.mean(x[sel][:, 0]), + torch.mean(x[sel][:, 1]), (zmin + zmax) / 2.0 + ]).to(device) + sel_global = ind[sel] + point_boundary_mask_xy[sel_global] = 1.0 + point_boundary_sem_xy[sel_global] = torch.as_tensor([ + center[0], center[1], center[2], zmax - zmin, + (pts_semantic_mask[ind][0]) + ]).to(device) + point_boundary_offset_xy[sel_global] = center - x[sel] + + # Get the boundary points here + v1 = corners[0] - corners[4] + v2 = corners[4] - corners[5] + cp = torch.cross(v1, v2) + d = -torch.dot(cp, corners[5]) + a, b, c = cp + plane_front_temp = torch.as_tensor([a, b, c, d]).to(device) + para_points = corners[[2, 3, 6, 7]] + plane_front_temp /= torch.norm(plane_front_temp[:3]) + newd = torch.sum(para_points * plane_front_temp[:3], 1) + if plane_front_temp[2] < self.train_cfg['lower_thresh']: + plane_front = plane_front_temp + plane_back = torch.as_tensor([ + plane_front_temp[0], plane_front_temp[1], + plane_front_temp[2], -torch.mean(newd) + ]).to(device) + else: + import pdb + pdb.set_trace() + print('error with upright') + + # Get the boundary points here + alldist = torch.abs( + torch.sum(x * plane_front[:3], 1) + plane_front[-1]) + mind = alldist.min() + sel = torch.abs(alldist - mind) < self.train_cfg['dist_thresh'] + if self.primitive_mode == 'xy': + if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( + alldist[sel]) < self.train_cfg['var_thresh']: + center = torch.as_tensor([ + torch.mean(x[sel][:, 0]), + torch.mean(x[sel][:, 1]), (zmin + zmax) / 2.0 + ]).to(device) + sel_global = ind[sel] + point_boundary_mask_xy[sel_global] = 1.0 + point_boundary_sem_xy[sel_global] = torch.as_tensor([ + center[0], center[1], center[2], zmax - zmin, + (pts_semantic_mask[ind][0]) + ]).to(device) + point_boundary_offset_xy[sel_global] = center - x[sel] + # Get the boundary points here + alldist = torch.abs( + torch.sum(x * plane_back[:3], 1) + plane_back[-1]) + mind = alldist.min() + sel = torch.abs(alldist - mind) < self.train_cfg['dist_thresh'] + if self.primitive_mode == 'xy': + if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( + alldist[sel]) < self.train_cfg['var_thresh']: + center = torch.as_tensor([ + torch.mean(x[sel][:, 0]), + torch.mean(x[sel][:, 1]), (zmin + zmax) / 2.0 + ]).to(device) + sel_global = ind[sel] + point_boundary_mask_xy[sel_global] = 1.0 + point_boundary_sem_xy[sel_global] = torch.as_tensor([ + center[0], center[1], center[2], zmax - zmin, + (pts_semantic_mask[ind][0]) + ]).to(device) + point_boundary_offset_xy[sel_global] = center - x[sel] + + if self.primitive_mode == 'z': + return (point_boundary_mask_z, point_boundary_sem_z, + point_boundary_offset_z) + elif self.primitive_mode == 'xy': + return (point_boundary_mask_xy, point_boundary_sem_xy, + point_boundary_offset_xy) + elif self.primitive_mode == 'line': + return (point_line_mask, point_line_sem, point_line_offset) + else: + NotImplementedError + + def primitive_decode_scores(self, net, end_points, num_class, mode=''): + net_transposed = net.transpose(2, 1) # (batch_size, 1024, ..) + + base_xyz = end_points['aggregated_points_' + + mode] # (batch_size, num_proposal, 3) + center = base_xyz + net_transposed[:, :, 0: + 3] # (batch_size, num_proposal, 3) + end_points['center_' + mode] = center + + if mode in ['z', 'xy']: + end_points['size_residuals_' + mode] = net_transposed[:, :, 3:3 + + self.num_dim] + + end_points['sem_cls_scores_' + mode] = net_transposed[:, :, 3 + + self.num_dim:] + + return center, end_points + + def check_upright(self, para_points): + return (para_points[0][-1] == para_points[1][-1]) and ( + para_points[1][-1] + == para_points[2][-1]) and (para_points[2][-1] + == para_points[3][-1]) + + def check_z(self, plane_equ, para_points): + return torch.sum(para_points[:, 2] + + plane_equ[-1]) / 4.0 < self.train_cfg['lower_thresh'] + + def get_linesel(self, points, xmin, xmax, ymin, ymax): + sel1 = torch.abs(points[:, 0] - xmin) < self.train_cfg['line_thresh'] + sel2 = torch.abs(points[:, 0] - xmax) < self.train_cfg['line_thresh'] + sel3 = torch.abs(points[:, 1] - ymin) < self.train_cfg['line_thresh'] + sel4 = torch.abs(points[:, 1] - ymax) < self.train_cfg['line_thresh'] + return sel1, sel2, sel3, sel4 + + def get_linesel2(self, points, ymin, ymax, zmin, zmax, axis=0): + sel3 = torch.abs(points[:, axis] - + ymin) < self.train_cfg['line_thresh'] + sel4 = torch.abs(points[:, axis] - + ymax) < self.train_cfg['line_thresh'] + return sel3, sel4 + + def compute_flag_loss(self, end_points, point_mask, mode): + # Compute existence flag for face and edge centers + # Load ground truth votes and assign them to seed points + seed_inds = end_points['seed_indices'].long() + + seed_gt_votes_mask = torch.gather(point_mask, 1, seed_inds).float() + end_points['sem_mask'] = seed_gt_votes_mask + + sem_cls_label = torch.gather(point_mask, 1, seed_inds) + end_points['sub_point_sem_cls_label_' + mode] = sem_cls_label + + pred_flag = end_points['pred_flag_' + mode] + + sem_loss = self.objectness_loss(pred_flag, sem_cls_label.long()) + + return sem_loss + + def compute_primitivesem_loss(self, + end_points, + point_mask, + point_offset, + point_sem, + mode=''): + """Compute final geometric primitive center and semantic.""" + # Load ground truth votes and assign them to seed points + batch_size = end_points['seed_points'].shape[0] + num_seed = end_points['seed_points'].shape[1] # B,num_seed,3 + vote_xyz = end_points['center_' + mode] # B,num_seed*vote_factor,3 + seed_inds = end_points['seed_indices'].long() + + num_proposal = end_points['aggregated_points_' + + mode].shape[1] # B,num_seed,3 + + seed_gt_votes_mask = torch.gather(point_mask, 1, seed_inds) + seed_inds_expand = seed_inds.view(batch_size, num_seed, + 1).repeat(1, 1, 3) + + seed_inds_expand_sem = seed_inds.view(batch_size, num_seed, 1).repeat( + 1, 1, 4 + self.num_dim) + + seed_gt_votes = torch.gather(point_offset, 1, seed_inds_expand) + seed_gt_sem = torch.gather(point_sem, 1, seed_inds_expand_sem) + seed_gt_votes += end_points['seed_points'] + + end_points['surface_center_gt_' + mode] = seed_gt_votes + end_points['surface_sem_gt_' + mode] = seed_gt_sem + end_points['surface_mask_gt_' + mode] = seed_gt_votes_mask + + # Compute the min of min of distance + vote_xyz_reshape = vote_xyz.view(batch_size * num_proposal, -1, 3) + seed_gt_votes_reshape = seed_gt_votes.view(batch_size * num_proposal, + 1, 3) + # A predicted vote to no where is not penalized as long as there is a + # good vote near the GT vote. + center_loss = self.center_loss( + vote_xyz_reshape, + seed_gt_votes_reshape, + dst_weight=seed_gt_votes_mask.view(batch_size * num_proposal, + 1))[1] + center_loss = center_loss.sum() / ( + torch.sum(seed_gt_votes_mask.float()) + 1e-6) + + # Compute the min of min of distance + # Need to remove this soon + if mode != 'line': + size_xyz = end_points[ + 'size_residuals_' + + mode].contiguous() # B,num_seed*vote_factor,3 + size_xyz_reshape = size_xyz.view(batch_size * num_proposal, -1, + self.num_dim).contiguous() + seed_gt_votes_reshape = seed_gt_sem[:, :, 3:3 + self.num_dim].view( + batch_size * num_proposal, 1, self.num_dim).contiguous() + # A predicted vote to no where is not penalized as long as + # there is a good vote near the GT vote. + size_loss = self.center_loss( + size_xyz_reshape, + seed_gt_votes_reshape, + dst_weight=seed_gt_votes_mask.view(batch_size * num_proposal, + 1))[1] + size_loss = size_loss.sum() / ( + torch.sum(seed_gt_votes_mask.float()) + 1e-6) + else: + size_loss = torch.tensor(0) + + # 3.4 Semantic cls loss + sem_cls_label = seed_gt_sem[:, :, -1].long() + end_points['supp_sem_' + mode] = sem_cls_label + sem_cls_loss = self.semantic_loss( + end_points['sem_cls_scores_' + mode].transpose(2, 1), + sem_cls_label) + sem_cls_loss = torch.sum(sem_cls_loss * seed_gt_votes_mask.float()) / ( + torch.sum(seed_gt_votes_mask.float()) + 1e-6) + + return center_loss, size_loss, sem_cls_loss diff --git a/tests/test_heads.py b/tests/test_heads.py index 865db50ea5..ca060a853e 100644 --- a/tests/test_heads.py +++ b/tests/test_heads.py @@ -457,3 +457,104 @@ def test_free_anchor_3D_head(): gt_labels, input_metas, None) assert losses['positive_bag_loss'] >= 0 assert losses['negative_bag_loss'] >= 0 + + +def test_primitive_head(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + _setup_seed(0) + + primitive_head_cfg = dict( + type='PrimitiveHead', + num_dim=2, + num_classes=18, + primitive_mode='z', + vote_moudule_cfg=dict( + in_channels=256, + vote_per_seed=1, + gt_per_seed=1, + conv_channels=(256, 256), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + norm_feats=True, + vote_loss=dict( + type='ChamferDistance', + mode='l1', + reduction='none', + loss_dst_weight=10.0)), + vote_aggregation_cfg=dict( + num_point=64, + radius=0.3, + num_sample=16, + mlp_channels=[256, 128, 128, 128], + use_xyz=True, + normalize_xyz=True), + feat_channels=(128, 128), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + objectness_loss=dict( + type='CrossEntropyLoss', + class_weight=[0.4, 0.6], + reduction='mean', + loss_weight=1.0), + center_loss=dict( + type='ChamferDistance', + mode='l1', + reduction='none', + loss_src_weight=1.0, + loss_dst_weight=1.0), + semantic_loss=dict( + type='CrossEntropyLoss', reduction='none', loss_weight=1.0), + train_cfg=dict( + dist_thresh=0.2, + var_thresh=1e-2, + lower_thresh=1e-6, + num_point=100, + num_point_line=10, + line_thresh=0.2)) + self = build_head(primitive_head_cfg).cuda() + fp_xyz = [torch.rand([2, 64, 3], dtype=torch.float32).cuda()] + hd_features = torch.rand([2, 256, 64], dtype=torch.float32).cuda() + fp_indices = [torch.randint(0, 64, [2, 64]).cuda()] + input_dict = dict( + fp_xyz_net0=fp_xyz, hd_feature=hd_features, fp_indices_net0=fp_indices) + + # test forward + ret_dict = self(input_dict, 'vote') + assert ret_dict['center_z'].shape == torch.Size([2, 64, 3]) + assert ret_dict['size_residuals_z'].shape == torch.Size([2, 64, 2]) + assert ret_dict['sem_cls_scores_z'].shape == torch.Size([2, 64, 18]) + assert ret_dict['aggregated_points_z'].shape == torch.Size([2, 64, 3]) + + # test loss + points = torch.rand([2, 1024, 3], dtype=torch.float32).cuda() + ret_dict['seed_points'] = fp_xyz[0] + ret_dict['seed_indices'] = fp_indices[0] + + from mmdet3d.core.bbox import DepthInstance3DBoxes + gt_bboxes_3d = [ + DepthInstance3DBoxes(torch.rand([4, 7], dtype=torch.float32).cuda()), + DepthInstance3DBoxes(torch.rand([4, 7], dtype=torch.float32).cuda()) + ] + gt_labels_3d = torch.randint(0, 18, [2, 4]).cuda() + gt_labels_3d = [gt_labels_3d[0], gt_labels_3d[1]] + pts_semantic_mask = torch.randint(0, 19, [2, 1024]).cuda() + pts_semantic_mask = [pts_semantic_mask[0], pts_semantic_mask[1]] + pts_instance_mask = torch.randint(0, 4, [2, 1024]).cuda() + pts_instance_mask = [pts_instance_mask[0], pts_instance_mask[1]] + + loss_input_dict = dict( + bbox_preds=ret_dict, + points=points, + gt_bboxes_3d=gt_bboxes_3d, + gt_labels_3d=gt_labels_3d, + pts_semantic_mask=pts_semantic_mask, + pts_instance_mask=pts_instance_mask) + losses_dict = self.loss(**loss_input_dict) + + assert losses_dict['flag_loss_z'] >= 0 + assert losses_dict['vote_loss_z'] >= 0 + assert losses_dict['center_loss_z'] >= 0 + assert losses_dict['size_loss_z'] >= 0 + assert losses_dict['sem_loss_z'] >= 0 + assert losses_dict['surface_loss_z'] >= 0 From f0dfd7369e4503eeaa4b700901fec97c2e1aa92b Mon Sep 17 00:00:00 2001 From: zhoujiaming Date: Wed, 5 Aug 2020 19:22:42 +0800 Subject: [PATCH 02/10] register of primitive head --- mmdet3d/models/dense_heads/__init__.py | 6 +++++- mmdet3d/models/model_utils/vote_module.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mmdet3d/models/dense_heads/__init__.py b/mmdet3d/models/dense_heads/__init__.py index 5e06a97509..67ca03266c 100644 --- a/mmdet3d/models/dense_heads/__init__.py +++ b/mmdet3d/models/dense_heads/__init__.py @@ -1,6 +1,10 @@ from .anchor3d_head import Anchor3DHead from .free_anchor3d_head import FreeAnchor3DHead from .parta2_rpn_head import PartA2RPNHead +from .primitive_head import PrimitiveHead from .vote_head import VoteHead -__all__ = ['Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead'] +__all__ = [ + 'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'PrimitiveHead', + 'VoteHead' +] diff --git a/mmdet3d/models/model_utils/vote_module.py b/mmdet3d/models/model_utils/vote_module.py index 790fa5f6b6..bfea4c8448 100644 --- a/mmdet3d/models/model_utils/vote_module.py +++ b/mmdet3d/models/model_utils/vote_module.py @@ -126,7 +126,7 @@ def get_loss(self, seed_points, vote_points, seed_indices, seed_indices_expand = seed_indices.unsqueeze(-1).repeat( 1, 1, 3 * self.gt_per_seed) seed_gt_votes = torch.gather(vote_targets, 1, seed_indices_expand) - seed_gt_votes += seed_points.repeat(1, 1, 3) + seed_gt_votes += seed_points.repeat(1, 1, self.gt_per_seed) weight = seed_gt_votes_mask / (torch.sum(seed_gt_votes_mask) + 1e-6) distance = self.vote_loss( From 18a6a0b00029e388e1e8c91667c6acd285a8fbca Mon Sep 17 00:00:00 2001 From: zhoujiaming Date: Mon, 10 Aug 2020 12:48:15 +0800 Subject: [PATCH 03/10] modify primitive head --- mmdet3d/models/dense_heads/primitive_head.py | 254 +++++++++---------- 1 file changed, 120 insertions(+), 134 deletions(-) diff --git a/mmdet3d/models/dense_heads/primitive_head.py b/mmdet3d/models/dense_heads/primitive_head.py index 801a0de2ad..0580886330 100644 --- a/mmdet3d/models/dense_heads/primitive_head.py +++ b/mmdet3d/models/dense_heads/primitive_head.py @@ -11,7 +11,7 @@ @HEADS.register_module() class PrimitiveHead(nn.Module): - r"""Bbox head of `Votenet `_. + r"""Bbox head of `H3dnet `_. Args: num_dim (int): The dimension of primitive. @@ -64,15 +64,16 @@ def __init__(self, # Existence flag prediction self.flag_conv = ConvModule( - 256, - 128, + vote_moudule_cfg['conv_channels'][-1], + vote_moudule_cfg['conv_channels'][-1] // 2, 1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, bias=True, inplace=True) - self.flag_pred = torch.nn.Conv1d(128, 2, 1) + self.flag_pred = torch.nn.Conv1d( + vote_moudule_cfg['conv_channels'][-1] // 2, 2, 1) self.vote_module = VoteModule(**vote_moudule_cfg) self.vote_aggregation = PointSAModule(**vote_aggregation_cfg) @@ -126,8 +127,6 @@ def forward(self, feat_dict, sample_mod): seed_features = feat_dict['hd_feature'] results = {} - # net_flag = F.relu(self.bn_flag1(self.conv_flag1(seed_features))) - # net_flag = self.conv_flag2(net_flag) net_flag = self.flag_conv(seed_features) net_flag = self.flag_pred(net_flag) @@ -210,7 +209,7 @@ def loss(self, losses = {} flag_loss = self.compute_flag_loss( - bbox_preds, point_mask, mode=self.primitive_mode) * 30 + bbox_preds, point_mask, mode=self.primitive_mode) losses['flag_loss_' + self.primitive_mode] = flag_loss # calculate vote loss @@ -229,8 +228,7 @@ def loss(self, losses['center_loss_' + self.primitive_mode] = center_loss losses['size_loss_' + self.primitive_mode] = size_loss losses['sem_loss_' + self.primitive_mode] = sem_cls_loss - losses['surface_loss_' + self.primitive_mode] = center_loss * 0.5 + \ - size_loss * 0.5 + sem_cls_loss + return losses def get_targets(self, @@ -309,22 +307,21 @@ def get_targets_single(self, """ gt_bboxes_3d = gt_bboxes_3d.to(points.device) num_points = points.shape[0] - device = points.device if self.primitive_mode == 'z': - point_boundary_mask_z = torch.zeros(num_points).to(device) - point_boundary_offset_z = torch.zeros([num_points, 3]).to(device) - point_boundary_sem_z = torch.zeros( - [num_points, 3 + self.num_dim + 1]).to(device) + point_boundary_mask_z = points.new_zeros(num_points) + point_boundary_offset_z = points.new_zeros([num_points, 3]) + point_boundary_sem_z = points.new_zeros( + [num_points, 3 + self.num_dim + 1]) elif self.primitive_mode == 'xy': - point_boundary_mask_xy = torch.zeros(num_points).to(device) - point_boundary_offset_xy = torch.zeros([num_points, 3]).to(device) - point_boundary_sem_xy = torch.zeros( - [num_points, 3 + self.num_dim + 1]).to(device) + point_boundary_mask_xy = points.new_zeros(num_points) + point_boundary_offset_xy = points.new_zeros([num_points, 3]) + point_boundary_sem_xy = points.new_zeros( + [num_points, 3 + self.num_dim + 1]) elif self.primitive_mode == 'line': - point_line_mask = torch.zeros(num_points).to(device) - point_line_offset = torch.zeros([num_points, 3]).to(device) - point_line_sem = torch.zeros([num_points, 3 + 1]).to(device) + point_line_mask = points.new_zeros(num_points) + point_line_offset = points.new_zeros([num_points, 3]) + point_line_sem = points.new_zeros([num_points, 3 + 1]) else: NotImplementedError @@ -342,24 +339,21 @@ def get_targets_single(self, xmax, ymax, zmax = corners.max(0)[0] # Get lower four lines - plane_lower_temp = torch.as_tensor([0, 0, 1, - -corners[6, -1]]).to(device) + plane_lower_temp = points.new_tensor([0, 0, 1, -corners[6, -1]]) para_points = corners[[1, 3, 5, 7]] newd = torch.sum(para_points * plane_lower_temp[:3], 1) if self.check_upright(para_points) and \ plane_lower_temp[0] + plane_lower_temp[1] < \ self.train_cfg['lower_thresh']: - plane_lower = torch.as_tensor([0, 0, 1, plane_lower_temp[-1] - ]).to(device) - plane_upper = torch.as_tensor([0, 0, 1, - -torch.mean(newd)]).to(device) + plane_lower = points.new_tensor( + [0, 0, 1, plane_lower_temp[-1]]) + plane_upper = points.new_tensor([0, 0, 1, -torch.mean(newd)]) else: - import pdb - pdb.set_trace() - print('error with upright') + raise NotImplementedError + # print('error with upright') + if self.check_z(plane_upper, para_points) is False: - import pdb - pdb.set_trace() + raise NotImplementedError # Get the boundary points here alldist = torch.abs( @@ -368,7 +362,7 @@ def get_targets_single(self, sel = torch.abs(alldist - mind) < self.train_cfg['dist_thresh'] # Get lower four lines - line_sel1, line_sel2, line_sel3, line_sel4 = self.get_linesel( + line_sel1, line_sel2, line_sel3, line_sel4 = self.match_point2line( x[sel], xmin, xmax, ymin, ymax) if self.primitive_mode == 'line': if torch.sum(line_sel1) > self.train_cfg['num_point_line']: @@ -377,55 +371,54 @@ def get_targets_single(self, linecenter[1] = (ymin + ymax) / 2.0 point_line_offset[ ind[sel][line_sel1]] = linecenter - x[sel][line_sel1] - point_line_sem[ind[sel][line_sel1]] = torch.as_tensor([ + point_line_sem[ind[sel][line_sel1]] = points.new_tensor([ linecenter[0], linecenter[1], linecenter[2], (pts_semantic_mask[ind][0]) - ]).to(device) + ]) if torch.sum(line_sel2) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel2]] = 1.0 linecenter = torch.mean(x[sel][line_sel2], axis=0) linecenter[1] = (ymin + ymax) / 2.0 point_line_offset[ ind[sel][line_sel2]] = linecenter - x[sel][line_sel2] - point_line_sem[ind[sel][line_sel2]] = torch.as_tensor([ + point_line_sem[ind[sel][line_sel2]] = points.new_tensor([ linecenter[0], linecenter[1], linecenter[2], (pts_semantic_mask[ind][0]) - ]).to(device) + ]) if torch.sum(line_sel3) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel3]] = 1.0 linecenter = torch.mean(x[sel][line_sel3], axis=0) linecenter[0] = (xmin + xmax) / 2.0 point_line_offset[ ind[sel][line_sel3]] = linecenter - x[sel][line_sel3] - point_line_sem[ind[sel][line_sel3]] = torch.as_tensor([ + point_line_sem[ind[sel][line_sel3]] = points.new_tensor([ linecenter[0], linecenter[1], linecenter[2], pts_semantic_mask[ind][0] - ]).to(device) + ]) if torch.sum(line_sel4) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel4]] = 1.0 linecenter = torch.mean(x[sel][line_sel4], axis=0) linecenter[0] = (xmin + xmax) / 2.0 point_line_offset[ ind[sel][line_sel4]] = linecenter - x[sel][line_sel4] - point_line_sem[ind[sel][line_sel4]] = torch.as_tensor([ + point_line_sem[ind[sel][line_sel4]] = points.new_tensor([ linecenter[0], linecenter[1], linecenter[2], (pts_semantic_mask[ind][0]) - ]).to(device) + ]) # Set the surface labels here if self.primitive_mode == 'z': if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( alldist[sel]) < self.train_cfg['var_thresh']: - center = torch.as_tensor([ - (xmin + xmax) / 2.0, (ymin + ymax) / 2.0, - torch.mean(x[sel][:, 2]) - ]).to(device) + center = points.new_tensor([(xmin + xmax) / 2.0, + (ymin + ymax) / 2.0, + torch.mean(x[sel][:, 2])]) sel_global = ind[sel] point_boundary_mask_z[sel_global] = 1.0 - point_boundary_sem_z[sel_global] = torch.as_tensor([ + point_boundary_sem_z[sel_global] = points.new_tensor([ center[0], center[1], center[2], xmax - xmin, ymax - ymin, (pts_semantic_mask[ind][0]) - ]).to(device) + ]) point_boundary_offset_z[sel_global] = center - x[sel] # Get the boundary points here @@ -435,7 +428,7 @@ def get_targets_single(self, sel = torch.abs(alldist - mind) < self.train_cfg['dist_thresh'] # Get upper four lines - line_sel1, line_sel2, line_sel3, line_sel4 = self.get_linesel( + line_sel1, line_sel2, line_sel3, line_sel4 = self.match_point2line( x[sel], xmin, xmax, ymin, ymax) if self.primitive_mode == 'line': @@ -445,77 +438,77 @@ def get_targets_single(self, linecenter[1] = (ymin + ymax) / 2.0 point_line_offset[ ind[sel][line_sel1]] = linecenter - x[sel][line_sel1] - point_line_sem[ind[sel][line_sel1]] = torch.as_tensor([ + point_line_sem[ind[sel][line_sel1]] = points.new_tensor([ linecenter[0], linecenter[1], linecenter[2], (pts_semantic_mask[ind][0]) - ]).to(device) + ]) if torch.sum(line_sel2) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel2]] = 1.0 linecenter = torch.mean(x[sel][line_sel2], axis=0) linecenter[1] = (ymin + ymax) / 2.0 point_line_offset[ ind[sel][line_sel2]] = linecenter - x[sel][line_sel2] - point_line_sem[ind[sel][line_sel2]] = torch.as_tensor([ + point_line_sem[ind[sel][line_sel2]] = points.new_tensor([ linecenter[0], linecenter[1], linecenter[2], (pts_semantic_mask[ind][0]) - ]).to(device) + ]) if torch.sum(line_sel3) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel3]] = 1.0 linecenter = torch.mean(x[sel][line_sel3], axis=0) linecenter[0] = (xmin + xmax) / 2.0 point_line_offset[ ind[sel][line_sel3]] = linecenter - x[sel][line_sel3] - point_line_sem[ind[sel][line_sel3]] = torch.as_tensor([ + point_line_sem[ind[sel][line_sel3]] = points.new_tensor([ linecenter[0], linecenter[1], linecenter[2], (pts_semantic_mask[ind][0]) - ]).to(device) + ]) if torch.sum(line_sel4) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel4]] = 1.0 linecenter = torch.mean(x[sel][line_sel4], axis=0) linecenter[0] = (xmin + xmax) / 2.0 point_line_offset[ ind[sel][line_sel4]] = linecenter - x[sel][line_sel4] - point_line_sem[ind[sel][line_sel4]] = torch.as_tensor([ + point_line_sem[ind[sel][line_sel4]] = points.new_tensor([ linecenter[0], linecenter[1], linecenter[2], (pts_semantic_mask[ind][0]) - ]).to(device) + ]) if self.primitive_mode == 'z': if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( alldist[sel]) < self.train_cfg['var_thresh']: - center = torch.as_tensor([ - (xmin + xmax) / 2.0, (ymin + ymax) / 2.0, - torch.mean(x[sel][:, 2]) - ]).to(device) + center = points.new_tensor([(xmin + xmax) / 2.0, + (ymin + ymax) / 2.0, + torch.mean(x[sel][:, 2])]) sel_global = ind[sel] point_boundary_mask_z[sel_global] = 1.0 - point_boundary_sem_z[sel_global] = torch.as_tensor([ + point_boundary_sem_z[sel_global] = points.new_tensor([ center[0], center[1], center[2], xmax - xmin, ymax - ymin, (pts_semantic_mask[ind][0]) - ]).to(device) + ]) point_boundary_offset_z[sel_global] = center - x[sel] # Get left two lines - v1 = corners[3] - corners[2] - v2 = corners[2] - corners[0] - cp = torch.cross(v1, v2) - d = -torch.dot(cp, corners[0]) - a, b, c = cp - plane_left_temp = torch.as_tensor([a, b, c, d]).to(device) + vec1 = corners[3] - corners[2] + vec2 = corners[2] - corners[0] + surface_norm = torch.cross(vec1, vec2) + surface_dis = -torch.dot(surface_norm, corners[0]) + plane_left_temp = points.new_tensor([ + surface_norm[0], surface_norm[1], surface_norm[2], surface_dis + ]) + para_points = corners[[4, 5, 6, 7]] # Normalize xy here plane_left_temp /= torch.norm(plane_left_temp[:3]) newd = torch.sum(para_points * plane_left_temp[:3], 1) if plane_left_temp[2] < self.train_cfg['lower_thresh']: plane_left = plane_left_temp - plane_right = torch.as_tensor([ + plane_right = points.new_tensor([ plane_left_temp[0], plane_left_temp[1], plane_left_temp[2], -torch.mean(newd) - ]).to(device) + ]) else: - import pdb - pdb.set_trace() - print('error with upright') + raise NotImplementedError + # print('error with upright') # Get the boundary points here alldist = torch.abs( @@ -524,8 +517,9 @@ def get_targets_single(self, sel = torch.abs(alldist - mind) < self.train_cfg['dist_thresh'] # Get upper four lines - line_sel1, line_sel2 = self.get_linesel2( - x[sel], ymin, ymax, zmin, zmax, axis=1) + _, _, line_sel1, line_sel2 = self.match_point2line( + x[sel], xmin, xmax, ymin, ymax) + if self.primitive_mode == 'line': if torch.sum(line_sel1) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel1]] = 1.0 @@ -533,34 +527,34 @@ def get_targets_single(self, linecenter[2] = (zmin + zmax) / 2.0 point_line_offset[ ind[sel][line_sel1]] = linecenter - x[sel][line_sel1] - point_line_sem[ind[sel][line_sel1]] = torch.as_tensor([ + point_line_sem[ind[sel][line_sel1]] = points.new_tensor([ linecenter[0], linecenter[1], linecenter[2], (pts_semantic_mask[ind][0]) - ]).to(device) + ]) if torch.sum(line_sel2) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel2]] = 1.0 linecenter = torch.mean(x[sel][line_sel2], axis=0) linecenter[2] = (zmin + zmax) / 2.0 point_line_offset[ ind[sel][line_sel2]] = linecenter - x[sel][line_sel2] - point_line_sem[ind[sel][line_sel2]] = torch.as_tensor([ + point_line_sem[ind[sel][line_sel2]] = points.new_tensor([ linecenter[0], linecenter[1], linecenter[2], (pts_semantic_mask[ind][0]) - ]).to(device) + ]) if self.primitive_mode == 'xy': if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( alldist[sel]) < self.train_cfg['var_thresh']: - center = torch.as_tensor([ + center = points.new_tensor([ torch.mean(x[sel][:, 0]), torch.mean(x[sel][:, 1]), (zmin + zmax) / 2.0 - ]).to(device) + ]) sel_global = ind[sel] point_boundary_mask_xy[sel_global] = 1.0 - point_boundary_sem_xy[sel_global] = torch.as_tensor([ + point_boundary_sem_xy[sel_global] = points.new_tensor([ center[0], center[1], center[2], zmax - zmin, (pts_semantic_mask[ind][0]) - ]).to(device) + ]) point_boundary_offset_xy[sel_global] = center - x[sel] # Get the boundary points here @@ -568,8 +562,9 @@ def get_targets_single(self, torch.sum(x * plane_right[:3], 1) + plane_right[-1]) mind = alldist.min() sel = torch.abs(alldist - mind) < self.train_cfg['dist_thresh'] - line_sel1, line_sel2 = self.get_linesel2( - x[sel], ymin, ymax, zmin, zmax, axis=1) + _, _, line_sel1, line_sel2 = self.match_point2line( + x[sel], xmin, xmax, ymin, ymax) + if self.primitive_mode == 'line': if torch.sum(line_sel1) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel1]] = 1.0 @@ -577,56 +572,57 @@ def get_targets_single(self, linecenter[2] = (zmin + zmax) / 2.0 point_line_offset[ ind[sel][line_sel1]] = linecenter - x[sel][line_sel1] - point_line_sem[ind[sel][line_sel1]] = torch.as_tensor([ + point_line_sem[ind[sel][line_sel1]] = points.new_tensor([ linecenter[0], linecenter[1], linecenter[2], (pts_semantic_mask[ind][0]) - ]).to(device) + ]) if torch.sum(line_sel2) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel2]] = 1.0 linecenter = torch.mean(x[sel][line_sel2], axis=0) linecenter[2] = (zmin + zmax) / 2.0 point_line_offset[ ind[sel][line_sel2]] = linecenter - x[sel][line_sel2] - point_line_sem[ind[sel][line_sel2]] = torch.as_tensor([ + point_line_sem[ind[sel][line_sel2]] = points.new_tensor([ linecenter[0], linecenter[1], linecenter[2], (pts_semantic_mask[ind][0]) - ]).to(device) + ]) if self.primitive_mode == 'xy': if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( alldist[sel]) < self.train_cfg['var_thresh']: - center = torch.as_tensor([ + center = points.new_tensor([ torch.mean(x[sel][:, 0]), torch.mean(x[sel][:, 1]), (zmin + zmax) / 2.0 - ]).to(device) + ]) sel_global = ind[sel] point_boundary_mask_xy[sel_global] = 1.0 - point_boundary_sem_xy[sel_global] = torch.as_tensor([ + point_boundary_sem_xy[sel_global] = points.new_tensor([ center[0], center[1], center[2], zmax - zmin, (pts_semantic_mask[ind][0]) - ]).to(device) + ]) point_boundary_offset_xy[sel_global] = center - x[sel] # Get the boundary points here - v1 = corners[0] - corners[4] - v2 = corners[4] - corners[5] - cp = torch.cross(v1, v2) - d = -torch.dot(cp, corners[5]) - a, b, c = cp - plane_front_temp = torch.as_tensor([a, b, c, d]).to(device) + vec1 = corners[0] - corners[4] + vec2 = corners[4] - corners[5] + surface_norm = torch.cross(vec1, vec2) + surface_dis = -torch.dot(surface_norm, corners[5]) + plane_front_temp = points.new_tensor([ + surface_norm[0], surface_norm[1], surface_norm[2], surface_dis + ]) + para_points = corners[[2, 3, 6, 7]] plane_front_temp /= torch.norm(plane_front_temp[:3]) newd = torch.sum(para_points * plane_front_temp[:3], 1) if plane_front_temp[2] < self.train_cfg['lower_thresh']: plane_front = plane_front_temp - plane_back = torch.as_tensor([ + plane_back = points.new_tensor([ plane_front_temp[0], plane_front_temp[1], plane_front_temp[2], -torch.mean(newd) - ]).to(device) + ]) else: - import pdb - pdb.set_trace() - print('error with upright') + raise NotImplementedError + # print('error with upright') # Get the boundary points here alldist = torch.abs( @@ -636,16 +632,16 @@ def get_targets_single(self, if self.primitive_mode == 'xy': if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( alldist[sel]) < self.train_cfg['var_thresh']: - center = torch.as_tensor([ + center = points.new_tensor([ torch.mean(x[sel][:, 0]), torch.mean(x[sel][:, 1]), (zmin + zmax) / 2.0 - ]).to(device) + ]) sel_global = ind[sel] point_boundary_mask_xy[sel_global] = 1.0 - point_boundary_sem_xy[sel_global] = torch.as_tensor([ + point_boundary_sem_xy[sel_global] = points.new_tensor([ center[0], center[1], center[2], zmax - zmin, (pts_semantic_mask[ind][0]) - ]).to(device) + ]) point_boundary_offset_xy[sel_global] = center - x[sel] # Get the boundary points here alldist = torch.abs( @@ -655,16 +651,16 @@ def get_targets_single(self, if self.primitive_mode == 'xy': if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( alldist[sel]) < self.train_cfg['var_thresh']: - center = torch.as_tensor([ + center = points.new_tensor([ torch.mean(x[sel][:, 0]), torch.mean(x[sel][:, 1]), (zmin + zmax) / 2.0 - ]).to(device) + ]) sel_global = ind[sel] point_boundary_mask_xy[sel_global] = 1.0 - point_boundary_sem_xy[sel_global] = torch.as_tensor([ + point_boundary_sem_xy[sel_global] = points.new_tensor([ center[0], center[1], center[2], zmax - zmin, (pts_semantic_mask[ind][0]) - ]).to(device) + ]) point_boundary_offset_xy[sel_global] = center - x[sel] if self.primitive_mode == 'z': @@ -679,12 +675,12 @@ def get_targets_single(self, NotImplementedError def primitive_decode_scores(self, net, end_points, num_class, mode=''): - net_transposed = net.transpose(2, 1) # (batch_size, 1024, ..) - - base_xyz = end_points['aggregated_points_' + - mode] # (batch_size, num_proposal, 3) - center = base_xyz + net_transposed[:, :, 0: - 3] # (batch_size, num_proposal, 3) + # (batch_size, 1024, ..) + net_transposed = net.transpose(2, 1) + # (batch_size, num_proposal, 3) + base_xyz = end_points['aggregated_points_' + mode] + # (batch_size, num_proposal, 3) + center = base_xyz + net_transposed[:, :, 0:3] end_points['center_' + mode] = center if mode in ['z', 'xy']: @@ -706,20 +702,13 @@ def check_z(self, plane_equ, para_points): return torch.sum(para_points[:, 2] + plane_equ[-1]) / 4.0 < self.train_cfg['lower_thresh'] - def get_linesel(self, points, xmin, xmax, ymin, ymax): + def match_point2line(self, points, xmin, xmax, ymin, ymax): sel1 = torch.abs(points[:, 0] - xmin) < self.train_cfg['line_thresh'] sel2 = torch.abs(points[:, 0] - xmax) < self.train_cfg['line_thresh'] sel3 = torch.abs(points[:, 1] - ymin) < self.train_cfg['line_thresh'] sel4 = torch.abs(points[:, 1] - ymax) < self.train_cfg['line_thresh'] return sel1, sel2, sel3, sel4 - def get_linesel2(self, points, ymin, ymax, zmin, zmax, axis=0): - sel3 = torch.abs(points[:, axis] - - ymin) < self.train_cfg['line_thresh'] - sel4 = torch.abs(points[:, axis] - - ymax) < self.train_cfg['line_thresh'] - return sel3, sel4 - def compute_flag_loss(self, end_points, point_mask, mode): # Compute existence flag for face and edge centers # Load ground truth votes and assign them to seed points @@ -746,12 +735,11 @@ def compute_primitivesem_loss(self, """Compute final geometric primitive center and semantic.""" # Load ground truth votes and assign them to seed points batch_size = end_points['seed_points'].shape[0] - num_seed = end_points['seed_points'].shape[1] # B,num_seed,3 - vote_xyz = end_points['center_' + mode] # B,num_seed*vote_factor,3 + num_seed = end_points['seed_points'].shape[1] + vote_xyz = end_points['center_' + mode] seed_inds = end_points['seed_indices'].long() - num_proposal = end_points['aggregated_points_' + - mode].shape[1] # B,num_seed,3 + num_proposal = end_points['aggregated_points_' + mode].shape[1] seed_gt_votes_mask = torch.gather(point_mask, 1, seed_inds) seed_inds_expand = seed_inds.view(batch_size, num_seed, @@ -785,9 +773,7 @@ def compute_primitivesem_loss(self, # Compute the min of min of distance # Need to remove this soon if mode != 'line': - size_xyz = end_points[ - 'size_residuals_' + - mode].contiguous() # B,num_seed*vote_factor,3 + size_xyz = end_points['size_residuals_' + mode].contiguous() size_xyz_reshape = size_xyz.view(batch_size * num_proposal, -1, self.num_dim).contiguous() seed_gt_votes_reshape = seed_gt_sem[:, :, 3:3 + self.num_dim].view( @@ -802,7 +788,7 @@ def compute_primitivesem_loss(self, size_loss = size_loss.sum() / ( torch.sum(seed_gt_votes_mask.float()) + 1e-6) else: - size_loss = torch.tensor(0) + size_loss = torch.tensor(0).float().to(center_loss.device) # 3.4 Semantic cls loss sem_cls_label = seed_gt_sem[:, :, -1].long() From c1b047050e055c988fe77ce4b857e1a2a476e0e7 Mon Sep 17 00:00:00 2001 From: zhoujiaming Date: Mon, 17 Aug 2020 11:42:45 +0800 Subject: [PATCH 04/10] modify primitive head --- mmdet3d/models/dense_heads/__init__.py | 6 +- .../models/roi_heads/mask_heads/__init__.py | 3 +- .../mask_heads}/primitive_head.py | 349 +++++++++--------- tests/test_heads.py | 1 - 4 files changed, 187 insertions(+), 172 deletions(-) rename mmdet3d/models/{dense_heads => roi_heads/mask_heads}/primitive_head.py (74%) diff --git a/mmdet3d/models/dense_heads/__init__.py b/mmdet3d/models/dense_heads/__init__.py index 67ca03266c..5e06a97509 100644 --- a/mmdet3d/models/dense_heads/__init__.py +++ b/mmdet3d/models/dense_heads/__init__.py @@ -1,10 +1,6 @@ from .anchor3d_head import Anchor3DHead from .free_anchor3d_head import FreeAnchor3DHead from .parta2_rpn_head import PartA2RPNHead -from .primitive_head import PrimitiveHead from .vote_head import VoteHead -__all__ = [ - 'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'PrimitiveHead', - 'VoteHead' -] +__all__ = ['Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead'] diff --git a/mmdet3d/models/roi_heads/mask_heads/__init__.py b/mmdet3d/models/roi_heads/mask_heads/__init__.py index 532bbfaae9..ecc8a118a5 100644 --- a/mmdet3d/models/roi_heads/mask_heads/__init__.py +++ b/mmdet3d/models/roi_heads/mask_heads/__init__.py @@ -1,3 +1,4 @@ from .pointwise_semantic_head import PointwiseSemanticHead +from .primitive_head import PrimitiveHead -__all__ = ['PointwiseSemanticHead'] +__all__ = ['PointwiseSemanticHead', 'PrimitiveHead'] diff --git a/mmdet3d/models/dense_heads/primitive_head.py b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py similarity index 74% rename from mmdet3d/models/dense_heads/primitive_head.py rename to mmdet3d/models/roi_heads/mask_heads/primitive_head.py index 0580886330..1966ca5498 100644 --- a/mmdet3d/models/dense_heads/primitive_head.py +++ b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py @@ -14,9 +14,10 @@ class PrimitiveHead(nn.Module): r"""Bbox head of `H3dnet `_. Args: - num_dim (int): The dimension of primitive. + num_dim (int): The dimension of primitive semantic information. num_classes (int): The number of class. - primitive_mode (str): The mode of primitive. + primitive_mode (str): The mode of primitive module, + avaliable mode ['z', 'xy', 'line']. bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and decoding boxes. train_cfg (dict): Config for training. @@ -47,6 +48,8 @@ def __init__(self, center_loss=None, semantic_loss=None): super(PrimitiveHead, self).__init__() + assert primitive_mode in ['z', 'xy', 'line'] + # The dimension of primitive semantic information. self.num_dim = num_dim self.num_classes = num_classes self.primitive_mode = primitive_mode @@ -62,7 +65,7 @@ def __init__(self, assert vote_aggregation_cfg['mlp_channels'][0] == vote_moudule_cfg[ 'in_channels'] - # Existence flag prediction + # Primitive existence flag prediction self.flag_conv = ConvModule( vote_moudule_cfg['conv_channels'][-1], vote_moudule_cfg['conv_channels'][-1] // 2, @@ -105,14 +108,6 @@ def init_weights(self): def forward(self, feat_dict, sample_mod): """Forward pass. - Note: - The forward of VoteHead is devided into 4 steps: - - 1. Generate vote_points from seed_points. - 2. Aggregate vote_points. - 3. Predict primitive cue and score. - 4. Decode predictions. - Args: feat_dict (dict): Feature dict from backbone. sample_mod (str): Sample mode for vote aggregation layer. @@ -169,7 +164,7 @@ def forward(self, feat_dict, sample_mod): # 4. decode predictions newcenter, decode_res = self.primitive_decode_scores( - predictions, results, self.num_classes, mode=self.primitive_mode) + predictions, aggregated_points, mode=self.primitive_mode) results.update(decode_res) return results @@ -203,13 +198,16 @@ def loss(self, dict: Losses of Votenet. """ targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d, - pts_semantic_mask, pts_instance_mask) + pts_semantic_mask, pts_instance_mask, + bbox_preds) - (point_mask, point_sem, point_offset) = targets + (point_mask, point_offset, gt_primitive_center, gt_primitive_semantic, + gt_sem_cls_label, gt_primitive_mask) = targets losses = {} - flag_loss = self.compute_flag_loss( - bbox_preds, point_mask, mode=self.primitive_mode) + # Compute the loss of primitive existence flag + pred_flag = bbox_preds['pred_flag_' + self.primitive_mode] + flag_loss = self.objectness_loss(pred_flag, gt_primitive_mask.long()) losses['flag_loss_' + self.primitive_mode] = flag_loss # calculate vote loss @@ -219,12 +217,20 @@ def loss(self, bbox_preds['seed_indices'], point_mask, point_offset) losses['vote_loss_' + self.primitive_mode] = vote_loss + num_proposal = bbox_preds['aggregated_points_' + + self.primitive_mode].shape[1] + primitive_center = bbox_preds['center_' + self.primitive_mode] + if self.primitive_mode != 'line': + primitive_semantic = bbox_preds['size_residuals_' + + self.primitive_mode].contiguous() + else: + primitive_semantic = None + semancitc_scores = bbox_preds['sem_cls_scores_' + + self.primitive_mode].transpose(2, 1) center_loss, size_loss, sem_cls_loss = self.compute_primitivesem_loss( - bbox_preds, - point_mask, - point_offset, - point_sem, - mode=self.primitive_mode) + primitive_center, primitive_semantic, semancitc_scores, + num_proposal, gt_primitive_center, gt_primitive_semantic, + gt_sem_cls_label, gt_primitive_mask) losses['center_loss_' + self.primitive_mode] = center_loss losses['size_loss_' + self.primitive_mode] = size_loss losses['sem_loss_' + self.primitive_mode] = sem_cls_loss @@ -249,7 +255,7 @@ def get_targets(self, label of each batch. pts_instance_mask (None | list[torch.Tensor]): Point-wise instance label of each batch. - bbox_preds (torch.Tensor): Predictions of primitive head. + bbox_preds (dict): Predictions from forward of primitive head. Returns: tuple[torch.Tensor]: Targets of primitive head. @@ -282,7 +288,30 @@ def get_targets(self, point_sem = torch.stack(point_sem) point_offset = torch.stack(point_offset) - return (point_mask, point_sem, point_offset) + batch_size = point_mask.shape[0] + num_proposal = bbox_preds['aggregated_points_' + + self.primitive_mode].shape[1] + num_seed = bbox_preds['seed_points'].shape[1] + seed_inds = bbox_preds['seed_indices'].long() + seed_inds_expand = seed_inds.view(batch_size, num_seed, + 1).repeat(1, 1, 3) + seed_gt_votes = torch.gather(point_offset, 1, seed_inds_expand) + seed_gt_votes += bbox_preds['seed_points'] + gt_primitive_center = seed_gt_votes.view(batch_size * num_proposal, 1, + 3) + + seed_inds_expand_sem = seed_inds.view(batch_size, num_seed, 1).repeat( + 1, 1, 4 + self.num_dim) + seed_gt_sem = torch.gather(point_sem, 1, seed_inds_expand_sem) + gt_primitive_semantic = seed_gt_sem[:, :, 3:3 + self.num_dim].view( + batch_size * num_proposal, 1, self.num_dim).contiguous() + + gt_sem_cls_label = seed_gt_sem[:, :, -1].long() + + gt_votes_mask = torch.gather(point_mask, 1, seed_inds) + + return (point_mask, point_offset, gt_primitive_center, + gt_primitive_semantic, gt_sem_cls_label, gt_votes_mask) def get_targets_single(self, points, @@ -367,42 +396,42 @@ def get_targets_single(self, if self.primitive_mode == 'line': if torch.sum(line_sel1) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel1]] = 1.0 - linecenter = torch.mean(x[sel][line_sel1], axis=0) - linecenter[1] = (ymin + ymax) / 2.0 + line_center = torch.mean(x[sel][line_sel1], axis=0) + line_center[1] = (ymin + ymax) / 2.0 point_line_offset[ - ind[sel][line_sel1]] = linecenter - x[sel][line_sel1] + ind[sel][line_sel1]] = line_center - x[sel][line_sel1] point_line_sem[ind[sel][line_sel1]] = points.new_tensor([ - linecenter[0], linecenter[1], linecenter[2], - (pts_semantic_mask[ind][0]) + line_center[0], line_center[1], line_center[2], + pts_semantic_mask[ind][0] ]) if torch.sum(line_sel2) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel2]] = 1.0 - linecenter = torch.mean(x[sel][line_sel2], axis=0) - linecenter[1] = (ymin + ymax) / 2.0 + line_center = torch.mean(x[sel][line_sel2], axis=0) + line_center[1] = (ymin + ymax) / 2.0 point_line_offset[ - ind[sel][line_sel2]] = linecenter - x[sel][line_sel2] + ind[sel][line_sel2]] = line_center - x[sel][line_sel2] point_line_sem[ind[sel][line_sel2]] = points.new_tensor([ - linecenter[0], linecenter[1], linecenter[2], - (pts_semantic_mask[ind][0]) + line_center[0], line_center[1], line_center[2], + pts_semantic_mask[ind][0] ]) if torch.sum(line_sel3) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel3]] = 1.0 - linecenter = torch.mean(x[sel][line_sel3], axis=0) - linecenter[0] = (xmin + xmax) / 2.0 + line_center = torch.mean(x[sel][line_sel3], axis=0) + line_center[0] = (xmin + xmax) / 2.0 point_line_offset[ - ind[sel][line_sel3]] = linecenter - x[sel][line_sel3] + ind[sel][line_sel3]] = line_center - x[sel][line_sel3] point_line_sem[ind[sel][line_sel3]] = points.new_tensor([ - linecenter[0], linecenter[1], linecenter[2], + line_center[0], line_center[1], line_center[2], pts_semantic_mask[ind][0] ]) if torch.sum(line_sel4) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel4]] = 1.0 - linecenter = torch.mean(x[sel][line_sel4], axis=0) - linecenter[0] = (xmin + xmax) / 2.0 + line_center = torch.mean(x[sel][line_sel4], axis=0) + line_center[0] = (xmin + xmax) / 2.0 point_line_offset[ - ind[sel][line_sel4]] = linecenter - x[sel][line_sel4] + ind[sel][line_sel4]] = line_center - x[sel][line_sel4] point_line_sem[ind[sel][line_sel4]] = points.new_tensor([ - linecenter[0], linecenter[1], linecenter[2], + line_center[0], line_center[1], line_center[2], (pts_semantic_mask[ind][0]) ]) @@ -434,42 +463,42 @@ def get_targets_single(self, if self.primitive_mode == 'line': if torch.sum(line_sel1) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel1]] = 1.0 - linecenter = torch.mean(x[sel][line_sel1], axis=0) - linecenter[1] = (ymin + ymax) / 2.0 + line_center = torch.mean(x[sel][line_sel1], axis=0) + line_center[1] = (ymin + ymax) / 2.0 point_line_offset[ - ind[sel][line_sel1]] = linecenter - x[sel][line_sel1] + ind[sel][line_sel1]] = line_center - x[sel][line_sel1] point_line_sem[ind[sel][line_sel1]] = points.new_tensor([ - linecenter[0], linecenter[1], linecenter[2], + line_center[0], line_center[1], line_center[2], (pts_semantic_mask[ind][0]) ]) if torch.sum(line_sel2) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel2]] = 1.0 - linecenter = torch.mean(x[sel][line_sel2], axis=0) - linecenter[1] = (ymin + ymax) / 2.0 + line_center = torch.mean(x[sel][line_sel2], axis=0) + line_center[1] = (ymin + ymax) / 2.0 point_line_offset[ - ind[sel][line_sel2]] = linecenter - x[sel][line_sel2] + ind[sel][line_sel2]] = line_center - x[sel][line_sel2] point_line_sem[ind[sel][line_sel2]] = points.new_tensor([ - linecenter[0], linecenter[1], linecenter[2], + line_center[0], line_center[1], line_center[2], (pts_semantic_mask[ind][0]) ]) if torch.sum(line_sel3) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel3]] = 1.0 - linecenter = torch.mean(x[sel][line_sel3], axis=0) - linecenter[0] = (xmin + xmax) / 2.0 + line_center = torch.mean(x[sel][line_sel3], axis=0) + line_center[0] = (xmin + xmax) / 2.0 point_line_offset[ - ind[sel][line_sel3]] = linecenter - x[sel][line_sel3] + ind[sel][line_sel3]] = line_center - x[sel][line_sel3] point_line_sem[ind[sel][line_sel3]] = points.new_tensor([ - linecenter[0], linecenter[1], linecenter[2], + line_center[0], line_center[1], line_center[2], (pts_semantic_mask[ind][0]) ]) if torch.sum(line_sel4) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel4]] = 1.0 - linecenter = torch.mean(x[sel][line_sel4], axis=0) - linecenter[0] = (xmin + xmax) / 2.0 + line_center = torch.mean(x[sel][line_sel4], axis=0) + line_center[0] = (xmin + xmax) / 2.0 point_line_offset[ - ind[sel][line_sel4]] = linecenter - x[sel][line_sel4] + ind[sel][line_sel4]] = line_center - x[sel][line_sel4] point_line_sem[ind[sel][line_sel4]] = points.new_tensor([ - linecenter[0], linecenter[1], linecenter[2], + line_center[0], line_center[1], line_center[2], (pts_semantic_mask[ind][0]) ]) @@ -523,22 +552,22 @@ def get_targets_single(self, if self.primitive_mode == 'line': if torch.sum(line_sel1) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel1]] = 1.0 - linecenter = torch.mean(x[sel][line_sel1], axis=0) - linecenter[2] = (zmin + zmax) / 2.0 + line_center = torch.mean(x[sel][line_sel1], axis=0) + line_center[2] = (zmin + zmax) / 2.0 point_line_offset[ - ind[sel][line_sel1]] = linecenter - x[sel][line_sel1] + ind[sel][line_sel1]] = line_center - x[sel][line_sel1] point_line_sem[ind[sel][line_sel1]] = points.new_tensor([ - linecenter[0], linecenter[1], linecenter[2], + line_center[0], line_center[1], line_center[2], (pts_semantic_mask[ind][0]) ]) if torch.sum(line_sel2) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel2]] = 1.0 - linecenter = torch.mean(x[sel][line_sel2], axis=0) - linecenter[2] = (zmin + zmax) / 2.0 + line_center = torch.mean(x[sel][line_sel2], axis=0) + line_center[2] = (zmin + zmax) / 2.0 point_line_offset[ - ind[sel][line_sel2]] = linecenter - x[sel][line_sel2] + ind[sel][line_sel2]] = line_center - x[sel][line_sel2] point_line_sem[ind[sel][line_sel2]] = points.new_tensor([ - linecenter[0], linecenter[1], linecenter[2], + line_center[0], line_center[1], line_center[2], (pts_semantic_mask[ind][0]) ]) @@ -568,22 +597,22 @@ def get_targets_single(self, if self.primitive_mode == 'line': if torch.sum(line_sel1) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel1]] = 1.0 - linecenter = torch.mean(x[sel][line_sel1], axis=0) - linecenter[2] = (zmin + zmax) / 2.0 + line_center = torch.mean(x[sel][line_sel1], axis=0) + line_center[2] = (zmin + zmax) / 2.0 point_line_offset[ - ind[sel][line_sel1]] = linecenter - x[sel][line_sel1] + ind[sel][line_sel1]] = line_center - x[sel][line_sel1] point_line_sem[ind[sel][line_sel1]] = points.new_tensor([ - linecenter[0], linecenter[1], linecenter[2], + line_center[0], line_center[1], line_center[2], (pts_semantic_mask[ind][0]) ]) if torch.sum(line_sel2) > self.train_cfg['num_point_line']: point_line_mask[ind[sel][line_sel2]] = 1.0 - linecenter = torch.mean(x[sel][line_sel2], axis=0) - linecenter[2] = (zmin + zmax) / 2.0 + line_center = torch.mean(x[sel][line_sel2], axis=0) + line_center[2] = (zmin + zmax) / 2.0 point_line_offset[ - ind[sel][line_sel2]] = linecenter - x[sel][line_sel2] + ind[sel][line_sel2]] = line_center - x[sel][line_sel2] point_line_sem[ind[sel][line_sel2]] = points.new_tensor([ - linecenter[0], linecenter[1], linecenter[2], + line_center[0], line_center[1], line_center[2], (pts_semantic_mask[ind][0]) ]) @@ -672,27 +701,44 @@ def get_targets_single(self, elif self.primitive_mode == 'line': return (point_line_mask, point_line_sem, point_line_offset) else: - NotImplementedError + raise NotImplementedError - def primitive_decode_scores(self, net, end_points, num_class, mode=''): - # (batch_size, 1024, ..) - net_transposed = net.transpose(2, 1) - # (batch_size, num_proposal, 3) - base_xyz = end_points['aggregated_points_' + mode] - # (batch_size, num_proposal, 3) - center = base_xyz + net_transposed[:, :, 0:3] - end_points['center_' + mode] = center + def primitive_decode_scores(self, preds, aggregated_points, mode='z'): + """Decode the outputs of primitive module. + + Args: + preds (torch.Tensor): primitive pridictions of each batch. + aggregated_points (torch.Tensor): The aggregated points + of vote stage. + mode (string): The type of primitive module. + + Returns: + Dict: Targets of center, size and semantic. + """ + ret_dict = {} + net_transposed = preds.transpose(2, 1) + + center = aggregated_points + net_transposed[:, :, 0:3] + ret_dict['center_' + mode] = center if mode in ['z', 'xy']: - end_points['size_residuals_' + mode] = net_transposed[:, :, 3:3 + - self.num_dim] + ret_dict['size_residuals_' + mode] = net_transposed[:, :, 3:3 + + self.num_dim] - end_points['sem_cls_scores_' + mode] = net_transposed[:, :, 3 + - self.num_dim:] + ret_dict['sem_cls_scores_' + mode] = net_transposed[:, :, + 3 + self.num_dim:] - return center, end_points + return center, ret_dict def check_upright(self, para_points): + """Check whether is upright corrdinate. + + Args: + para_points (torch.Tensor): Points of input. + + Returns: + Bool: Flag of result. + """ return (para_points[0][-1] == para_points[1][-1]) and ( para_points[1][-1] == para_points[2][-1]) and (para_points[2][-1] @@ -703,100 +749,73 @@ def check_z(self, plane_equ, para_points): plane_equ[-1]) / 4.0 < self.train_cfg['lower_thresh'] def match_point2line(self, points, xmin, xmax, ymin, ymax): + """Match points to corresponding edge. + + Args: + points (torch.Tensor): Points of input. + xmin (float): Min of X-axis. + xmax (float): Max of X-axis. + ymin (float): Min of Y-axis. + ymax (float): Max of Y-axis. + + Returns: + Tuple: Flag of matching correspondence. + """ sel1 = torch.abs(points[:, 0] - xmin) < self.train_cfg['line_thresh'] sel2 = torch.abs(points[:, 0] - xmax) < self.train_cfg['line_thresh'] sel3 = torch.abs(points[:, 1] - ymin) < self.train_cfg['line_thresh'] sel4 = torch.abs(points[:, 1] - ymax) < self.train_cfg['line_thresh'] return sel1, sel2, sel3, sel4 - def compute_flag_loss(self, end_points, point_mask, mode): - # Compute existence flag for face and edge centers - # Load ground truth votes and assign them to seed points - seed_inds = end_points['seed_indices'].long() - - seed_gt_votes_mask = torch.gather(point_mask, 1, seed_inds).float() - end_points['sem_mask'] = seed_gt_votes_mask - - sem_cls_label = torch.gather(point_mask, 1, seed_inds) - end_points['sub_point_sem_cls_label_' + mode] = sem_cls_label - - pred_flag = end_points['pred_flag_' + mode] - - sem_loss = self.objectness_loss(pred_flag, sem_cls_label.long()) - - return sem_loss + def compute_primitivesem_loss(self, primitive_center, primitive_semantic, + semantic_scores, num_proposal, + gt_primitive_center, gt_primitive_semantic, + gt_sem_cls_label, gt_primitive_mask): + """Compute loss of primitive module. - def compute_primitivesem_loss(self, - end_points, - point_mask, - point_offset, - point_sem, - mode=''): - """Compute final geometric primitive center and semantic.""" - # Load ground truth votes and assign them to seed points - batch_size = end_points['seed_points'].shape[0] - num_seed = end_points['seed_points'].shape[1] - vote_xyz = end_points['center_' + mode] - seed_inds = end_points['seed_indices'].long() - - num_proposal = end_points['aggregated_points_' + mode].shape[1] - - seed_gt_votes_mask = torch.gather(point_mask, 1, seed_inds) - seed_inds_expand = seed_inds.view(batch_size, num_seed, - 1).repeat(1, 1, 3) - - seed_inds_expand_sem = seed_inds.view(batch_size, num_seed, 1).repeat( - 1, 1, 4 + self.num_dim) + Args: + primitive_center (torch.Tensor): Pridictions of primitive center. + primitive_semantic (torch.Tensor): Pridictions of primitive + semantic. + semantic_scores (torch.Tensor): Pridictions of primitive + semantic scores. + num_proposal (int): The number of primitive proposal. + gt_primitive_center (torch.Tensor): Ground truth of + primitive center. + gt_votes_sem (torch.Tensor): Ground truth of primitive semantic. + gt_sem_cls_label (torch.Tensor): Ground truth of primitive + semantic class. + gt_primitive_mask (torch.Tensor): Ground truth of primitive mask. - seed_gt_votes = torch.gather(point_offset, 1, seed_inds_expand) - seed_gt_sem = torch.gather(point_sem, 1, seed_inds_expand_sem) - seed_gt_votes += end_points['seed_points'] - - end_points['surface_center_gt_' + mode] = seed_gt_votes - end_points['surface_sem_gt_' + mode] = seed_gt_sem - end_points['surface_mask_gt_' + mode] = seed_gt_votes_mask - - # Compute the min of min of distance - vote_xyz_reshape = vote_xyz.view(batch_size * num_proposal, -1, 3) - seed_gt_votes_reshape = seed_gt_votes.view(batch_size * num_proposal, - 1, 3) - # A predicted vote to no where is not penalized as long as there is a - # good vote near the GT vote. + Returns: + Tuple: Loss of primitive module. + """ + batch_size = primitive_center.shape[0] + vote_xyz_reshape = primitive_center.view(batch_size * num_proposal, -1, + 3) center_loss = self.center_loss( vote_xyz_reshape, - seed_gt_votes_reshape, - dst_weight=seed_gt_votes_mask.view(batch_size * num_proposal, - 1))[1] + gt_primitive_center, + dst_weight=gt_primitive_mask.view(batch_size * num_proposal, 1))[1] center_loss = center_loss.sum() / ( - torch.sum(seed_gt_votes_mask.float()) + 1e-6) - - # Compute the min of min of distance - # Need to remove this soon - if mode != 'line': - size_xyz = end_points['size_residuals_' + mode].contiguous() - size_xyz_reshape = size_xyz.view(batch_size * num_proposal, -1, - self.num_dim).contiguous() - seed_gt_votes_reshape = seed_gt_sem[:, :, 3:3 + self.num_dim].view( - batch_size * num_proposal, 1, self.num_dim).contiguous() - # A predicted vote to no where is not penalized as long as - # there is a good vote near the GT vote. + torch.sum(gt_primitive_mask.float()) + 1e-6) + + if self.primitive_mode != 'line': + size_xyz_reshape = primitive_semantic.view( + batch_size * num_proposal, -1, self.num_dim).contiguous() size_loss = self.center_loss( size_xyz_reshape, - seed_gt_votes_reshape, - dst_weight=seed_gt_votes_mask.view(batch_size * num_proposal, - 1))[1] + gt_primitive_semantic, + dst_weight=gt_primitive_mask.view(batch_size * num_proposal, + 1))[1] size_loss = size_loss.sum() / ( - torch.sum(seed_gt_votes_mask.float()) + 1e-6) + torch.sum(gt_primitive_mask.float()) + 1e-6) else: size_loss = torch.tensor(0).float().to(center_loss.device) - # 3.4 Semantic cls loss - sem_cls_label = seed_gt_sem[:, :, -1].long() - end_points['supp_sem_' + mode] = sem_cls_label - sem_cls_loss = self.semantic_loss( - end_points['sem_cls_scores_' + mode].transpose(2, 1), - sem_cls_label) - sem_cls_loss = torch.sum(sem_cls_loss * seed_gt_votes_mask.float()) / ( - torch.sum(seed_gt_votes_mask.float()) + 1e-6) + # Semantic cls loss + sem_cls_loss = self.semantic_loss(semantic_scores, gt_sem_cls_label) + sem_cls_loss = torch.sum(sem_cls_loss * gt_primitive_mask.float()) / ( + torch.sum(gt_primitive_mask.float()) + 1e-6) return center_loss, size_loss, sem_cls_loss diff --git a/tests/test_heads.py b/tests/test_heads.py index ca060a853e..87cc8b02cb 100644 --- a/tests/test_heads.py +++ b/tests/test_heads.py @@ -557,4 +557,3 @@ def test_primitive_head(): assert losses_dict['center_loss_z'] >= 0 assert losses_dict['size_loss_z'] >= 0 assert losses_dict['sem_loss_z'] >= 0 - assert losses_dict['surface_loss_z'] >= 0 From fa0454c8dd4d79ecebdf6e59b70c7e98b3549520 Mon Sep 17 00:00:00 2001 From: zhoujiaming Date: Wed, 19 Aug 2020 11:53:32 +0800 Subject: [PATCH 05/10] modify primitive head --- .../roi_heads/mask_heads/primitive_head.py | 490 ++++++++---------- tests/test_heads.py | 13 +- 2 files changed, 235 insertions(+), 268 deletions(-) diff --git a/mmdet3d/models/roi_heads/mask_heads/primitive_head.py b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py index 1966ca5498..cb0255949d 100644 --- a/mmdet3d/models/roi_heads/mask_heads/primitive_head.py +++ b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py @@ -11,7 +11,7 @@ @HEADS.register_module() class PrimitiveHead(nn.Module): - r"""Bbox head of `H3dnet `_. + r"""Primitive head of `H3DNet `_. Args: num_dim (int): The dimension of primitive semantic information. @@ -26,6 +26,8 @@ class PrimitiveHead(nn.Module): vote_aggregation_cfg (dict): Config of vote aggregation layer. feat_channels (tuple[int]): Convolution channels of prediction layer. + upper_thresh (float): Threshold for line matching. + surface_thresh (float): Threshold for suface matching. conv_cfg (dict): Config of convolution in prediction layer. norm_cfg (dict): Config of BN in prediction layer. objectness_loss (dict): Config of objectness loss. @@ -42,11 +44,14 @@ def __init__(self, vote_moudule_cfg=None, vote_aggregation_cfg=None, feat_channels=(128, 128), + upper_thresh=100.0, + surface_thresh=0.5, conv_cfg=dict(type='Conv1d'), norm_cfg=dict(type='BN1d'), objectness_loss=None, center_loss=None, - semantic_loss=None): + semantic_reg_loss=None, + semantic_cls_loss=None): super(PrimitiveHead, self).__init__() assert primitive_mode in ['z', 'xy', 'line'] # The dimension of primitive semantic information. @@ -57,10 +62,13 @@ def __init__(self, self.test_cfg = test_cfg self.gt_per_seed = vote_moudule_cfg['gt_per_seed'] self.num_proposal = vote_aggregation_cfg['num_point'] + self.upper_thresh = upper_thresh + self.surface_thresh = surface_thresh self.objectness_loss = build_loss(objectness_loss) self.center_loss = build_loss(center_loss) - self.semantic_loss = build_loss(semantic_loss) + self.semantic_reg_loss = build_loss(semantic_reg_loss) + self.semantic_cls_loss = build_loss(semantic_cls_loss) assert vote_aggregation_cfg['mlp_channels'][0] == vote_moudule_cfg[ 'in_channels'] @@ -101,6 +109,8 @@ def __init__(self, self.conv_pred.add_module('conv_out', nn.Conv1d(prev_channel, conv_out_channel, 1)) + self.softmax_normal = nn.Softmax(dim=1) + def init_weights(self): """Initialize weights of VoteHead.""" pass @@ -122,10 +132,10 @@ def forward(self, feat_dict, sample_mod): seed_features = feat_dict['hd_feature'] results = {} - net_flag = self.flag_conv(seed_features) - net_flag = self.flag_pred(net_flag) + primitive_flag = self.flag_conv(seed_features) + primitive_flag = self.flag_pred(primitive_flag) - results['pred_flag_' + self.primitive_mode] = net_flag + results['pred_flag_' + self.primitive_mode] = primitive_flag # 1. generate vote_points from seed_points vote_points, vote_features = self.vote_module(seed_points, @@ -159,14 +169,19 @@ def forward(self, feat_dict, sample_mod): results['aggregated_indices_' + self.primitive_mode] = aggregated_indices - # 3. predict bbox and score + # 3. predict primitive offsets and semantic information predictions = self.conv_pred(features) # 4. decode predictions - newcenter, decode_res = self.primitive_decode_scores( - predictions, aggregated_points, mode=self.primitive_mode) + primitive_center, decode_res = self.primitive_decode_scores( + predictions, aggregated_points) results.update(decode_res) + center, pred_ind = self.get_primitive_center(primitive_flag, + primitive_center) + + results['pred_' + self.primitive_mode + '_ind'] = pred_ind + results['pred_' + self.primitive_mode + '_center'] = center return results def loss(self, @@ -227,6 +242,9 @@ def loss(self, primitive_semantic = None semancitc_scores = bbox_preds['sem_cls_scores_' + self.primitive_mode].transpose(2, 1) + + gt_primitive_mask = gt_primitive_mask / \ + (gt_primitive_mask.sum() + 1e-6) center_loss, size_loss, sem_cls_loss = self.compute_primitivesem_loss( primitive_center, primitive_semantic, semancitc_scores, num_proposal, gt_primitive_center, gt_primitive_semantic, @@ -337,22 +355,11 @@ def get_targets_single(self, gt_bboxes_3d = gt_bboxes_3d.to(points.device) num_points = points.shape[0] - if self.primitive_mode == 'z': - point_boundary_mask_z = points.new_zeros(num_points) - point_boundary_offset_z = points.new_zeros([num_points, 3]) - point_boundary_sem_z = points.new_zeros( - [num_points, 3 + self.num_dim + 1]) - elif self.primitive_mode == 'xy': - point_boundary_mask_xy = points.new_zeros(num_points) - point_boundary_offset_xy = points.new_zeros([num_points, 3]) - point_boundary_sem_xy = points.new_zeros( - [num_points, 3 + self.num_dim + 1]) - elif self.primitive_mode == 'line': - point_line_mask = points.new_zeros(num_points) - point_line_offset = points.new_zeros([num_points, 3]) - point_line_sem = points.new_zeros([num_points, 3 + 1]) - else: - NotImplementedError + point_mask = points.new_zeros(num_points) + # Offset to the primitive center + point_offset = points.new_zeros([num_points, 3]) + # Semantic information of primitive center + point_sem = points.new_zeros([num_points, 3 + self.num_dim + 1]) instance_flag = torch.nonzero( pts_semantic_mask != self.num_classes).squeeze(1) @@ -362,14 +369,14 @@ def get_targets_single(self, ind = instance_flag[pts_instance_mask[instance_flag] == i_instance] x = points[ind, :3] - # Corners - corners = gt_bboxes_3d.corners[i][[0, 1, 3, 2, 4, 5, 7, 6]] - xmin, ymin, zmin = corners.min(0)[0] - xmax, ymax, zmax = corners.max(0)[0] + # Bbox Corners + cur_corners = gt_bboxes_3d.corners[i] + xmin, ymin, zmin = cur_corners.min(0)[0] + xmax, ymax, zmax = cur_corners.max(0)[0] - # Get lower four lines - plane_lower_temp = points.new_tensor([0, 0, 1, -corners[6, -1]]) - para_points = corners[[1, 3, 5, 7]] + plane_lower_temp = points.new_tensor( + [0, 0, 1, -cur_corners[7, -1]]) + para_points = cur_corners[[1, 2, 5, 6]] newd = torch.sum(para_points * plane_lower_temp[:3], 1) if self.check_upright(para_points) and \ plane_lower_temp[0] + plane_lower_temp[1] < \ @@ -379,154 +386,107 @@ def get_targets_single(self, plane_upper = points.new_tensor([0, 0, 1, -torch.mean(newd)]) else: raise NotImplementedError - # print('error with upright') if self.check_z(plane_upper, para_points) is False: raise NotImplementedError # Get the boundary points here - alldist = torch.abs( + point2plane_dist = torch.abs( torch.sum(x * plane_lower[:3], 1) + plane_lower[-1]) - mind = alldist.min() - sel = torch.abs(alldist - mind) < self.train_cfg['dist_thresh'] + min_dist = point2plane_dist.min() + selected = torch.abs(point2plane_dist - + min_dist) < self.train_cfg['dist_thresh'] # Get lower four lines - line_sel1, line_sel2, line_sel3, line_sel4 = self.match_point2line( - x[sel], xmin, xmax, ymin, ymax) if self.primitive_mode == 'line': - if torch.sum(line_sel1) > self.train_cfg['num_point_line']: - point_line_mask[ind[sel][line_sel1]] = 1.0 - line_center = torch.mean(x[sel][line_sel1], axis=0) - line_center[1] = (ymin + ymax) / 2.0 - point_line_offset[ - ind[sel][line_sel1]] = line_center - x[sel][line_sel1] - point_line_sem[ind[sel][line_sel1]] = points.new_tensor([ - line_center[0], line_center[1], line_center[2], - pts_semantic_mask[ind][0] - ]) - if torch.sum(line_sel2) > self.train_cfg['num_point_line']: - point_line_mask[ind[sel][line_sel2]] = 1.0 - line_center = torch.mean(x[sel][line_sel2], axis=0) - line_center[1] = (ymin + ymax) / 2.0 - point_line_offset[ - ind[sel][line_sel2]] = line_center - x[sel][line_sel2] - point_line_sem[ind[sel][line_sel2]] = points.new_tensor([ - line_center[0], line_center[1], line_center[2], - pts_semantic_mask[ind][0] - ]) - if torch.sum(line_sel3) > self.train_cfg['num_point_line']: - point_line_mask[ind[sel][line_sel3]] = 1.0 - line_center = torch.mean(x[sel][line_sel3], axis=0) - line_center[0] = (xmin + xmax) / 2.0 - point_line_offset[ - ind[sel][line_sel3]] = line_center - x[sel][line_sel3] - point_line_sem[ind[sel][line_sel3]] = points.new_tensor([ - line_center[0], line_center[1], line_center[2], - pts_semantic_mask[ind][0] - ]) - if torch.sum(line_sel4) > self.train_cfg['num_point_line']: - point_line_mask[ind[sel][line_sel4]] = 1.0 - line_center = torch.mean(x[sel][line_sel4], axis=0) - line_center[0] = (xmin + xmax) / 2.0 - point_line_offset[ - ind[sel][line_sel4]] = line_center - x[sel][line_sel4] - point_line_sem[ind[sel][line_sel4]] = points.new_tensor([ - line_center[0], line_center[1], line_center[2], - (pts_semantic_mask[ind][0]) - ]) + point2_lines_matching = self.match_point2line( + x[selected], xmin, xmax, ymin, ymax) + for idx, line_select in enumerate(point2_lines_matching): + if torch.sum(line_select) > \ + self.train_cfg['num_point_line']: + point_mask[ind[selected][line_select]] = 1.0 + line_center = torch.mean( + x[selected][line_select], axis=0) + if idx < 2: + line_center[1] = (ymin + ymax) / 2.0 + else: + line_center[0] = (xmin + xmax) / 2.0 + point_offset[ind[selected][line_select]] = \ + line_center - x[selected][line_select] + point_sem[ind[selected][line_select]] = \ + points.new_tensor([line_center[0], line_center[1], + line_center[2], + pts_semantic_mask[ind][0]]) # Set the surface labels here if self.primitive_mode == 'z': - if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( - alldist[sel]) < self.train_cfg['var_thresh']: + if torch.sum(selected) > self.train_cfg['num_point'] and \ + torch.var(point2plane_dist[selected]) < \ + self.train_cfg['var_thresh']: center = points.new_tensor([(xmin + xmax) / 2.0, (ymin + ymax) / 2.0, - torch.mean(x[sel][:, 2])]) - sel_global = ind[sel] - point_boundary_mask_z[sel_global] = 1.0 - point_boundary_sem_z[sel_global] = points.new_tensor([ + torch.mean(x[selected][:, 2])]) + sel_global = ind[selected] + point_mask[sel_global] = 1.0 + point_sem[sel_global] = points.new_tensor([ center[0], center[1], center[2], xmax - xmin, ymax - ymin, (pts_semantic_mask[ind][0]) ]) - point_boundary_offset_z[sel_global] = center - x[sel] + point_offset[sel_global] = center - x[selected] # Get the boundary points here - alldist = torch.abs( + point2plane_dist = torch.abs( torch.sum(x * plane_upper[:3], 1) + plane_upper[-1]) - mind = alldist.min() - sel = torch.abs(alldist - mind) < self.train_cfg['dist_thresh'] + min_dist = point2plane_dist.min() + selected = torch.abs(point2plane_dist - + min_dist) < self.train_cfg['dist_thresh'] # Get upper four lines - line_sel1, line_sel2, line_sel3, line_sel4 = self.match_point2line( - x[sel], xmin, xmax, ymin, ymax) - if self.primitive_mode == 'line': - if torch.sum(line_sel1) > self.train_cfg['num_point_line']: - point_line_mask[ind[sel][line_sel1]] = 1.0 - line_center = torch.mean(x[sel][line_sel1], axis=0) - line_center[1] = (ymin + ymax) / 2.0 - point_line_offset[ - ind[sel][line_sel1]] = line_center - x[sel][line_sel1] - point_line_sem[ind[sel][line_sel1]] = points.new_tensor([ - line_center[0], line_center[1], line_center[2], - (pts_semantic_mask[ind][0]) - ]) - if torch.sum(line_sel2) > self.train_cfg['num_point_line']: - point_line_mask[ind[sel][line_sel2]] = 1.0 - line_center = torch.mean(x[sel][line_sel2], axis=0) - line_center[1] = (ymin + ymax) / 2.0 - point_line_offset[ - ind[sel][line_sel2]] = line_center - x[sel][line_sel2] - point_line_sem[ind[sel][line_sel2]] = points.new_tensor([ - line_center[0], line_center[1], line_center[2], - (pts_semantic_mask[ind][0]) - ]) - if torch.sum(line_sel3) > self.train_cfg['num_point_line']: - point_line_mask[ind[sel][line_sel3]] = 1.0 - line_center = torch.mean(x[sel][line_sel3], axis=0) - line_center[0] = (xmin + xmax) / 2.0 - point_line_offset[ - ind[sel][line_sel3]] = line_center - x[sel][line_sel3] - point_line_sem[ind[sel][line_sel3]] = points.new_tensor([ - line_center[0], line_center[1], line_center[2], - (pts_semantic_mask[ind][0]) - ]) - if torch.sum(line_sel4) > self.train_cfg['num_point_line']: - point_line_mask[ind[sel][line_sel4]] = 1.0 - line_center = torch.mean(x[sel][line_sel4], axis=0) - line_center[0] = (xmin + xmax) / 2.0 - point_line_offset[ - ind[sel][line_sel4]] = line_center - x[sel][line_sel4] - point_line_sem[ind[sel][line_sel4]] = points.new_tensor([ - line_center[0], line_center[1], line_center[2], - (pts_semantic_mask[ind][0]) - ]) + point2_lines_matching = self.match_point2line( + x[selected], xmin, xmax, ymin, ymax) + for idx, line_select in enumerate(point2_lines_matching): + if torch.sum(line_select) > \ + self.train_cfg['num_point_line']: + point_mask[ind[selected][line_select]] = 1.0 + line_center = torch.mean( + x[selected][line_select], axis=0) + if idx < 2: + line_center[1] = (ymin + ymax) / 2.0 + else: + line_center[0] = (xmin + xmax) / 2.0 + point_offset[ind[selected][line_select]] = \ + line_center - x[selected][line_select] + point_sem[ind[selected][line_select]] = \ + points.new_tensor([line_center[0], line_center[1], + line_center[2], + pts_semantic_mask[ind][0]]) if self.primitive_mode == 'z': - if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( - alldist[sel]) < self.train_cfg['var_thresh']: + if torch.sum(selected) > self.train_cfg['num_point'] and \ + torch.var(point2plane_dist[selected]) < \ + self.train_cfg['var_thresh']: center = points.new_tensor([(xmin + xmax) / 2.0, (ymin + ymax) / 2.0, - torch.mean(x[sel][:, 2])]) - sel_global = ind[sel] - point_boundary_mask_z[sel_global] = 1.0 - point_boundary_sem_z[sel_global] = points.new_tensor([ + torch.mean(x[selected][:, 2])]) + sel_global = ind[selected] + point_mask[sel_global] = 1.0 + point_sem[sel_global] = points.new_tensor([ center[0], center[1], center[2], xmax - xmin, ymax - ymin, (pts_semantic_mask[ind][0]) ]) - point_boundary_offset_z[sel_global] = center - x[sel] + point_offset[sel_global] = center - x[selected] # Get left two lines - vec1 = corners[3] - corners[2] - vec2 = corners[2] - corners[0] + vec1 = cur_corners[2] - cur_corners[3] + vec2 = cur_corners[3] - cur_corners[0] surface_norm = torch.cross(vec1, vec2) - surface_dis = -torch.dot(surface_norm, corners[0]) + surface_dis = -torch.dot(surface_norm, cur_corners[0]) plane_left_temp = points.new_tensor([ surface_norm[0], surface_norm[1], surface_norm[2], surface_dis ]) - para_points = corners[[4, 5, 6, 7]] - # Normalize xy here + para_points = cur_corners[[4, 5, 7, 6]] plane_left_temp /= torch.norm(plane_left_temp[:3]) newd = torch.sum(para_points * plane_left_temp[:3], 1) if plane_left_temp[2] < self.train_cfg['lower_thresh']: @@ -537,110 +497,98 @@ def get_targets_single(self, ]) else: raise NotImplementedError - # print('error with upright') # Get the boundary points here - alldist = torch.abs( + point2plane_dist = torch.abs( torch.sum(x * plane_left[:3], 1) + plane_left[-1]) - mind = alldist.min() - sel = torch.abs(alldist - mind) < self.train_cfg['dist_thresh'] + min_dist = point2plane_dist.min() + selected = torch.abs(point2plane_dist - + min_dist) < self.train_cfg['dist_thresh'] # Get upper four lines - _, _, line_sel1, line_sel2 = self.match_point2line( - x[sel], xmin, xmax, ymin, ymax) - if self.primitive_mode == 'line': - if torch.sum(line_sel1) > self.train_cfg['num_point_line']: - point_line_mask[ind[sel][line_sel1]] = 1.0 - line_center = torch.mean(x[sel][line_sel1], axis=0) - line_center[2] = (zmin + zmax) / 2.0 - point_line_offset[ - ind[sel][line_sel1]] = line_center - x[sel][line_sel1] - point_line_sem[ind[sel][line_sel1]] = points.new_tensor([ - line_center[0], line_center[1], line_center[2], - (pts_semantic_mask[ind][0]) - ]) - if torch.sum(line_sel2) > self.train_cfg['num_point_line']: - point_line_mask[ind[sel][line_sel2]] = 1.0 - line_center = torch.mean(x[sel][line_sel2], axis=0) - line_center[2] = (zmin + zmax) / 2.0 - point_line_offset[ - ind[sel][line_sel2]] = line_center - x[sel][line_sel2] - point_line_sem[ind[sel][line_sel2]] = points.new_tensor([ - line_center[0], line_center[1], line_center[2], - (pts_semantic_mask[ind][0]) - ]) + _, _, line_sel1, line_sel2 = self.match_point2line( + x[selected], xmin, xmax, ymin, ymax) + for idx, line_select in enumerate([line_sel1, line_sel2]): + if torch.sum(line_select) > \ + self.train_cfg['num_point_line']: + point_mask[ind[selected][line_select]] = 1.0 + line_center = torch.mean( + x[selected][line_select], axis=0) + line_center[2] = (zmin + zmax) / 2.0 + point_offset[ind[selected][line_select]] = \ + line_center - x[selected][line_select] + point_sem[ind[selected][line_select]] = \ + points.new_tensor([line_center[0], line_center[1], + line_center[2], + pts_semantic_mask[ind][0]]) if self.primitive_mode == 'xy': - if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( - alldist[sel]) < self.train_cfg['var_thresh']: + if torch.sum(selected) > self.train_cfg['num_point'] and \ + torch.var(point2plane_dist[selected]) < \ + self.train_cfg['var_thresh']: center = points.new_tensor([ - torch.mean(x[sel][:, 0]), - torch.mean(x[sel][:, 1]), (zmin + zmax) / 2.0 + torch.mean(x[selected][:, 0]), + torch.mean(x[selected][:, 1]), (zmin + zmax) / 2.0 ]) - sel_global = ind[sel] - point_boundary_mask_xy[sel_global] = 1.0 - point_boundary_sem_xy[sel_global] = points.new_tensor([ + sel_global = ind[selected] + point_mask[sel_global] = 1.0 + point_sem[sel_global] = points.new_tensor([ center[0], center[1], center[2], zmax - zmin, (pts_semantic_mask[ind][0]) ]) - point_boundary_offset_xy[sel_global] = center - x[sel] + point_offset[sel_global] = center - x[selected] # Get the boundary points here - alldist = torch.abs( + point2plane_dist = torch.abs( torch.sum(x * plane_right[:3], 1) + plane_right[-1]) - mind = alldist.min() - sel = torch.abs(alldist - mind) < self.train_cfg['dist_thresh'] - _, _, line_sel1, line_sel2 = self.match_point2line( - x[sel], xmin, xmax, ymin, ymax) + min_dist = point2plane_dist.min() + selected = torch.abs(point2plane_dist - + min_dist) < self.train_cfg['dist_thresh'] if self.primitive_mode == 'line': - if torch.sum(line_sel1) > self.train_cfg['num_point_line']: - point_line_mask[ind[sel][line_sel1]] = 1.0 - line_center = torch.mean(x[sel][line_sel1], axis=0) - line_center[2] = (zmin + zmax) / 2.0 - point_line_offset[ - ind[sel][line_sel1]] = line_center - x[sel][line_sel1] - point_line_sem[ind[sel][line_sel1]] = points.new_tensor([ - line_center[0], line_center[1], line_center[2], - (pts_semantic_mask[ind][0]) - ]) - if torch.sum(line_sel2) > self.train_cfg['num_point_line']: - point_line_mask[ind[sel][line_sel2]] = 1.0 - line_center = torch.mean(x[sel][line_sel2], axis=0) - line_center[2] = (zmin + zmax) / 2.0 - point_line_offset[ - ind[sel][line_sel2]] = line_center - x[sel][line_sel2] - point_line_sem[ind[sel][line_sel2]] = points.new_tensor([ - line_center[0], line_center[1], line_center[2], - (pts_semantic_mask[ind][0]) - ]) + _, _, line_sel1, line_sel2 = self.match_point2line( + x[selected], xmin, xmax, ymin, ymax) + for idx, line_select in enumerate([line_sel1, line_sel2]): + if torch.sum(line_select) > \ + self.train_cfg['num_point_line']: + point_mask[ind[selected][line_select]] = 1.0 + line_center = torch.mean( + x[selected][line_select], axis=0) + line_center[2] = (zmin + zmax) / 2.0 + point_offset[ind[selected][line_select]] = \ + line_center - x[selected][line_select] + point_sem[ind[selected][line_select]] = \ + points.new_tensor([line_center[0], line_center[1], + line_center[2], + pts_semantic_mask[ind][0]]) if self.primitive_mode == 'xy': - if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( - alldist[sel]) < self.train_cfg['var_thresh']: + if torch.sum(selected) > self.train_cfg['num_point'] and \ + torch.var(point2plane_dist[selected]) < \ + self.train_cfg['var_thresh']: center = points.new_tensor([ - torch.mean(x[sel][:, 0]), - torch.mean(x[sel][:, 1]), (zmin + zmax) / 2.0 + torch.mean(x[selected][:, 0]), + torch.mean(x[selected][:, 1]), (zmin + zmax) / 2.0 ]) - sel_global = ind[sel] - point_boundary_mask_xy[sel_global] = 1.0 - point_boundary_sem_xy[sel_global] = points.new_tensor([ + sel_global = ind[selected] + point_mask[sel_global] = 1.0 + point_sem[sel_global] = points.new_tensor([ center[0], center[1], center[2], zmax - zmin, (pts_semantic_mask[ind][0]) ]) - point_boundary_offset_xy[sel_global] = center - x[sel] + point_offset[sel_global] = center - x[selected] # Get the boundary points here - vec1 = corners[0] - corners[4] - vec2 = corners[4] - corners[5] + vec1 = cur_corners[0] - cur_corners[4] + vec2 = cur_corners[4] - cur_corners[5] surface_norm = torch.cross(vec1, vec2) - surface_dis = -torch.dot(surface_norm, corners[5]) + surface_dis = -torch.dot(surface_norm, cur_corners[5]) plane_front_temp = points.new_tensor([ surface_norm[0], surface_norm[1], surface_norm[2], surface_dis ]) - para_points = corners[[2, 3, 6, 7]] + para_points = cur_corners[[3, 2, 7, 6]] plane_front_temp /= torch.norm(plane_front_temp[:3]) newd = torch.sum(para_points * plane_front_temp[:3], 1) if plane_front_temp[2] < self.train_cfg['lower_thresh']: @@ -651,82 +599,77 @@ def get_targets_single(self, ]) else: raise NotImplementedError - # print('error with upright') # Get the boundary points here - alldist = torch.abs( + point2plane_dist = torch.abs( torch.sum(x * plane_front[:3], 1) + plane_front[-1]) - mind = alldist.min() - sel = torch.abs(alldist - mind) < self.train_cfg['dist_thresh'] + min_dist = point2plane_dist.min() + selected = torch.abs(point2plane_dist - + min_dist) < self.train_cfg['dist_thresh'] if self.primitive_mode == 'xy': - if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( - alldist[sel]) < self.train_cfg['var_thresh']: + if torch.sum(selected) > self.train_cfg['num_point'] and \ + torch.var(point2plane_dist[selected]) < \ + self.train_cfg['var_thresh']: center = points.new_tensor([ - torch.mean(x[sel][:, 0]), - torch.mean(x[sel][:, 1]), (zmin + zmax) / 2.0 + torch.mean(x[selected][:, 0]), + torch.mean(x[selected][:, 1]), (zmin + zmax) / 2.0 ]) - sel_global = ind[sel] - point_boundary_mask_xy[sel_global] = 1.0 - point_boundary_sem_xy[sel_global] = points.new_tensor([ + sel_global = ind[selected] + point_mask[sel_global] = 1.0 + point_sem[sel_global] = points.new_tensor([ center[0], center[1], center[2], zmax - zmin, (pts_semantic_mask[ind][0]) ]) - point_boundary_offset_xy[sel_global] = center - x[sel] + point_offset[sel_global] = center - x[selected] + # Get the boundary points here - alldist = torch.abs( + point2plane_dist = torch.abs( torch.sum(x * plane_back[:3], 1) + plane_back[-1]) - mind = alldist.min() - sel = torch.abs(alldist - mind) < self.train_cfg['dist_thresh'] + min_dist = point2plane_dist.min() + selected = torch.abs(point2plane_dist - + min_dist) < self.train_cfg['dist_thresh'] if self.primitive_mode == 'xy': - if torch.sum(sel) > self.train_cfg['num_point'] and torch.var( - alldist[sel]) < self.train_cfg['var_thresh']: + if torch.sum(selected) > self.train_cfg['num_point'] and \ + torch.var(point2plane_dist[selected]) < \ + self.train_cfg['var_thresh']: center = points.new_tensor([ - torch.mean(x[sel][:, 0]), - torch.mean(x[sel][:, 1]), (zmin + zmax) / 2.0 + torch.mean(x[selected][:, 0]), + torch.mean(x[selected][:, 1]), (zmin + zmax) / 2.0 ]) - sel_global = ind[sel] - point_boundary_mask_xy[sel_global] = 1.0 - point_boundary_sem_xy[sel_global] = points.new_tensor([ + sel_global = ind[selected] + point_mask[sel_global] = 1.0 + point_sem[sel_global] = points.new_tensor([ center[0], center[1], center[2], zmax - zmin, (pts_semantic_mask[ind][0]) ]) - point_boundary_offset_xy[sel_global] = center - x[sel] - - if self.primitive_mode == 'z': - return (point_boundary_mask_z, point_boundary_sem_z, - point_boundary_offset_z) - elif self.primitive_mode == 'xy': - return (point_boundary_mask_xy, point_boundary_sem_xy, - point_boundary_offset_xy) - elif self.primitive_mode == 'line': - return (point_line_mask, point_line_sem, point_line_offset) - else: - raise NotImplementedError + point_offset[sel_global] = center - x[selected] + + return (point_mask, point_sem, point_offset) - def primitive_decode_scores(self, preds, aggregated_points, mode='z'): + def primitive_decode_scores(self, preds, aggregated_points): """Decode the outputs of primitive module. Args: preds (torch.Tensor): primitive pridictions of each batch. aggregated_points (torch.Tensor): The aggregated points of vote stage. - mode (string): The type of primitive module. Returns: Dict: Targets of center, size and semantic. """ + ret_dict = {} net_transposed = preds.transpose(2, 1) center = aggregated_points + net_transposed[:, :, 0:3] - ret_dict['center_' + mode] = center + ret_dict['center_' + self.primitive_mode] = center - if mode in ['z', 'xy']: - ret_dict['size_residuals_' + mode] = net_transposed[:, :, 3:3 + - self.num_dim] + if self.primitive_mode in ['z', 'xy']: + ret_dict['size_residuals_' + self.primitive_mode] = \ + net_transposed[:, :, 3:3 + self.num_dim] - ret_dict['sem_cls_scores_' + mode] = net_transposed[:, :, - 3 + self.num_dim:] + ret_dict['sem_cls_scores_' + self.primitive_mode] = \ + net_transposed[:, :, 3 + self.num_dim:] return center, ret_dict @@ -793,29 +736,46 @@ def compute_primitivesem_loss(self, primitive_center, primitive_semantic, batch_size = primitive_center.shape[0] vote_xyz_reshape = primitive_center.view(batch_size * num_proposal, -1, 3) + center_loss = self.center_loss( vote_xyz_reshape, gt_primitive_center, dst_weight=gt_primitive_mask.view(batch_size * num_proposal, 1))[1] - center_loss = center_loss.sum() / ( - torch.sum(gt_primitive_mask.float()) + 1e-6) if self.primitive_mode != 'line': size_xyz_reshape = primitive_semantic.view( batch_size * num_proposal, -1, self.num_dim).contiguous() - size_loss = self.center_loss( + size_loss = self.semantic_reg_loss( size_xyz_reshape, gt_primitive_semantic, dst_weight=gt_primitive_mask.view(batch_size * num_proposal, 1))[1] - size_loss = size_loss.sum() / ( - torch.sum(gt_primitive_mask.float()) + 1e-6) else: size_loss = torch.tensor(0).float().to(center_loss.device) # Semantic cls loss - sem_cls_loss = self.semantic_loss(semantic_scores, gt_sem_cls_label) - sem_cls_loss = torch.sum(sem_cls_loss * gt_primitive_mask.float()) / ( - torch.sum(gt_primitive_mask.float()) + 1e-6) + sem_cls_loss = self.semantic_cls_loss( + semantic_scores, + gt_sem_cls_label, + weight=gt_primitive_mask.float()) return center_loss, size_loss, sem_cls_loss + + def get_primitive_center(self, pred_flag, center): + """Generate primitive center from primitive head predictions. + + Args: + pred_flag (torch.Tensor): Scores of primitive center. + center (torch.Tensor): Pridictions of primitive center. + + Returns: + Tuple: Primitive center and the prediction indices. + """ + ind_normal = self.softmax_normal(pred_flag) + pred_indices = (ind_normal[:, 1, :] > + self.surface_thresh).detach().float() + selected = (ind_normal[:, 1, :] <= + self.surface_thresh).detach().float() + offset = torch.ones_like(center) * self.upper_thresh + center = center + offset * selected.unsqueeze(-1) + return center, pred_indices diff --git a/tests/test_heads.py b/tests/test_heads.py index 87cc8b02cb..410725d7c9 100644 --- a/tests/test_heads.py +++ b/tests/test_heads.py @@ -500,11 +500,17 @@ def test_primitive_head(): center_loss=dict( type='ChamferDistance', mode='l1', - reduction='none', + reduction='sum', loss_src_weight=1.0, loss_dst_weight=1.0), - semantic_loss=dict( - type='CrossEntropyLoss', reduction='none', loss_weight=1.0), + semantic_reg_loss=dict( + type='ChamferDistance', + mode='l1', + reduction='sum', + loss_src_weight=1.0, + loss_dst_weight=1.0), + semantic_cls_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), train_cfg=dict( dist_thresh=0.2, var_thresh=1e-2, @@ -512,6 +518,7 @@ def test_primitive_head(): num_point=100, num_point_line=10, line_thresh=0.2)) + self = build_head(primitive_head_cfg).cuda() fp_xyz = [torch.rand([2, 64, 3], dtype=torch.float32).cuda()] hd_features = torch.rand([2, 256, 64], dtype=torch.float32).cuda() From ca996c173f60e5436fd4044bc7294a515cedfb72 Mon Sep 17 00:00:00 2001 From: zhoujiaming Date: Wed, 19 Aug 2020 20:26:34 +0800 Subject: [PATCH 06/10] modify primitive head --- .../roi_heads/mask_heads/primitive_head.py | 287 ++++++++++-------- 1 file changed, 153 insertions(+), 134 deletions(-) diff --git a/mmdet3d/models/roi_heads/mask_heads/primitive_head.py b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py index cb0255949d..9e70759c68 100644 --- a/mmdet3d/models/roi_heads/mask_heads/primitive_head.py +++ b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py @@ -1,6 +1,7 @@ import torch from mmcv.cnn import ConvModule from torch import nn as nn +from torch.nn import functional as F from mmdet3d.models.builder import build_loss from mmdet3d.models.model_utils import VoteModule @@ -14,7 +15,7 @@ class PrimitiveHead(nn.Module): r"""Primitive head of `H3DNet `_. Args: - num_dim (int): The dimension of primitive semantic information. + num_dims (int): The dimension of primitive semantic information. num_classes (int): The number of class. primitive_mode (str): The mode of primitive module, avaliable mode ['z', 'xy', 'line']. @@ -36,7 +37,7 @@ class PrimitiveHead(nn.Module): """ def __init__(self, - num_dim, + num_dims, num_classes, primitive_mode, train_cfg=None, @@ -55,7 +56,7 @@ def __init__(self, super(PrimitiveHead, self).__init__() assert primitive_mode in ['z', 'xy', 'line'] # The dimension of primitive semantic information. - self.num_dim = num_dim + self.num_dims = num_dims self.num_classes = num_classes self.primitive_mode = primitive_mode self.train_cfg = train_cfg @@ -105,21 +106,19 @@ def __init__(self, prev_channel = feat_channels[k] self.conv_pred = nn.Sequential(*conv_pred_list) - conv_out_channel = 3 + num_dim + num_classes + conv_out_channel = 3 + num_dims + num_classes self.conv_pred.add_module('conv_out', nn.Conv1d(prev_channel, conv_out_channel, 1)) - self.softmax_normal = nn.Softmax(dim=1) - def init_weights(self): """Initialize weights of VoteHead.""" pass - def forward(self, feat_dict, sample_mod): + def forward(self, feats_dict, sample_mod): """Forward pass. Args: - feat_dict (dict): Feature dict from backbone. + feats_dict (dict): Feature dict from backbone. sample_mod (str): Sample mode for vote aggregation layer. valid modes are "vote", "seed" and "random". @@ -128,8 +127,8 @@ def forward(self, feat_dict, sample_mod): """ assert sample_mod in ['vote', 'seed', 'random'] - seed_points = feat_dict['fp_xyz_net0'][-1] - seed_features = feat_dict['hd_feature'] + seed_points = feats_dict['fp_xyz_net0'][-1] + seed_features = feats_dict['hd_feature'] results = {} primitive_flag = self.flag_conv(seed_features) @@ -154,9 +153,11 @@ def forward(self, feat_dict, sample_mod): elif sample_mod == 'random': # Random sampling from the votes batch_size, num_seed = seed_points.shape[:2] - sample_indices = seed_points.new_tensor( - torch.randint(0, num_seed, (batch_size, self.num_proposal)), - dtype=torch.int32) + sample_indices = torch.randint( + 0, + num_seed, (batch_size, self.num_proposal), + dtype=torch.int32, + device=seed_points.device) else: raise NotImplementedError @@ -173,12 +174,12 @@ def forward(self, feat_dict, sample_mod): predictions = self.conv_pred(features) # 4. decode predictions - primitive_center, decode_res = self.primitive_decode_scores( - predictions, aggregated_points) - results.update(decode_res) + decode_ret = self.primitive_decode_scores(predictions, + aggregated_points) + results.update(decode_ret) - center, pred_ind = self.get_primitive_center(primitive_flag, - primitive_center) + center, pred_ind = self.get_primitive_center( + primitive_flag, decode_ret['center_' + self.primitive_mode]) results['pred_' + self.primitive_mode + '_ind'] = pred_ind results['pred_' + self.primitive_mode + '_center'] = center @@ -210,7 +211,7 @@ def loss(self, which bounding. Returns: - dict: Losses of Votenet. + dict: Losses of Primitive Head. """ targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask, pts_instance_mask, @@ -245,7 +246,7 @@ def loss(self, gt_primitive_mask = gt_primitive_mask / \ (gt_primitive_mask.sum() + 1e-6) - center_loss, size_loss, sem_cls_loss = self.compute_primitivesem_loss( + center_loss, size_loss, sem_cls_loss = self.compute_primitive_loss( primitive_center, primitive_semantic, semancitc_scores, num_proposal, gt_primitive_center, gt_primitive_semantic, gt_sem_cls_label, gt_primitive_mask) @@ -262,7 +263,7 @@ def get_targets(self, pts_semantic_mask=None, pts_instance_mask=None, bbox_preds=None): - """Generate targets of vote head. + """Generate targets of primitive head. Args: points (list[torch.Tensor]): Points of each batch. @@ -319,10 +320,10 @@ def get_targets(self, 3) seed_inds_expand_sem = seed_inds.view(batch_size, num_seed, 1).repeat( - 1, 1, 4 + self.num_dim) + 1, 1, 4 + self.num_dims) seed_gt_sem = torch.gather(point_sem, 1, seed_inds_expand_sem) - gt_primitive_semantic = seed_gt_sem[:, :, 3:3 + self.num_dim].view( - batch_size * num_proposal, 1, self.num_dim).contiguous() + gt_primitive_semantic = seed_gt_sem[:, :, 3:3 + self.num_dims].view( + batch_size * num_proposal, 1, self.num_dims).contiguous() gt_sem_cls_label = seed_gt_sem[:, :, -1].long() @@ -337,7 +338,7 @@ def get_targets_single(self, gt_labels_3d, pts_semantic_mask=None, pts_instance_mask=None): - """Generate targets of vote head for single batch. + """Generate targets of primitive head for single batch. Args: points (torch.Tensor): Points of each batch. @@ -359,15 +360,16 @@ def get_targets_single(self, # Offset to the primitive center point_offset = points.new_zeros([num_points, 3]) # Semantic information of primitive center - point_sem = points.new_zeros([num_points, 3 + self.num_dim + 1]) + point_sem = points.new_zeros([num_points, 3 + self.num_dims + 1]) instance_flag = torch.nonzero( pts_semantic_mask != self.num_classes).squeeze(1) instance_labels = pts_instance_mask[instance_flag].unique() for i, i_instance in enumerate(instance_labels): - ind = instance_flag[pts_instance_mask[instance_flag] == i_instance] - x = points[ind, :3] + indices = instance_flag[pts_instance_mask[instance_flag] == + i_instance] + coords = points[indices, :3] # Bbox Corners cur_corners = gt_bboxes_3d.corners[i] @@ -376,23 +378,27 @@ def get_targets_single(self, plane_lower_temp = points.new_tensor( [0, 0, 1, -cur_corners[7, -1]]) - para_points = cur_corners[[1, 2, 5, 6]] - newd = torch.sum(para_points * plane_lower_temp[:3], 1) - if self.check_upright(para_points) and \ + upper_points = cur_corners[[1, 2, 5, 6]] + refined_distance = torch.sum(upper_points * plane_lower_temp[:3], + 1) + + if self.check_horizon(upper_points) and \ plane_lower_temp[0] + plane_lower_temp[1] < \ self.train_cfg['lower_thresh']: plane_lower = points.new_tensor( [0, 0, 1, plane_lower_temp[-1]]) - plane_upper = points.new_tensor([0, 0, 1, -torch.mean(newd)]) + plane_upper = points.new_tensor( + [0, 0, 1, -torch.mean(refined_distance)]) else: - raise NotImplementedError + raise NotImplementedError('Only horizontal plane is support!') - if self.check_z(plane_upper, para_points) is False: - raise NotImplementedError + if self.check_dist(plane_upper, upper_points) is False: + raise NotImplementedError( + 'Mean distance to plane should be lower than thresh!') # Get the boundary points here point2plane_dist = torch.abs( - torch.sum(x * plane_lower[:3], 1) + plane_lower[-1]) + torch.sum(coords * plane_lower[:3], 1) + plane_lower[-1]) min_dist = point2plane_dist.min() selected = torch.abs(point2plane_dist - min_dist) < self.train_cfg['dist_thresh'] @@ -400,43 +406,44 @@ def get_targets_single(self, # Get lower four lines if self.primitive_mode == 'line': point2_lines_matching = self.match_point2line( - x[selected], xmin, xmax, ymin, ymax) + coords[selected], xmin, xmax, ymin, ymax) for idx, line_select in enumerate(point2_lines_matching): if torch.sum(line_select) > \ self.train_cfg['num_point_line']: - point_mask[ind[selected][line_select]] = 1.0 + point_mask[indices[selected][line_select]] = 1.0 line_center = torch.mean( - x[selected][line_select], axis=0) + coords[selected][line_select], axis=0) if idx < 2: line_center[1] = (ymin + ymax) / 2.0 else: line_center[0] = (xmin + xmax) / 2.0 - point_offset[ind[selected][line_select]] = \ - line_center - x[selected][line_select] - point_sem[ind[selected][line_select]] = \ + point_offset[indices[selected][line_select]] = \ + line_center - coords[selected][line_select] + point_sem[indices[selected][line_select]] = \ points.new_tensor([line_center[0], line_center[1], line_center[2], - pts_semantic_mask[ind][0]]) + pts_semantic_mask[indices][0]]) # Set the surface labels here if self.primitive_mode == 'z': if torch.sum(selected) > self.train_cfg['num_point'] and \ torch.var(point2plane_dist[selected]) < \ self.train_cfg['var_thresh']: - center = points.new_tensor([(xmin + xmax) / 2.0, - (ymin + ymax) / 2.0, - torch.mean(x[selected][:, 2])]) - sel_global = ind[selected] + center = points.new_tensor([ + (xmin + xmax) / 2.0, (ymin + ymax) / 2.0, + torch.mean(coords[selected][:, 2]) + ]) + sel_global = indices[selected] point_mask[sel_global] = 1.0 point_sem[sel_global] = points.new_tensor([ center[0], center[1], center[2], xmax - xmin, - ymax - ymin, (pts_semantic_mask[ind][0]) + ymax - ymin, (pts_semantic_mask[indices][0]) ]) - point_offset[sel_global] = center - x[selected] + point_offset[sel_global] = center - coords[selected] # Get the boundary points here point2plane_dist = torch.abs( - torch.sum(x * plane_upper[:3], 1) + plane_upper[-1]) + torch.sum(coords * plane_upper[:3], 1) + plane_upper[-1]) min_dist = point2plane_dist.min() selected = torch.abs(point2plane_dist - min_dist) < self.train_cfg['dist_thresh'] @@ -444,38 +451,39 @@ def get_targets_single(self, # Get upper four lines if self.primitive_mode == 'line': point2_lines_matching = self.match_point2line( - x[selected], xmin, xmax, ymin, ymax) + coords[selected], xmin, xmax, ymin, ymax) for idx, line_select in enumerate(point2_lines_matching): if torch.sum(line_select) > \ self.train_cfg['num_point_line']: - point_mask[ind[selected][line_select]] = 1.0 + point_mask[indices[selected][line_select]] = 1.0 line_center = torch.mean( - x[selected][line_select], axis=0) + coords[selected][line_select], axis=0) if idx < 2: line_center[1] = (ymin + ymax) / 2.0 else: line_center[0] = (xmin + xmax) / 2.0 - point_offset[ind[selected][line_select]] = \ - line_center - x[selected][line_select] - point_sem[ind[selected][line_select]] = \ + point_offset[indices[selected][line_select]] = \ + line_center - coords[selected][line_select] + point_sem[indices[selected][line_select]] = \ points.new_tensor([line_center[0], line_center[1], line_center[2], - pts_semantic_mask[ind][0]]) + pts_semantic_mask[indices][0]]) if self.primitive_mode == 'z': if torch.sum(selected) > self.train_cfg['num_point'] and \ torch.var(point2plane_dist[selected]) < \ self.train_cfg['var_thresh']: - center = points.new_tensor([(xmin + xmax) / 2.0, - (ymin + ymax) / 2.0, - torch.mean(x[selected][:, 2])]) - sel_global = ind[selected] + center = points.new_tensor([ + (xmin + xmax) / 2.0, (ymin + ymax) / 2.0, + torch.mean(coords[selected][:, 2]) + ]) + sel_global = indices[selected] point_mask[sel_global] = 1.0 point_sem[sel_global] = points.new_tensor([ center[0], center[1], center[2], xmax - xmin, - ymax - ymin, (pts_semantic_mask[ind][0]) + ymax - ymin, (pts_semantic_mask[indices][0]) ]) - point_offset[sel_global] = center - x[selected] + point_offset[sel_global] = center - coords[selected] # Get left two lines vec1 = cur_corners[2] - cur_corners[3] @@ -486,21 +494,22 @@ def get_targets_single(self, surface_norm[0], surface_norm[1], surface_norm[2], surface_dis ]) - para_points = cur_corners[[4, 5, 7, 6]] + right_points = cur_corners[[4, 5, 7, 6]] plane_left_temp /= torch.norm(plane_left_temp[:3]) - newd = torch.sum(para_points * plane_left_temp[:3], 1) + refined_distance = torch.sum(right_points * plane_left_temp[:3], 1) if plane_left_temp[2] < self.train_cfg['lower_thresh']: plane_left = plane_left_temp plane_right = points.new_tensor([ plane_left_temp[0], plane_left_temp[1], plane_left_temp[2], - -torch.mean(newd) + -torch.mean(refined_distance) ]) else: - raise NotImplementedError + raise NotImplementedError( + 'Normal vector of the plane should be horizontal!') # Get the boundary points here point2plane_dist = torch.abs( - torch.sum(x * plane_left[:3], 1) + plane_left[-1]) + torch.sum(coords * plane_left[:3], 1) + plane_left[-1]) min_dist = point2plane_dist.min() selected = torch.abs(point2plane_dist - min_dist) < self.train_cfg['dist_thresh'] @@ -508,76 +517,76 @@ def get_targets_single(self, # Get upper four lines if self.primitive_mode == 'line': _, _, line_sel1, line_sel2 = self.match_point2line( - x[selected], xmin, xmax, ymin, ymax) + coords[selected], xmin, xmax, ymin, ymax) for idx, line_select in enumerate([line_sel1, line_sel2]): if torch.sum(line_select) > \ self.train_cfg['num_point_line']: - point_mask[ind[selected][line_select]] = 1.0 + point_mask[indices[selected][line_select]] = 1.0 line_center = torch.mean( - x[selected][line_select], axis=0) + coords[selected][line_select], axis=0) line_center[2] = (zmin + zmax) / 2.0 - point_offset[ind[selected][line_select]] = \ - line_center - x[selected][line_select] - point_sem[ind[selected][line_select]] = \ + point_offset[indices[selected][line_select]] = \ + line_center - coords[selected][line_select] + point_sem[indices[selected][line_select]] = \ points.new_tensor([line_center[0], line_center[1], line_center[2], - pts_semantic_mask[ind][0]]) + pts_semantic_mask[indices][0]]) if self.primitive_mode == 'xy': if torch.sum(selected) > self.train_cfg['num_point'] and \ torch.var(point2plane_dist[selected]) < \ self.train_cfg['var_thresh']: center = points.new_tensor([ - torch.mean(x[selected][:, 0]), - torch.mean(x[selected][:, 1]), (zmin + zmax) / 2.0 + torch.mean(coords[selected][:, 0]), + torch.mean(coords[selected][:, 1]), (zmin + zmax) / 2.0 ]) - sel_global = ind[selected] + sel_global = indices[selected] point_mask[sel_global] = 1.0 point_sem[sel_global] = points.new_tensor([ center[0], center[1], center[2], zmax - zmin, - (pts_semantic_mask[ind][0]) + (pts_semantic_mask[indices][0]) ]) - point_offset[sel_global] = center - x[selected] + point_offset[sel_global] = center - coords[selected] # Get the boundary points here point2plane_dist = torch.abs( - torch.sum(x * plane_right[:3], 1) + plane_right[-1]) + torch.sum(coords * plane_right[:3], 1) + plane_right[-1]) min_dist = point2plane_dist.min() selected = torch.abs(point2plane_dist - min_dist) < self.train_cfg['dist_thresh'] if self.primitive_mode == 'line': _, _, line_sel1, line_sel2 = self.match_point2line( - x[selected], xmin, xmax, ymin, ymax) + coords[selected], xmin, xmax, ymin, ymax) for idx, line_select in enumerate([line_sel1, line_sel2]): if torch.sum(line_select) > \ self.train_cfg['num_point_line']: - point_mask[ind[selected][line_select]] = 1.0 + point_mask[indices[selected][line_select]] = 1.0 line_center = torch.mean( - x[selected][line_select], axis=0) + coords[selected][line_select], axis=0) line_center[2] = (zmin + zmax) / 2.0 - point_offset[ind[selected][line_select]] = \ - line_center - x[selected][line_select] - point_sem[ind[selected][line_select]] = \ + point_offset[indices[selected][line_select]] = \ + line_center - coords[selected][line_select] + point_sem[indices[selected][line_select]] = \ points.new_tensor([line_center[0], line_center[1], line_center[2], - pts_semantic_mask[ind][0]]) + pts_semantic_mask[indices][0]]) if self.primitive_mode == 'xy': if torch.sum(selected) > self.train_cfg['num_point'] and \ torch.var(point2plane_dist[selected]) < \ self.train_cfg['var_thresh']: center = points.new_tensor([ - torch.mean(x[selected][:, 0]), - torch.mean(x[selected][:, 1]), (zmin + zmax) / 2.0 + torch.mean(coords[selected][:, 0]), + torch.mean(coords[selected][:, 1]), (zmin + zmax) / 2.0 ]) - sel_global = ind[selected] + sel_global = indices[selected] point_mask[sel_global] = 1.0 point_sem[sel_global] = points.new_tensor([ center[0], center[1], center[2], zmax - zmin, - (pts_semantic_mask[ind][0]) + (pts_semantic_mask[indices][0]) ]) - point_offset[sel_global] = center - x[selected] + point_offset[sel_global] = center - coords[selected] # Get the boundary points here vec1 = cur_corners[0] - cur_corners[4] @@ -588,21 +597,22 @@ def get_targets_single(self, surface_norm[0], surface_norm[1], surface_norm[2], surface_dis ]) - para_points = cur_corners[[3, 2, 7, 6]] + back_points = cur_corners[[3, 2, 7, 6]] plane_front_temp /= torch.norm(plane_front_temp[:3]) - newd = torch.sum(para_points * plane_front_temp[:3], 1) + refined_distance = torch.sum(back_points * plane_front_temp[:3], 1) if plane_front_temp[2] < self.train_cfg['lower_thresh']: plane_front = plane_front_temp plane_back = points.new_tensor([ plane_front_temp[0], plane_front_temp[1], - plane_front_temp[2], -torch.mean(newd) + plane_front_temp[2], -torch.mean(refined_distance) ]) else: - raise NotImplementedError + raise NotImplementedError( + 'Normal vector of the plane should be horizontal!') # Get the boundary points here point2plane_dist = torch.abs( - torch.sum(x * plane_front[:3], 1) + plane_front[-1]) + torch.sum(coords * plane_front[:3], 1) + plane_front[-1]) min_dist = point2plane_dist.min() selected = torch.abs(point2plane_dist - min_dist) < self.train_cfg['dist_thresh'] @@ -611,20 +621,20 @@ def get_targets_single(self, torch.var(point2plane_dist[selected]) < \ self.train_cfg['var_thresh']: center = points.new_tensor([ - torch.mean(x[selected][:, 0]), - torch.mean(x[selected][:, 1]), (zmin + zmax) / 2.0 + torch.mean(coords[selected][:, 0]), + torch.mean(coords[selected][:, 1]), (zmin + zmax) / 2.0 ]) - sel_global = ind[selected] + sel_global = indices[selected] point_mask[sel_global] = 1.0 point_sem[sel_global] = points.new_tensor([ center[0], center[1], center[2], zmax - zmin, - (pts_semantic_mask[ind][0]) + (pts_semantic_mask[indices][0]) ]) - point_offset[sel_global] = center - x[selected] + point_offset[sel_global] = center - coords[selected] # Get the boundary points here point2plane_dist = torch.abs( - torch.sum(x * plane_back[:3], 1) + plane_back[-1]) + torch.sum(coords * plane_back[:3], 1) + plane_back[-1]) min_dist = point2plane_dist.min() selected = torch.abs(point2plane_dist - min_dist) < self.train_cfg['dist_thresh'] @@ -633,66 +643,75 @@ def get_targets_single(self, torch.var(point2plane_dist[selected]) < \ self.train_cfg['var_thresh']: center = points.new_tensor([ - torch.mean(x[selected][:, 0]), - torch.mean(x[selected][:, 1]), (zmin + zmax) / 2.0 + torch.mean(coords[selected][:, 0]), + torch.mean(coords[selected][:, 1]), (zmin + zmax) / 2.0 ]) - sel_global = ind[selected] + sel_global = indices[selected] point_mask[sel_global] = 1.0 point_sem[sel_global] = points.new_tensor([ center[0], center[1], center[2], zmax - zmin, - (pts_semantic_mask[ind][0]) + (pts_semantic_mask[indices][0]) ]) - point_offset[sel_global] = center - x[selected] + point_offset[sel_global] = center - coords[selected] return (point_mask, point_sem, point_offset) - def primitive_decode_scores(self, preds, aggregated_points): - """Decode the outputs of primitive module. + def primitive_decode_scores(self, predictions, aggregated_points): + """Decode predicted parts to primitive head. Args: - preds (torch.Tensor): primitive pridictions of each batch. + predictions (torch.Tensor): primitive pridictions of each batch. aggregated_points (torch.Tensor): The aggregated points of vote stage. Returns: - Dict: Targets of center, size and semantic. + Dict: Predictions of primitive head, including center, + semantic size and semantic scores. """ ret_dict = {} - net_transposed = preds.transpose(2, 1) + net_transposed = predictions.transpose(2, 1) center = aggregated_points + net_transposed[:, :, 0:3] ret_dict['center_' + self.primitive_mode] = center if self.primitive_mode in ['z', 'xy']: ret_dict['size_residuals_' + self.primitive_mode] = \ - net_transposed[:, :, 3:3 + self.num_dim] + net_transposed[:, :, 3:3 + self.num_dims] ret_dict['sem_cls_scores_' + self.primitive_mode] = \ - net_transposed[:, :, 3 + self.num_dim:] + net_transposed[:, :, 3 + self.num_dims:] - return center, ret_dict + return ret_dict - def check_upright(self, para_points): - """Check whether is upright corrdinate. + def check_horizon(self, points): + """Check whether is a horizontal plane. Args: - para_points (torch.Tensor): Points of input. + points (torch.Tensor): Points of input. Returns: Bool: Flag of result. """ - return (para_points[0][-1] == para_points[1][-1]) and ( - para_points[1][-1] - == para_points[2][-1]) and (para_points[2][-1] - == para_points[3][-1]) + return (points[0][-1] == points[1][-1]) and \ + (points[1][-1] == points[2][-1]) and \ + (points[2][-1] == points[3][-1]) + + def check_dist(self, plane_equ, points): + """Whether the mean of points to plane distance is lower than thresh. + + Args: + plane_equ (torch.Tensor): Plane to be checked. + points (torch.Tensor): Points to be checked. - def check_z(self, plane_equ, para_points): - return torch.sum(para_points[:, 2] + + Returns: + Tuple: Flag of result. + """ + return torch.sum(points[:, 2] + plane_equ[-1]) / 4.0 < self.train_cfg['lower_thresh'] def match_point2line(self, points, xmin, xmax, ymin, ymax): - """Match points to corresponding edge. + """Match points to corresponding line. Args: points (torch.Tensor): Points of input. @@ -710,10 +729,10 @@ def match_point2line(self, points, xmin, xmax, ymin, ymax): sel4 = torch.abs(points[:, 1] - ymax) < self.train_cfg['line_thresh'] return sel1, sel2, sel3, sel4 - def compute_primitivesem_loss(self, primitive_center, primitive_semantic, - semantic_scores, num_proposal, - gt_primitive_center, gt_primitive_semantic, - gt_sem_cls_label, gt_primitive_mask): + def compute_primitive_loss(self, primitive_center, primitive_semantic, + semantic_scores, num_proposal, + gt_primitive_center, gt_primitive_semantic, + gt_sem_cls_label, gt_primitive_mask): """Compute loss of primitive module. Args: @@ -744,7 +763,7 @@ def compute_primitivesem_loss(self, primitive_center, primitive_semantic, if self.primitive_mode != 'line': size_xyz_reshape = primitive_semantic.view( - batch_size * num_proposal, -1, self.num_dim).contiguous() + batch_size * num_proposal, -1, self.num_dims).contiguous() size_loss = self.semantic_reg_loss( size_xyz_reshape, gt_primitive_semantic, @@ -762,7 +781,7 @@ def compute_primitivesem_loss(self, primitive_center, primitive_semantic, return center_loss, size_loss, sem_cls_loss def get_primitive_center(self, pred_flag, center): - """Generate primitive center from primitive head predictions. + """Generate primitive center from predictions. Args: pred_flag (torch.Tensor): Scores of primitive center. @@ -771,7 +790,7 @@ def get_primitive_center(self, pred_flag, center): Returns: Tuple: Primitive center and the prediction indices. """ - ind_normal = self.softmax_normal(pred_flag) + ind_normal = F.softmax(pred_flag, dim=1) pred_indices = (ind_normal[:, 1, :] > self.surface_thresh).detach().float() selected = (ind_normal[:, 1, :] <= From 256a29d920d3b5bd9688f859d85897e6d7c834ee Mon Sep 17 00:00:00 2001 From: zhoujiaming Date: Wed, 19 Aug 2020 20:38:17 +0800 Subject: [PATCH 07/10] update primitive head unittest --- tests/test_heads.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_heads.py b/tests/test_heads.py index 410725d7c9..637ca5751f 100644 --- a/tests/test_heads.py +++ b/tests/test_heads.py @@ -466,7 +466,7 @@ def test_primitive_head(): primitive_head_cfg = dict( type='PrimitiveHead', - num_dim=2, + num_dims=2, num_classes=18, primitive_mode='z', vote_moudule_cfg=dict( @@ -564,3 +564,8 @@ def test_primitive_head(): assert losses_dict['center_loss_z'] >= 0 assert losses_dict['size_loss_z'] >= 0 assert losses_dict['sem_loss_z'] >= 0 + + # 'Primitive_mode' should be one of ['z', 'xy', 'line'] + with pytest.raises(AssertionError): + primitive_head_cfg['vote_moudule_cfg']['in_channels'] = 'xyz' + build_head(primitive_head_cfg) From cbf31389079cf6df2a5be7ba53b42be5f4885524 Mon Sep 17 00:00:00 2001 From: zhoujiaming Date: Thu, 20 Aug 2020 13:47:26 +0800 Subject: [PATCH 08/10] modify primitive had --- .../roi_heads/mask_heads/primitive_head.py | 488 ++++++++++-------- 1 file changed, 285 insertions(+), 203 deletions(-) diff --git a/mmdet3d/models/roi_heads/mask_heads/primitive_head.py b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py index 9e70759c68..6d04ff9e05 100644 --- a/mmdet3d/models/roi_heads/mask_heads/primitive_head.py +++ b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py @@ -370,6 +370,7 @@ def get_targets_single(self, indices = instance_flag[pts_instance_mask[instance_flag] == i_instance] coords = points[indices, :3] + cur_cls_label = pts_semantic_mask[indices][0] # Bbox Corners cur_corners = gt_bboxes_3d.corners[i] @@ -379,8 +380,7 @@ def get_targets_single(self, plane_lower_temp = points.new_tensor( [0, 0, 1, -cur_corners[7, -1]]) upper_points = cur_corners[[1, 2, 5, 6]] - refined_distance = torch.sum(upper_points * plane_lower_temp[:3], - 1) + refined_distance = (upper_points * plane_lower_temp[:3]).sum(dim=1) if self.check_horizon(upper_points) and \ plane_lower_temp[0] + plane_lower_temp[1] < \ @@ -397,209 +397,167 @@ def get_targets_single(self, 'Mean distance to plane should be lower than thresh!') # Get the boundary points here - point2plane_dist = torch.abs( - torch.sum(coords * plane_lower[:3], 1) + plane_lower[-1]) - min_dist = point2plane_dist.min() - selected = torch.abs(point2plane_dist - - min_dist) < self.train_cfg['dist_thresh'] + point2plane_dist, selected = self.match_point2plane( + plane_lower, coords) # Get lower four lines if self.primitive_mode == 'line': point2_lines_matching = self.match_point2line( coords[selected], xmin, xmax, ymin, ymax) - for idx, line_select in enumerate(point2_lines_matching): - if torch.sum(line_select) > \ - self.train_cfg['num_point_line']: - point_mask[indices[selected][line_select]] = 1.0 - line_center = torch.mean( - coords[selected][line_select], axis=0) - if idx < 2: - line_center[1] = (ymin + ymax) / 2.0 - else: - line_center[0] = (xmin + xmax) / 2.0 - point_offset[indices[selected][line_select]] = \ - line_center - coords[selected][line_select] - point_sem[indices[selected][line_select]] = \ - points.new_tensor([line_center[0], line_center[1], - line_center[2], - pts_semantic_mask[indices][0]]) + + point_mask, point_offset, point_sem = \ + self._assign_primitive_line_targets(point_mask, + point_offset, + point_sem, + coords[selected], + indices[selected], + cur_cls_label, + point2_lines_matching, + cur_corners, + [0, 0, 1, 1]) # Set the surface labels here - if self.primitive_mode == 'z': - if torch.sum(selected) > self.train_cfg['num_point'] and \ - torch.var(point2plane_dist[selected]) < \ - self.train_cfg['var_thresh']: - center = points.new_tensor([ - (xmin + xmax) / 2.0, (ymin + ymax) / 2.0, - torch.mean(coords[selected][:, 2]) - ]) - sel_global = indices[selected] - point_mask[sel_global] = 1.0 - point_sem[sel_global] = points.new_tensor([ - center[0], center[1], center[2], xmax - xmin, - ymax - ymin, (pts_semantic_mask[indices][0]) - ]) - point_offset[sel_global] = center - coords[selected] + if self.primitive_mode == 'z' and \ + selected.sum() > self.train_cfg['num_point'] and \ + point2plane_dist[selected].var() < \ + self.train_cfg['var_thresh']: + + point_mask, point_offset, point_sem = \ + self._assign_primitive_z_targets(point_mask, + point_offset, + point_sem, + coords[selected], + indices[selected], + cur_cls_label, + cur_corners) # Get the boundary points here - point2plane_dist = torch.abs( - torch.sum(coords * plane_upper[:3], 1) + plane_upper[-1]) - min_dist = point2plane_dist.min() - selected = torch.abs(point2plane_dist - - min_dist) < self.train_cfg['dist_thresh'] + point2plane_dist, selected = self.match_point2plane( + plane_upper, coords) # Get upper four lines if self.primitive_mode == 'line': point2_lines_matching = self.match_point2line( coords[selected], xmin, xmax, ymin, ymax) - for idx, line_select in enumerate(point2_lines_matching): - if torch.sum(line_select) > \ - self.train_cfg['num_point_line']: - point_mask[indices[selected][line_select]] = 1.0 - line_center = torch.mean( - coords[selected][line_select], axis=0) - if idx < 2: - line_center[1] = (ymin + ymax) / 2.0 - else: - line_center[0] = (xmin + xmax) / 2.0 - point_offset[indices[selected][line_select]] = \ - line_center - coords[selected][line_select] - point_sem[indices[selected][line_select]] = \ - points.new_tensor([line_center[0], line_center[1], - line_center[2], - pts_semantic_mask[indices][0]]) - - if self.primitive_mode == 'z': - if torch.sum(selected) > self.train_cfg['num_point'] and \ - torch.var(point2plane_dist[selected]) < \ - self.train_cfg['var_thresh']: - center = points.new_tensor([ - (xmin + xmax) / 2.0, (ymin + ymax) / 2.0, - torch.mean(coords[selected][:, 2]) - ]) - sel_global = indices[selected] - point_mask[sel_global] = 1.0 - point_sem[sel_global] = points.new_tensor([ - center[0], center[1], center[2], xmax - xmin, - ymax - ymin, (pts_semantic_mask[indices][0]) - ]) - point_offset[sel_global] = center - coords[selected] + + point_mask, point_offset, point_sem = \ + self._assign_primitive_line_targets(point_mask, + point_offset, + point_sem, + coords[selected], + indices[selected], + cur_cls_label, + point2_lines_matching, + cur_corners, + [1, 1, 0, 0]) + + if self.primitive_mode == 'z' and \ + selected.sum() > self.train_cfg['num_point'] and \ + point2plane_dist[selected].var() < \ + self.train_cfg['var_thresh']: + + point_mask, point_offset, point_sem = \ + self._assign_primitive_z_targets(point_mask, + point_offset, + point_sem, + coords[selected], + indices[selected], + cur_cls_label, + cur_corners) # Get left two lines - vec1 = cur_corners[2] - cur_corners[3] - vec2 = cur_corners[3] - cur_corners[0] - surface_norm = torch.cross(vec1, vec2) - surface_dis = -torch.dot(surface_norm, cur_corners[0]) - plane_left_temp = points.new_tensor([ - surface_norm[0], surface_norm[1], surface_norm[2], surface_dis - ]) + plane_left_temp = self._get_plane_fomulation( + cur_corners[2] - cur_corners[3], + cur_corners[3] - cur_corners[0], cur_corners[0]) right_points = cur_corners[[4, 5, 7, 6]] plane_left_temp /= torch.norm(plane_left_temp[:3]) - refined_distance = torch.sum(right_points * plane_left_temp[:3], 1) + refined_distance = (right_points * plane_left_temp[:3]).sum(dim=1) + if plane_left_temp[2] < self.train_cfg['lower_thresh']: plane_left = plane_left_temp plane_right = points.new_tensor([ plane_left_temp[0], plane_left_temp[1], plane_left_temp[2], - -torch.mean(refined_distance) + -refined_distance.mean() ]) else: raise NotImplementedError( 'Normal vector of the plane should be horizontal!') # Get the boundary points here - point2plane_dist = torch.abs( - torch.sum(coords * plane_left[:3], 1) + plane_left[-1]) - min_dist = point2plane_dist.min() - selected = torch.abs(point2plane_dist - - min_dist) < self.train_cfg['dist_thresh'] + point2plane_dist, selected = self.match_point2plane( + plane_left, coords) # Get upper four lines if self.primitive_mode == 'line': _, _, line_sel1, line_sel2 = self.match_point2line( coords[selected], xmin, xmax, ymin, ymax) - for idx, line_select in enumerate([line_sel1, line_sel2]): - if torch.sum(line_select) > \ - self.train_cfg['num_point_line']: - point_mask[indices[selected][line_select]] = 1.0 - line_center = torch.mean( - coords[selected][line_select], axis=0) - line_center[2] = (zmin + zmax) / 2.0 - point_offset[indices[selected][line_select]] = \ - line_center - coords[selected][line_select] - point_sem[indices[selected][line_select]] = \ - points.new_tensor([line_center[0], line_center[1], - line_center[2], - pts_semantic_mask[indices][0]]) - - if self.primitive_mode == 'xy': - if torch.sum(selected) > self.train_cfg['num_point'] and \ - torch.var(point2plane_dist[selected]) < \ - self.train_cfg['var_thresh']: - center = points.new_tensor([ - torch.mean(coords[selected][:, 0]), - torch.mean(coords[selected][:, 1]), (zmin + zmax) / 2.0 - ]) - sel_global = indices[selected] - point_mask[sel_global] = 1.0 - point_sem[sel_global] = points.new_tensor([ - center[0], center[1], center[2], zmax - zmin, - (pts_semantic_mask[indices][0]) - ]) - point_offset[sel_global] = center - coords[selected] + point_mask, point_offset, point_sem = \ + self._assign_primitive_line_targets(point_mask, + point_offset, + point_sem, + coords[selected], + indices[selected], + cur_cls_label, + [line_sel1, line_sel2], + cur_corners, + [2, 2]) + + if self.primitive_mode == 'xy' and \ + selected.sum() > self.train_cfg['num_point'] and \ + point2plane_dist[selected].var() < \ + self.train_cfg['var_thresh']: + + point_mask, point_offset, point_sem = \ + self._assign_primitive_xy_targets(point_mask, + point_offset, + point_sem, + coords[selected], + indices[selected], + cur_cls_label, + cur_corners) # Get the boundary points here - point2plane_dist = torch.abs( - torch.sum(coords * plane_right[:3], 1) + plane_right[-1]) - min_dist = point2plane_dist.min() - selected = torch.abs(point2plane_dist - - min_dist) < self.train_cfg['dist_thresh'] + point2plane_dist, selected = self.match_point2plane( + plane_right, coords) if self.primitive_mode == 'line': _, _, line_sel1, line_sel2 = self.match_point2line( coords[selected], xmin, xmax, ymin, ymax) - for idx, line_select in enumerate([line_sel1, line_sel2]): - if torch.sum(line_select) > \ - self.train_cfg['num_point_line']: - point_mask[indices[selected][line_select]] = 1.0 - line_center = torch.mean( - coords[selected][line_select], axis=0) - line_center[2] = (zmin + zmax) / 2.0 - point_offset[indices[selected][line_select]] = \ - line_center - coords[selected][line_select] - point_sem[indices[selected][line_select]] = \ - points.new_tensor([line_center[0], line_center[1], - line_center[2], - pts_semantic_mask[indices][0]]) - - if self.primitive_mode == 'xy': - if torch.sum(selected) > self.train_cfg['num_point'] and \ - torch.var(point2plane_dist[selected]) < \ - self.train_cfg['var_thresh']: - center = points.new_tensor([ - torch.mean(coords[selected][:, 0]), - torch.mean(coords[selected][:, 1]), (zmin + zmax) / 2.0 - ]) - sel_global = indices[selected] - point_mask[sel_global] = 1.0 - point_sem[sel_global] = points.new_tensor([ - center[0], center[1], center[2], zmax - zmin, - (pts_semantic_mask[indices][0]) - ]) - point_offset[sel_global] = center - coords[selected] - # Get the boundary points here - vec1 = cur_corners[0] - cur_corners[4] - vec2 = cur_corners[4] - cur_corners[5] - surface_norm = torch.cross(vec1, vec2) - surface_dis = -torch.dot(surface_norm, cur_corners[5]) - plane_front_temp = points.new_tensor([ - surface_norm[0], surface_norm[1], surface_norm[2], surface_dis - ]) + point_mask, point_offset, point_sem = \ + self._assign_primitive_line_targets(point_mask, + point_offset, + point_sem, + coords[selected], + indices[selected], + cur_cls_label, + [line_sel1, line_sel2], + cur_corners, + [2, 2]) + + if self.primitive_mode == 'xy' and \ + selected.sum() > self.train_cfg['num_point'] and \ + point2plane_dist[selected].var() < \ + self.train_cfg['var_thresh']: + + point_mask, point_offset, point_sem = \ + self._assign_primitive_xy_targets(point_mask, + point_offset, + point_sem, + coords[selected], + indices[selected], + cur_cls_label, + cur_corners) + + plane_front_temp = self._get_plane_fomulation( + cur_corners[0] - cur_corners[4], + cur_corners[4] - cur_corners[5], cur_corners[5]) back_points = cur_corners[[3, 2, 7, 6]] plane_front_temp /= torch.norm(plane_front_temp[:3]) - refined_distance = torch.sum(back_points * plane_front_temp[:3], 1) + refined_distance = (back_points * plane_front_temp[:3]).sum(dim=1) + if plane_front_temp[2] < self.train_cfg['lower_thresh']: plane_front = plane_front_temp plane_back = points.new_tensor([ @@ -611,48 +569,40 @@ def get_targets_single(self, 'Normal vector of the plane should be horizontal!') # Get the boundary points here - point2plane_dist = torch.abs( - torch.sum(coords * plane_front[:3], 1) + plane_front[-1]) - min_dist = point2plane_dist.min() - selected = torch.abs(point2plane_dist - - min_dist) < self.train_cfg['dist_thresh'] - if self.primitive_mode == 'xy': - if torch.sum(selected) > self.train_cfg['num_point'] and \ - torch.var(point2plane_dist[selected]) < \ - self.train_cfg['var_thresh']: - center = points.new_tensor([ - torch.mean(coords[selected][:, 0]), - torch.mean(coords[selected][:, 1]), (zmin + zmax) / 2.0 - ]) - sel_global = indices[selected] - point_mask[sel_global] = 1.0 - point_sem[sel_global] = points.new_tensor([ - center[0], center[1], center[2], zmax - zmin, - (pts_semantic_mask[indices][0]) - ]) - point_offset[sel_global] = center - coords[selected] + point2plane_dist, selected = self.match_point2plane( + plane_front, coords) + + if self.primitive_mode == 'xy' and \ + selected.sum() > self.train_cfg['num_point'] and \ + (point2plane_dist[selected]).var() < \ + self.train_cfg['var_thresh']: + + point_mask, point_offset, point_sem = \ + self._assign_primitive_xy_targets(point_mask, + point_offset, + point_sem, + coords[selected], + indices[selected], + cur_cls_label, + cur_corners) # Get the boundary points here - point2plane_dist = torch.abs( - torch.sum(coords * plane_back[:3], 1) + plane_back[-1]) - min_dist = point2plane_dist.min() - selected = torch.abs(point2plane_dist - - min_dist) < self.train_cfg['dist_thresh'] - if self.primitive_mode == 'xy': - if torch.sum(selected) > self.train_cfg['num_point'] and \ - torch.var(point2plane_dist[selected]) < \ - self.train_cfg['var_thresh']: - center = points.new_tensor([ - torch.mean(coords[selected][:, 0]), - torch.mean(coords[selected][:, 1]), (zmin + zmax) / 2.0 - ]) - sel_global = indices[selected] - point_mask[sel_global] = 1.0 - point_sem[sel_global] = points.new_tensor([ - center[0], center[1], center[2], zmax - zmin, - (pts_semantic_mask[indices][0]) - ]) - point_offset[sel_global] = center - coords[selected] + point2plane_dist, selected = self.match_point2plane( + plane_back, coords) + + if self.primitive_mode == 'xy' and \ + selected.sum() > self.train_cfg['num_point'] and \ + point2plane_dist[selected].var() < \ + self.train_cfg['var_thresh']: + + point_mask, point_offset, point_sem = \ + self._assign_primitive_xy_targets(point_mask, + point_offset, + point_sem, + coords[selected], + indices[selected], + cur_cls_label, + cur_corners) return (point_mask, point_sem, point_offset) @@ -707,8 +657,8 @@ def check_dist(self, plane_equ, points): Returns: Tuple: Flag of result. """ - return torch.sum(points[:, 2] + - plane_equ[-1]) / 4.0 < self.train_cfg['lower_thresh'] + return (points[:, 2] + + plane_equ[-1]).sum() / 4.0 < self.train_cfg['lower_thresh'] def match_point2line(self, points, xmin, xmax, ymin, ymax): """Match points to corresponding line. @@ -729,6 +679,24 @@ def match_point2line(self, points, xmin, xmax, ymin, ymax): sel4 = torch.abs(points[:, 1] - ymax) < self.train_cfg['line_thresh'] return sel1, sel2, sel3, sel4 + def match_point2plane(self, plane, points): + """Match points to plane. + + Args: + plane (torch.Tensor): Equation of the plane. + points (torch.Tensor): Points of input. + + Returns: + Tuple: Distance of each point to the plane and + flag of matching correspondence. + """ + point2plane_dist = torch.abs((points * plane[:3]).sum(dim=1) + + plane[-1]) + min_dist = point2plane_dist.min() + selected = torch.abs(point2plane_dist - + min_dist) < self.train_cfg['dist_thresh'] + return point2plane_dist, selected + def compute_primitive_loss(self, primitive_center, primitive_semantic, semantic_scores, num_proposal, gt_primitive_center, gt_primitive_semantic, @@ -798,3 +766,117 @@ def get_primitive_center(self, pred_flag, center): offset = torch.ones_like(center) * self.upper_thresh center = center + offset * selected.unsqueeze(-1) return center, pred_indices + + def _assign_primitive_line_targets(self, point_mask, point_offset, + point_sem, coords, indices, cls_label, + point2_lines_matching, corners, + center_axises): + """Generate targets of line primitive. + + Args: + point_mask (torch.Tensor): Tensor to store the ground + truth of mask. + point_offset (torch.Tensor): Tensor to store the ground + truth of offset. + point_sem (torch.Tensor): Tensor to store the ground + truth of semantic. + coords (torch.Tensor): The selected points. + indices (torch.Tensor): Indices of the selected points. + cls_label (int): Class label of the ground truth bounding box. + point2_lines_matching (torch.Tensor): Flag indicate that + matching line of each point. + corners (torch.Tensor): Corners of the ground truth bounding box. + center_axises (list[int]): Indicate in which axis the line center + should be refined. + + Returns: + Tuple: Targets of the line primitive. + """ + for line_select, center_axis in zip(point2_lines_matching, + center_axises): + if line_select.sum() > self.train_cfg['num_point_line']: + point_mask[indices[line_select]] = 1.0 + line_center = coords[line_select].mean(dim=0) + line_center[center_axis] = corners[:, center_axis].mean() + point_offset[indices[line_select]] = \ + line_center - coords[line_select] + point_sem[indices[line_select]] = \ + point_sem.new_tensor([line_center[0], line_center[1], + line_center[2], cls_label]) + return point_mask, point_offset, point_sem + + def _assign_primitive_z_targets(self, point_mask, point_offset, point_sem, + coords, indices, cls_label, corners): + """Generate targets of center primitive. + + Args: + point_mask (torch.Tensor): Tensor to store the ground + truth of mask. + point_offset (torch.Tensor): Tensor to store the ground + truth of offset. + point_sem (torch.Tensor): Tensor to store the ground + truth of semantic. + coords (torch.Tensor): The selected points. + indices (torch.Tensor): Indices of the selected points. + cls_label (int): Class label of the ground truth bounding box. + corners (torch.Tensor): Corners of the ground truth bounding box. + + Returns: + Tuple: Targets of the center primitive. + """ + center = point_mask.new_tensor( + [corners[:, 0].mean(), corners[:, 1].mean(), coords[:, 2].mean()]) + point_mask[indices] = 1.0 + point_sem[indices] = point_sem.new_tensor([ + center[0], center[1], center[2], + corners[:, 0].max() - corners[:, 0].min(), + corners[:, 1].max() - corners[:, 1].min(), cls_label + ]) + point_offset[indices] = center - coords + return point_mask, point_offset, point_sem + + def _assign_primitive_xy_targets(self, point_mask, point_offset, point_sem, + coords, indices, cls_label, corners): + """Generate targets of surface primitive. + + Args: + point_mask (torch.Tensor): Tensor to store the ground + truth of mask. + point_offset (torch.Tensor): Tensor to store the ground + truth of offset. + point_sem (torch.Tensor): Tensor to store the ground + truth of semantic. + coords (torch.Tensor): The selected points. + indices (torch.Tensor): Indices of the selected points. + cls_label (int): Class label of the ground truth bounding box. + corners (torch.Tensor): Corners of the ground truth bounding box. + + Returns: + Tuple: Targets of the surface primitive. + """ + center = point_mask.new_tensor( + [coords[:, 0].mean(), coords[:, 1].mean(), corners[:, 2].mean()]) + point_mask[indices] = 1.0 + point_sem[indices] = point_sem.new_tensor([ + center[0], center[1], center[2], + corners[:, 2].max() - corners[:, 2].min(), cls_label + ]) + point_offset[indices] = center - coords + return point_mask, point_offset, point_sem + + def _get_plane_fomulation(self, vector1, vector2, point): + """Compute the equation of the plane. + + Args: + vector1 (torch.Tensor): Parallel vector of the plane. + vector2 (torch.Tensor): Parallel vector of the plane. + point (torch.Tensor): Point on the plane. + + Returns: + torch.Tensor: Equation of the plane. + """ + surface_norm = torch.cross(vector1, vector2) + surface_dis = -torch.dot(surface_norm, point) + plane = point.new_tensor( + [surface_norm[0], surface_norm[1], surface_norm[2], surface_dis]) + return plane From edf2959bf98cf31304c7a11a20e3ab1581d57fbe Mon Sep 17 00:00:00 2001 From: zhoujiaming Date: Mon, 24 Aug 2020 10:39:45 +0800 Subject: [PATCH 09/10] fix bugs for primitive head --- mmdet3d/models/roi_heads/mask_heads/primitive_head.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmdet3d/models/roi_heads/mask_heads/primitive_head.py b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py index 6d04ff9e05..7f6ab68b5d 100644 --- a/mmdet3d/models/roi_heads/mask_heads/primitive_head.py +++ b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py @@ -414,7 +414,7 @@ def get_targets_single(self, cur_cls_label, point2_lines_matching, cur_corners, - [0, 0, 1, 1]) + [1, 1, 0, 0]) # Set the surface labels here if self.primitive_mode == 'z' and \ From e5dc5824dc7e3d88dc896fc9a208755c0819aecc Mon Sep 17 00:00:00 2001 From: zhoujiaming Date: Tue, 25 Aug 2020 20:33:53 +0800 Subject: [PATCH 10/10] update primitive head --- .../roi_heads/mask_heads/primitive_head.py | 183 ++++++++---------- 1 file changed, 78 insertions(+), 105 deletions(-) diff --git a/mmdet3d/models/roi_heads/mask_heads/primitive_head.py b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py index 7f6ab68b5d..d15a845bb5 100644 --- a/mmdet3d/models/roi_heads/mask_heads/primitive_head.py +++ b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py @@ -159,7 +159,7 @@ def forward(self, feats_dict, sample_mod): dtype=torch.int32, device=seed_points.device) else: - raise NotImplementedError + raise NotImplementedError('Unsupported sample mod!') vote_aggregation_ret = self.vote_aggregation(vote_points, vote_features, @@ -279,20 +279,12 @@ def get_targets(self, Returns: tuple[torch.Tensor]: Targets of primitive head. """ - valid_gt_masks = list() - gt_num = list() for index in range(len(gt_labels_3d)): if len(gt_labels_3d[index]) == 0: fake_box = gt_bboxes_3d[index].tensor.new_zeros( 1, gt_bboxes_3d[index].tensor.shape[-1]) gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box) gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1) - valid_gt_masks.append(gt_labels_3d[index].new_zeros(1)) - gt_num.append(1) - else: - valid_gt_masks.append(gt_labels_3d[index].new_ones( - gt_labels_3d[index].shape)) - gt_num.append(gt_labels_3d[index].shape[0]) if pts_semantic_mask is None: pts_semantic_mask = [None for i in range(len(gt_labels_3d))] @@ -402,7 +394,7 @@ def get_targets_single(self, # Get lower four lines if self.primitive_mode == 'line': - point2_lines_matching = self.match_point2line( + point2line_matching = self.match_point2line( coords[selected], xmin, xmax, ymin, ymax) point_mask, point_offset, point_sem = \ @@ -412,7 +404,7 @@ def get_targets_single(self, coords[selected], indices[selected], cur_cls_label, - point2_lines_matching, + point2line_matching, cur_corners, [1, 1, 0, 0]) @@ -423,13 +415,13 @@ def get_targets_single(self, self.train_cfg['var_thresh']: point_mask, point_offset, point_sem = \ - self._assign_primitive_z_targets(point_mask, - point_offset, - point_sem, - coords[selected], - indices[selected], - cur_cls_label, - cur_corners) + self._assign_primitive_surface_targets(point_mask, + point_offset, + point_sem, + coords[selected], + indices[selected], + cur_cls_label, + cur_corners) # Get the boundary points here point2plane_dist, selected = self.match_point2plane( @@ -437,7 +429,7 @@ def get_targets_single(self, # Get upper four lines if self.primitive_mode == 'line': - point2_lines_matching = self.match_point2line( + point2line_matching = self.match_point2line( coords[selected], xmin, xmax, ymin, ymax) point_mask, point_offset, point_sem = \ @@ -447,7 +439,7 @@ def get_targets_single(self, coords[selected], indices[selected], cur_cls_label, - point2_lines_matching, + point2line_matching, cur_corners, [1, 1, 0, 0]) @@ -457,13 +449,13 @@ def get_targets_single(self, self.train_cfg['var_thresh']: point_mask, point_offset, point_sem = \ - self._assign_primitive_z_targets(point_mask, - point_offset, - point_sem, - coords[selected], - indices[selected], - cur_cls_label, - cur_corners) + self._assign_primitive_surface_targets(point_mask, + point_offset, + point_sem, + coords[selected], + indices[selected], + cur_cls_label, + cur_corners) # Get left two lines plane_left_temp = self._get_plane_fomulation( @@ -509,13 +501,13 @@ def get_targets_single(self, self.train_cfg['var_thresh']: point_mask, point_offset, point_sem = \ - self._assign_primitive_xy_targets(point_mask, - point_offset, - point_sem, - coords[selected], - indices[selected], - cur_cls_label, - cur_corners) + self._assign_primitive_surface_targets(point_mask, + point_offset, + point_sem, + coords[selected], + indices[selected], + cur_cls_label, + cur_corners) # Get the boundary points here point2plane_dist, selected = self.match_point2plane( @@ -542,13 +534,13 @@ def get_targets_single(self, self.train_cfg['var_thresh']: point_mask, point_offset, point_sem = \ - self._assign_primitive_xy_targets(point_mask, - point_offset, - point_sem, - coords[selected], - indices[selected], - cur_cls_label, - cur_corners) + self._assign_primitive_surface_targets(point_mask, + point_offset, + point_sem, + coords[selected], + indices[selected], + cur_cls_label, + cur_corners) plane_front_temp = self._get_plane_fomulation( cur_corners[0] - cur_corners[4], @@ -578,13 +570,13 @@ def get_targets_single(self, self.train_cfg['var_thresh']: point_mask, point_offset, point_sem = \ - self._assign_primitive_xy_targets(point_mask, - point_offset, - point_sem, - coords[selected], - indices[selected], - cur_cls_label, - cur_corners) + self._assign_primitive_surface_targets(point_mask, + point_offset, + point_sem, + coords[selected], + indices[selected], + cur_cls_label, + cur_corners) # Get the boundary points here point2plane_dist, selected = self.match_point2plane( @@ -596,13 +588,13 @@ def get_targets_single(self, self.train_cfg['var_thresh']: point_mask, point_offset, point_sem = \ - self._assign_primitive_xy_targets(point_mask, - point_offset, - point_sem, - coords[selected], - indices[selected], - cur_cls_label, - cur_corners) + self._assign_primitive_surface_targets(point_mask, + point_offset, + point_sem, + coords[selected], + indices[selected], + cur_cls_label, + cur_corners) return (point_mask, point_sem, point_offset) @@ -620,17 +612,17 @@ def primitive_decode_scores(self, predictions, aggregated_points): """ ret_dict = {} - net_transposed = predictions.transpose(2, 1) + pred_transposed = predictions.transpose(2, 1) - center = aggregated_points + net_transposed[:, :, 0:3] + center = aggregated_points + pred_transposed[:, :, 0:3] ret_dict['center_' + self.primitive_mode] = center if self.primitive_mode in ['z', 'xy']: ret_dict['size_residuals_' + self.primitive_mode] = \ - net_transposed[:, :, 3:3 + self.num_dims] + pred_transposed[:, :, 3:3 + self.num_dims] ret_dict['sem_cls_scores_' + self.primitive_mode] = \ - net_transposed[:, :, 3 + self.num_dims:] + pred_transposed[:, :, 3 + self.num_dims:] return ret_dict @@ -738,13 +730,11 @@ def compute_primitive_loss(self, primitive_center, primitive_semantic, dst_weight=gt_primitive_mask.view(batch_size * num_proposal, 1))[1] else: - size_loss = torch.tensor(0).float().to(center_loss.device) + size_loss = center_loss.new_tensor(0.0) # Semantic cls loss sem_cls_loss = self.semantic_cls_loss( - semantic_scores, - gt_sem_cls_label, - weight=gt_primitive_mask.float()) + semantic_scores, gt_sem_cls_label, weight=gt_primitive_mask) return center_loss, size_loss, sem_cls_loss @@ -769,7 +759,7 @@ def get_primitive_center(self, pred_flag, center): def _assign_primitive_line_targets(self, point_mask, point_offset, point_sem, coords, indices, cls_label, - point2_lines_matching, corners, + point2line_matching, corners, center_axises): """Generate targets of line primitive. @@ -783,7 +773,7 @@ def _assign_primitive_line_targets(self, point_mask, point_offset, coords (torch.Tensor): The selected points. indices (torch.Tensor): Indices of the selected points. cls_label (int): Class label of the ground truth bounding box. - point2_lines_matching (torch.Tensor): Flag indicate that + point2line_matching (torch.Tensor): Flag indicate that matching line of each point. corners (torch.Tensor): Corners of the ground truth bounding box. center_axises (list[int]): Indicate in which axis the line center @@ -792,7 +782,7 @@ def _assign_primitive_line_targets(self, point_mask, point_offset, Returns: Tuple: Targets of the line primitive. """ - for line_select, center_axis in zip(point2_lines_matching, + for line_select, center_axis in zip(point2line_matching, center_axises): if line_select.sum() > self.train_cfg['num_point_line']: point_mask[indices[line_select]] = 1.0 @@ -805,9 +795,10 @@ def _assign_primitive_line_targets(self, point_mask, point_offset, line_center[2], cls_label]) return point_mask, point_offset, point_sem - def _assign_primitive_z_targets(self, point_mask, point_offset, point_sem, - coords, indices, cls_label, corners): - """Generate targets of center primitive. + def _assign_primitive_surface_targets(self, point_mask, point_offset, + point_sem, coords, indices, + cls_label, corners): + """Generate targets for primitive z and primitive xy. Args: point_mask (torch.Tensor): Tensor to store the ground @@ -824,43 +815,25 @@ def _assign_primitive_z_targets(self, point_mask, point_offset, point_sem, Returns: Tuple: Targets of the center primitive. """ - center = point_mask.new_tensor( - [corners[:, 0].mean(), corners[:, 1].mean(), coords[:, 2].mean()]) - point_mask[indices] = 1.0 - point_sem[indices] = point_sem.new_tensor([ - center[0], center[1], center[2], - corners[:, 0].max() - corners[:, 0].min(), - corners[:, 1].max() - corners[:, 1].min(), cls_label - ]) - point_offset[indices] = center - coords - return point_mask, point_offset, point_sem - - def _assign_primitive_xy_targets(self, point_mask, point_offset, point_sem, - coords, indices, cls_label, corners): - """Generate targets of surface primitive. - - Args: - point_mask (torch.Tensor): Tensor to store the ground - truth of mask. - point_offset (torch.Tensor): Tensor to store the ground - truth of offset. - point_sem (torch.Tensor): Tensor to store the ground - truth of semantic. - coords (torch.Tensor): The selected points. - indices (torch.Tensor): Indices of the selected points. - cls_label (int): Class label of the ground truth bounding box. - corners (torch.Tensor): Corners of the ground truth bounding box. - - Returns: - Tuple: Targets of the surface primitive. - """ - center = point_mask.new_tensor( - [coords[:, 0].mean(), coords[:, 1].mean(), corners[:, 2].mean()]) point_mask[indices] = 1.0 - point_sem[indices] = point_sem.new_tensor([ - center[0], center[1], center[2], - corners[:, 2].max() - corners[:, 2].min(), cls_label - ]) + if self.primitive_mode == 'z': + center = point_mask.new_tensor([ + corners[:, 0].mean(), corners[:, 1].mean(), coords[:, + 2].mean() + ]) + point_sem[indices] = point_sem.new_tensor([ + center[0], center[1], center[2], + corners[:, 0].max() - corners[:, 0].min(), + corners[:, 1].max() - corners[:, 1].min(), cls_label + ]) + elif self.primitive_mode == 'xy': + center = point_mask.new_tensor([ + coords[:, 0].mean(), coords[:, 1].mean(), corners[:, 2].mean() + ]) + point_sem[indices] = point_sem.new_tensor([ + center[0], center[1], center[2], + corners[:, 2].max() - corners[:, 2].min(), cls_label + ]) point_offset[indices] = center - coords return point_mask, point_offset, point_sem