-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
Copy pathh3d_bbox_head.py
925 lines (807 loc) · 39.1 KB
/
h3d_bbox_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch import nn as nn
from torch.nn import functional as F
from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import HEADS, build_loss
from mmdet3d.models.losses import chamfer_distance
from mmdet3d.ops import build_sa_module
from mmdet.core import build_bbox_coder, multi_apply
@HEADS.register_module()
class H3DBboxHead(BaseModule):
r"""Bbox head of `H3DNet <https://arxiv.org/abs/2006.05682>`_.
Args:
num_classes (int): The number of classes.
surface_matching_cfg (dict): Config for surface primitive matching.
line_matching_cfg (dict): Config for line primitive matching.
bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
decoding boxes.
train_cfg (dict): Config for training.
test_cfg (dict): Config for testing.
gt_per_seed (int): Number of ground truth votes generated
from each seed point.
num_proposal (int): Number of proposal votes generated.
feat_channels (tuple[int]): Convolution channels of
prediction layer.
primitive_feat_refine_streams (int): The number of mlps to
refine primitive feature.
primitive_refine_channels (tuple[int]): Convolution channels of
prediction layer.
upper_thresh (float): Threshold for line matching.
surface_thresh (float): Threshold for surface matching.
line_thresh (float): Threshold for line matching.
conv_cfg (dict): Config of convolution in prediction layer.
norm_cfg (dict): Config of BN in prediction layer.
objectness_loss (dict): Config of objectness loss.
center_loss (dict): Config of center loss.
dir_class_loss (dict): Config of direction classification loss.
dir_res_loss (dict): Config of direction residual regression loss.
size_class_loss (dict): Config of size classification loss.
size_res_loss (dict): Config of size residual regression loss.
semantic_loss (dict): Config of point-wise semantic segmentation loss.
cues_objectness_loss (dict): Config of cues objectness loss.
cues_semantic_loss (dict): Config of cues semantic loss.
proposal_objectness_loss (dict): Config of proposal objectness
loss.
primitive_center_loss (dict): Config of primitive center regression
loss.
"""
def __init__(self,
num_classes,
suface_matching_cfg,
line_matching_cfg,
bbox_coder,
train_cfg=None,
test_cfg=None,
gt_per_seed=1,
num_proposal=256,
feat_channels=(128, 128),
primitive_feat_refine_streams=2,
primitive_refine_channels=[128, 128, 128],
upper_thresh=100.0,
surface_thresh=0.5,
line_thresh=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=None,
center_loss=None,
dir_class_loss=None,
dir_res_loss=None,
size_class_loss=None,
size_res_loss=None,
semantic_loss=None,
cues_objectness_loss=None,
cues_semantic_loss=None,
proposal_objectness_loss=None,
primitive_center_loss=None,
init_cfg=None):
super(H3DBboxHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.gt_per_seed = gt_per_seed
self.num_proposal = num_proposal
self.with_angle = bbox_coder['with_rot']
self.upper_thresh = upper_thresh
self.surface_thresh = surface_thresh
self.line_thresh = line_thresh
self.objectness_loss = build_loss(objectness_loss)
self.center_loss = build_loss(center_loss)
self.dir_class_loss = build_loss(dir_class_loss)
self.dir_res_loss = build_loss(dir_res_loss)
self.size_class_loss = build_loss(size_class_loss)
self.size_res_loss = build_loss(size_res_loss)
self.semantic_loss = build_loss(semantic_loss)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_sizes = self.bbox_coder.num_sizes
self.num_dir_bins = self.bbox_coder.num_dir_bins
self.cues_objectness_loss = build_loss(cues_objectness_loss)
self.cues_semantic_loss = build_loss(cues_semantic_loss)
self.proposal_objectness_loss = build_loss(proposal_objectness_loss)
self.primitive_center_loss = build_loss(primitive_center_loss)
assert suface_matching_cfg['mlp_channels'][-1] == \
line_matching_cfg['mlp_channels'][-1]
# surface center matching
self.surface_center_matcher = build_sa_module(suface_matching_cfg)
# line center matching
self.line_center_matcher = build_sa_module(line_matching_cfg)
# Compute the matching scores
matching_feat_dims = suface_matching_cfg['mlp_channels'][-1]
self.matching_conv = ConvModule(
matching_feat_dims,
matching_feat_dims,
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
inplace=True)
self.matching_pred = nn.Conv1d(matching_feat_dims, 2, 1)
# Compute the semantic matching scores
self.semantic_matching_conv = ConvModule(
matching_feat_dims,
matching_feat_dims,
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
inplace=True)
self.semantic_matching_pred = nn.Conv1d(matching_feat_dims, 2, 1)
# Surface feature aggregation
self.surface_feats_aggregation = list()
for k in range(primitive_feat_refine_streams):
self.surface_feats_aggregation.append(
ConvModule(
matching_feat_dims,
matching_feat_dims,
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
inplace=True))
self.surface_feats_aggregation = nn.Sequential(
*self.surface_feats_aggregation)
# Line feature aggregation
self.line_feats_aggregation = list()
for k in range(primitive_feat_refine_streams):
self.line_feats_aggregation.append(
ConvModule(
matching_feat_dims,
matching_feat_dims,
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
inplace=True))
self.line_feats_aggregation = nn.Sequential(
*self.line_feats_aggregation)
# surface center(6) + line center(12)
prev_channel = 18 * matching_feat_dims
self.bbox_pred = nn.ModuleList()
for k in range(len(primitive_refine_channels)):
self.bbox_pred.append(
ConvModule(
prev_channel,
primitive_refine_channels[k],
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
inplace=False))
prev_channel = primitive_refine_channels[k]
# Final object detection
# Objectness scores (2), center residual (3),
# heading class+residual (num_heading_bin*2), size class +
# residual(num_size_cluster*4)
conv_out_channel = (2 + 3 + bbox_coder['num_dir_bins'] * 2 +
bbox_coder['num_sizes'] * 4 + self.num_classes)
self.bbox_pred.append(nn.Conv1d(prev_channel, conv_out_channel, 1))
def forward(self, feats_dict, sample_mod):
"""Forward pass.
Args:
feats_dict (dict): Feature dict from backbone.
sample_mod (str): Sample mode for vote aggregation layer.
valid modes are "vote", "seed" and "random".
Returns:
dict: Predictions of vote head.
"""
ret_dict = {}
aggregated_points = feats_dict['aggregated_points']
original_feature = feats_dict['aggregated_features']
batch_size = original_feature.shape[0]
object_proposal = original_feature.shape[2]
# Extract surface center, features and semantic predictions
z_center = feats_dict['pred_z_center']
xy_center = feats_dict['pred_xy_center']
z_semantic = feats_dict['sem_cls_scores_z']
xy_semantic = feats_dict['sem_cls_scores_xy']
z_feature = feats_dict['aggregated_features_z']
xy_feature = feats_dict['aggregated_features_xy']
# Extract line points and features
line_center = feats_dict['pred_line_center']
line_feature = feats_dict['aggregated_features_line']
surface_center_pred = torch.cat((z_center, xy_center), dim=1)
ret_dict['surface_center_pred'] = surface_center_pred
ret_dict['surface_sem_pred'] = torch.cat((z_semantic, xy_semantic),
dim=1)
# Extract the surface and line centers of rpn proposals
rpn_proposals = feats_dict['proposal_list']
rpn_proposals_bbox = DepthInstance3DBoxes(
rpn_proposals.reshape(-1, 7).clone(),
box_dim=rpn_proposals.shape[-1],
with_yaw=self.with_angle,
origin=(0.5, 0.5, 0.5))
obj_surface_center, obj_line_center = \
rpn_proposals_bbox.get_surface_line_center()
obj_surface_center = obj_surface_center.reshape(
batch_size, -1, 6, 3).transpose(1, 2).reshape(batch_size, -1, 3)
obj_line_center = obj_line_center.reshape(batch_size, -1, 12,
3).transpose(1, 2).reshape(
batch_size, -1, 3)
ret_dict['surface_center_object'] = obj_surface_center
ret_dict['line_center_object'] = obj_line_center
# aggregate primitive z and xy features to rpn proposals
surface_center_feature_pred = torch.cat((z_feature, xy_feature), dim=2)
surface_center_feature_pred = torch.cat(
(surface_center_feature_pred.new_zeros(
(batch_size, 6, surface_center_feature_pred.shape[2])),
surface_center_feature_pred),
dim=1)
surface_xyz, surface_features, _ = self.surface_center_matcher(
surface_center_pred,
surface_center_feature_pred,
target_xyz=obj_surface_center)
# aggregate primitive line features to rpn proposals
line_feature = torch.cat((line_feature.new_zeros(
(batch_size, 12, line_feature.shape[2])), line_feature),
dim=1)
line_xyz, line_features, _ = self.line_center_matcher(
line_center, line_feature, target_xyz=obj_line_center)
# combine the surface and line features
combine_features = torch.cat((surface_features, line_features), dim=2)
matching_features = self.matching_conv(combine_features)
matching_score = self.matching_pred(matching_features)
ret_dict['matching_score'] = matching_score.transpose(2, 1)
semantic_matching_features = self.semantic_matching_conv(
combine_features)
semantic_matching_score = self.semantic_matching_pred(
semantic_matching_features)
ret_dict['semantic_matching_score'] = \
semantic_matching_score.transpose(2, 1)
surface_features = self.surface_feats_aggregation(surface_features)
line_features = self.line_feats_aggregation(line_features)
# Combine all surface and line features
surface_features = surface_features.view(batch_size, -1,
object_proposal)
line_features = line_features.view(batch_size, -1, object_proposal)
combine_feature = torch.cat((surface_features, line_features), dim=1)
# Final bbox predictions
bbox_predictions = self.bbox_pred[0](combine_feature)
bbox_predictions += original_feature
for conv_module in self.bbox_pred[1:]:
bbox_predictions = conv_module(bbox_predictions)
refine_decode_res = self.bbox_coder.split_pred(
bbox_predictions[:, :self.num_classes + 2],
bbox_predictions[:, self.num_classes + 2:], aggregated_points)
for key in refine_decode_res.keys():
ret_dict[key + '_optimized'] = refine_decode_res[key]
return ret_dict
def loss(self,
bbox_preds,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
img_metas=None,
rpn_targets=None,
gt_bboxes_ignore=None):
"""Compute loss.
Args:
bbox_preds (dict): Predictions from forward of h3d bbox head.
points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
bboxes of each sample.
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
pts_semantic_mask (list[torch.Tensor]): Point-wise
semantic mask.
pts_instance_mask (list[torch.Tensor]): Point-wise
instance mask.
img_metas (list[dict]): Contain pcd and img's meta info.
rpn_targets (Tuple) : Targets generated by rpn head.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding.
Returns:
dict: Losses of H3dnet.
"""
(vote_targets, vote_target_masks, size_class_targets, size_res_targets,
dir_class_targets, dir_res_targets, center_targets, _, mask_targets,
valid_gt_masks, objectness_targets, objectness_weights,
box_loss_weights, valid_gt_weights) = rpn_targets
losses = {}
# calculate refined proposal loss
refined_proposal_loss = self.get_proposal_stage_loss(
bbox_preds,
size_class_targets,
size_res_targets,
dir_class_targets,
dir_res_targets,
center_targets,
mask_targets,
objectness_targets,
objectness_weights,
box_loss_weights,
valid_gt_weights,
suffix='_optimized')
for key in refined_proposal_loss.keys():
losses[key + '_optimized'] = refined_proposal_loss[key]
bbox3d_optimized = self.bbox_coder.decode(
bbox_preds, suffix='_optimized')
targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask,
bbox_preds)
(cues_objectness_label, cues_sem_label, proposal_objectness_label,
cues_mask, cues_match_mask, proposal_objectness_mask,
cues_matching_label, obj_surface_line_center) = targets
# match scores for each geometric primitive
objectness_scores = bbox_preds['matching_score']
# match scores for the semantics of primitives
objectness_scores_sem = bbox_preds['semantic_matching_score']
primitive_objectness_loss = self.cues_objectness_loss(
objectness_scores.transpose(2, 1),
cues_objectness_label,
weight=cues_mask,
avg_factor=cues_mask.sum() + 1e-6)
primitive_sem_loss = self.cues_semantic_loss(
objectness_scores_sem.transpose(2, 1),
cues_sem_label,
weight=cues_mask,
avg_factor=cues_mask.sum() + 1e-6)
objectness_scores = bbox_preds['obj_scores_optimized']
objectness_loss_refine = self.proposal_objectness_loss(
objectness_scores.transpose(2, 1), proposal_objectness_label)
primitive_matching_loss = (objectness_loss_refine *
cues_match_mask).sum() / (
cues_match_mask.sum() + 1e-6) * 0.5
primitive_sem_matching_loss = (
objectness_loss_refine * proposal_objectness_mask).sum() / (
proposal_objectness_mask.sum() + 1e-6) * 0.5
# Get the object surface center here
batch_size, object_proposal = bbox3d_optimized.shape[:2]
refined_bbox = DepthInstance3DBoxes(
bbox3d_optimized.reshape(-1, 7).clone(),
box_dim=bbox3d_optimized.shape[-1],
with_yaw=self.with_angle,
origin=(0.5, 0.5, 0.5))
pred_obj_surface_center, pred_obj_line_center = \
refined_bbox.get_surface_line_center()
pred_obj_surface_center = pred_obj_surface_center.reshape(
batch_size, -1, 6, 3).transpose(1, 2).reshape(batch_size, -1, 3)
pred_obj_line_center = pred_obj_line_center.reshape(
batch_size, -1, 12, 3).transpose(1, 2).reshape(batch_size, -1, 3)
pred_surface_line_center = torch.cat(
(pred_obj_surface_center, pred_obj_line_center), 1)
square_dist = self.primitive_center_loss(pred_surface_line_center,
obj_surface_line_center)
match_dist = torch.sqrt(square_dist.sum(dim=-1) + 1e-6)
primitive_centroid_reg_loss = torch.sum(
match_dist * cues_matching_label) / (
cues_matching_label.sum() + 1e-6)
refined_loss = dict(
primitive_objectness_loss=primitive_objectness_loss,
primitive_sem_loss=primitive_sem_loss,
primitive_matching_loss=primitive_matching_loss,
primitive_sem_matching_loss=primitive_sem_matching_loss,
primitive_centroid_reg_loss=primitive_centroid_reg_loss)
losses.update(refined_loss)
return losses
def get_bboxes(self,
points,
bbox_preds,
input_metas,
rescale=False,
suffix=''):
"""Generate bboxes from vote head predictions.
Args:
points (torch.Tensor): Input points.
bbox_preds (dict): Predictions from vote head.
input_metas (list[dict]): Point cloud and image's meta info.
rescale (bool): Whether to rescale bboxes.
Returns:
list[tuple[torch.Tensor]]: Bounding boxes, scores and labels.
"""
# decode boxes
obj_scores = F.softmax(
bbox_preds['obj_scores' + suffix], dim=-1)[..., -1]
sem_scores = F.softmax(bbox_preds['sem_scores'], dim=-1)
prediction_collection = {}
prediction_collection['center'] = bbox_preds['center' + suffix]
prediction_collection['dir_class'] = bbox_preds['dir_class']
prediction_collection['dir_res'] = bbox_preds['dir_res' + suffix]
prediction_collection['size_class'] = bbox_preds['size_class']
prediction_collection['size_res'] = bbox_preds['size_res' + suffix]
bbox3d = self.bbox_coder.decode(prediction_collection)
batch_size = bbox3d.shape[0]
results = list()
for b in range(batch_size):
bbox_selected, score_selected, labels = self.multiclass_nms_single(
obj_scores[b], sem_scores[b], bbox3d[b], points[b, ..., :3],
input_metas[b])
bbox = input_metas[b]['box_type_3d'](
bbox_selected,
box_dim=bbox_selected.shape[-1],
with_yaw=self.bbox_coder.with_rot)
results.append((bbox, score_selected, labels))
return results
def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points,
input_meta):
"""Multi-class nms in single batch.
Args:
obj_scores (torch.Tensor): Objectness score of bounding boxes.
sem_scores (torch.Tensor): semantic class score of bounding boxes.
bbox (torch.Tensor): Predicted bounding boxes.
points (torch.Tensor): Input points.
input_meta (dict): Point cloud and image's meta info.
Returns:
tuple[torch.Tensor]: Bounding boxes, scores and labels.
"""
bbox = input_meta['box_type_3d'](
bbox,
box_dim=bbox.shape[-1],
with_yaw=self.bbox_coder.with_rot,
origin=(0.5, 0.5, 0.5))
box_indices = bbox.points_in_boxes_all(points)
corner3d = bbox.corners
minmax_box3d = corner3d.new(torch.Size((corner3d.shape[0], 6)))
minmax_box3d[:, :3] = torch.min(corner3d, dim=1)[0]
minmax_box3d[:, 3:] = torch.max(corner3d, dim=1)[0]
nonempty_box_mask = box_indices.T.sum(1) > 5
bbox_classes = torch.argmax(sem_scores, -1)
nms_selected = aligned_3d_nms(minmax_box3d[nonempty_box_mask],
obj_scores[nonempty_box_mask],
bbox_classes[nonempty_box_mask],
self.test_cfg.nms_thr)
# filter empty boxes and boxes with low score
scores_mask = (obj_scores > self.test_cfg.score_thr)
nonempty_box_inds = torch.nonzero(
nonempty_box_mask, as_tuple=False).flatten()
nonempty_mask = torch.zeros_like(bbox_classes).scatter(
0, nonempty_box_inds[nms_selected], 1)
selected = (nonempty_mask.bool() & scores_mask.bool())
if self.test_cfg.per_class_proposal:
bbox_selected, score_selected, labels = [], [], []
for k in range(sem_scores.shape[-1]):
bbox_selected.append(bbox[selected].tensor)
score_selected.append(obj_scores[selected] *
sem_scores[selected][:, k])
labels.append(
torch.zeros_like(bbox_classes[selected]).fill_(k))
bbox_selected = torch.cat(bbox_selected, 0)
score_selected = torch.cat(score_selected, 0)
labels = torch.cat(labels, 0)
else:
bbox_selected = bbox[selected].tensor
score_selected = obj_scores[selected]
labels = bbox_classes[selected]
return bbox_selected, score_selected, labels
def get_proposal_stage_loss(self,
bbox_preds,
size_class_targets,
size_res_targets,
dir_class_targets,
dir_res_targets,
center_targets,
mask_targets,
objectness_targets,
objectness_weights,
box_loss_weights,
valid_gt_weights,
suffix=''):
"""Compute loss for the aggregation module.
Args:
bbox_preds (dict): Predictions from forward of vote head.
size_class_targets (torch.Tensor): Ground truth
size class of each prediction bounding box.
size_res_targets (torch.Tensor): Ground truth
size residual of each prediction bounding box.
dir_class_targets (torch.Tensor): Ground truth
direction class of each prediction bounding box.
dir_res_targets (torch.Tensor): Ground truth
direction residual of each prediction bounding box.
center_targets (torch.Tensor): Ground truth center
of each prediction bounding box.
mask_targets (torch.Tensor): Validation of each
prediction bounding box.
objectness_targets (torch.Tensor): Ground truth
objectness label of each prediction bounding box.
objectness_weights (torch.Tensor): Weights of objectness
loss for each prediction bounding box.
box_loss_weights (torch.Tensor): Weights of regression
loss for each prediction bounding box.
valid_gt_weights (torch.Tensor): Validation of each
ground truth bounding box.
Returns:
dict: Losses of aggregation module.
"""
# calculate objectness loss
objectness_loss = self.objectness_loss(
bbox_preds['obj_scores' + suffix].transpose(2, 1),
objectness_targets,
weight=objectness_weights)
# calculate center loss
source2target_loss, target2source_loss = self.center_loss(
bbox_preds['center' + suffix],
center_targets,
src_weight=box_loss_weights,
dst_weight=valid_gt_weights)
center_loss = source2target_loss + target2source_loss
# calculate direction class loss
dir_class_loss = self.dir_class_loss(
bbox_preds['dir_class' + suffix].transpose(2, 1),
dir_class_targets,
weight=box_loss_weights)
# calculate direction residual loss
batch_size, proposal_num = size_class_targets.shape[:2]
heading_label_one_hot = dir_class_targets.new_zeros(
(batch_size, proposal_num, self.num_dir_bins))
heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1)
dir_res_norm = (bbox_preds['dir_res_norm' + suffix] *
heading_label_one_hot).sum(dim=-1)
dir_res_loss = self.dir_res_loss(
dir_res_norm, dir_res_targets, weight=box_loss_weights)
# calculate size class loss
size_class_loss = self.size_class_loss(
bbox_preds['size_class' + suffix].transpose(2, 1),
size_class_targets,
weight=box_loss_weights)
# calculate size residual loss
one_hot_size_targets = box_loss_weights.new_zeros(
(batch_size, proposal_num, self.num_sizes))
one_hot_size_targets.scatter_(2, size_class_targets.unsqueeze(-1), 1)
one_hot_size_targets_expand = one_hot_size_targets.unsqueeze(
-1).repeat(1, 1, 1, 3)
size_residual_norm = (bbox_preds['size_res_norm' + suffix] *
one_hot_size_targets_expand).sum(dim=2)
box_loss_weights_expand = box_loss_weights.unsqueeze(-1).repeat(
1, 1, 3)
size_res_loss = self.size_res_loss(
size_residual_norm,
size_res_targets,
weight=box_loss_weights_expand)
# calculate semantic loss
semantic_loss = self.semantic_loss(
bbox_preds['sem_scores' + suffix].transpose(2, 1),
mask_targets,
weight=box_loss_weights)
losses = dict(
objectness_loss=objectness_loss,
semantic_loss=semantic_loss,
center_loss=center_loss,
dir_class_loss=dir_class_loss,
dir_res_loss=dir_res_loss,
size_class_loss=size_class_loss,
size_res_loss=size_res_loss)
return losses
def get_targets(self,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
bbox_preds=None):
"""Generate targets of proposal module.
Args:
points (list[torch.Tensor]): Points of each batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): Labels of each batch.
pts_semantic_mask (list[torch.Tensor]): Point-wise semantic
label of each batch.
pts_instance_mask (list[torch.Tensor]): Point-wise instance
label of each batch.
bbox_preds (torch.Tensor): Bounding box predictions of vote head.
Returns:
tuple[torch.Tensor]: Targets of proposal module.
"""
# find empty example
valid_gt_masks = list()
gt_num = list()
for index in range(len(gt_labels_3d)):
if len(gt_labels_3d[index]) == 0:
fake_box = gt_bboxes_3d[index].tensor.new_zeros(
1, gt_bboxes_3d[index].tensor.shape[-1])
gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box)
gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1)
valid_gt_masks.append(gt_labels_3d[index].new_zeros(1))
gt_num.append(1)
else:
valid_gt_masks.append(gt_labels_3d[index].new_ones(
gt_labels_3d[index].shape))
gt_num.append(gt_labels_3d[index].shape[0])
if pts_semantic_mask is None:
pts_semantic_mask = [None for i in range(len(gt_labels_3d))]
pts_instance_mask = [None for i in range(len(gt_labels_3d))]
aggregated_points = [
bbox_preds['aggregated_points'][i]
for i in range(len(gt_labels_3d))
]
surface_center_pred = [
bbox_preds['surface_center_pred'][i]
for i in range(len(gt_labels_3d))
]
line_center_pred = [
bbox_preds['pred_line_center'][i]
for i in range(len(gt_labels_3d))
]
surface_center_object = [
bbox_preds['surface_center_object'][i]
for i in range(len(gt_labels_3d))
]
line_center_object = [
bbox_preds['line_center_object'][i]
for i in range(len(gt_labels_3d))
]
surface_sem_pred = [
bbox_preds['surface_sem_pred'][i]
for i in range(len(gt_labels_3d))
]
line_sem_pred = [
bbox_preds['sem_cls_scores_line'][i]
for i in range(len(gt_labels_3d))
]
(cues_objectness_label, cues_sem_label, proposal_objectness_label,
cues_mask, cues_match_mask, proposal_objectness_mask,
cues_matching_label, obj_surface_line_center) = multi_apply(
self.get_targets_single, points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, aggregated_points,
surface_center_pred, line_center_pred, surface_center_object,
line_center_object, surface_sem_pred, line_sem_pred)
cues_objectness_label = torch.stack(cues_objectness_label)
cues_sem_label = torch.stack(cues_sem_label)
proposal_objectness_label = torch.stack(proposal_objectness_label)
cues_mask = torch.stack(cues_mask)
cues_match_mask = torch.stack(cues_match_mask)
proposal_objectness_mask = torch.stack(proposal_objectness_mask)
cues_matching_label = torch.stack(cues_matching_label)
obj_surface_line_center = torch.stack(obj_surface_line_center)
return (cues_objectness_label, cues_sem_label,
proposal_objectness_label, cues_mask, cues_match_mask,
proposal_objectness_mask, cues_matching_label,
obj_surface_line_center)
def get_targets_single(self,
points,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
aggregated_points=None,
pred_surface_center=None,
pred_line_center=None,
pred_obj_surface_center=None,
pred_obj_line_center=None,
pred_surface_sem=None,
pred_line_sem=None):
"""Generate targets for primitive cues for single batch.
Args:
points (torch.Tensor): Points of each batch.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth
boxes of each batch.
gt_labels_3d (torch.Tensor): Labels of each batch.
pts_semantic_mask (torch.Tensor): Point-wise semantic
label of each batch.
pts_instance_mask (torch.Tensor): Point-wise instance
label of each batch.
aggregated_points (torch.Tensor): Aggregated points from
vote aggregation layer.
pred_surface_center (torch.Tensor): Prediction of surface center.
pred_line_center (torch.Tensor): Prediction of line center.
pred_obj_surface_center (torch.Tensor): Objectness prediction
of surface center.
pred_obj_line_center (torch.Tensor): Objectness prediction of
line center.
pred_surface_sem (torch.Tensor): Semantic prediction of
surface center.
pred_line_sem (torch.Tensor): Semantic prediction of line center.
Returns:
tuple[torch.Tensor]: Targets for primitive cues.
"""
device = points.device
gt_bboxes_3d = gt_bboxes_3d.to(device)
num_proposals = aggregated_points.shape[0]
gt_center = gt_bboxes_3d.gravity_center
dist1, dist2, ind1, _ = chamfer_distance(
aggregated_points.unsqueeze(0),
gt_center.unsqueeze(0),
reduction='none')
# Set assignment
object_assignment = ind1.squeeze(0)
# Generate objectness label and mask
# objectness_label: 1 if pred object center is within
# self.train_cfg['near_threshold'] of any GT object
# objectness_mask: 0 if pred object center is in gray
# zone (DONOTCARE), 1 otherwise
euclidean_dist1 = torch.sqrt(dist1.squeeze(0) + 1e-6)
proposal_objectness_label = euclidean_dist1.new_zeros(
num_proposals, dtype=torch.long)
proposal_objectness_mask = euclidean_dist1.new_zeros(num_proposals)
gt_sem = gt_labels_3d[object_assignment]
obj_surface_center, obj_line_center = \
gt_bboxes_3d.get_surface_line_center()
obj_surface_center = obj_surface_center.reshape(-1, 6,
3).transpose(0, 1)
obj_line_center = obj_line_center.reshape(-1, 12, 3).transpose(0, 1)
obj_surface_center = obj_surface_center[:, object_assignment].reshape(
1, -1, 3)
obj_line_center = obj_line_center[:,
object_assignment].reshape(1, -1, 3)
surface_sem = torch.argmax(pred_surface_sem, dim=1).float()
line_sem = torch.argmax(pred_line_sem, dim=1).float()
dist_surface, _, surface_ind, _ = chamfer_distance(
obj_surface_center,
pred_surface_center.unsqueeze(0),
reduction='none')
dist_line, _, line_ind, _ = chamfer_distance(
obj_line_center, pred_line_center.unsqueeze(0), reduction='none')
surface_sel = pred_surface_center[surface_ind.squeeze(0)]
line_sel = pred_line_center[line_ind.squeeze(0)]
surface_sel_sem = surface_sem[surface_ind.squeeze(0)]
line_sel_sem = line_sem[line_ind.squeeze(0)]
surface_sel_sem_gt = gt_sem.repeat(6).float()
line_sel_sem_gt = gt_sem.repeat(12).float()
euclidean_dist_surface = torch.sqrt(dist_surface.squeeze(0) + 1e-6)
euclidean_dist_line = torch.sqrt(dist_line.squeeze(0) + 1e-6)
objectness_label_surface = euclidean_dist_line.new_zeros(
num_proposals * 6, dtype=torch.long)
objectness_mask_surface = euclidean_dist_line.new_zeros(num_proposals *
6)
objectness_label_line = euclidean_dist_line.new_zeros(
num_proposals * 12, dtype=torch.long)
objectness_mask_line = euclidean_dist_line.new_zeros(num_proposals *
12)
objectness_label_surface_sem = euclidean_dist_line.new_zeros(
num_proposals * 6, dtype=torch.long)
objectness_label_line_sem = euclidean_dist_line.new_zeros(
num_proposals * 12, dtype=torch.long)
euclidean_dist_obj_surface = torch.sqrt((
(pred_obj_surface_center - surface_sel)**2).sum(dim=-1) + 1e-6)
euclidean_dist_obj_line = torch.sqrt(
torch.sum((pred_obj_line_center - line_sel)**2, dim=-1) + 1e-6)
# Objectness score just with centers
proposal_objectness_label[
euclidean_dist1 < self.train_cfg['near_threshold']] = 1
proposal_objectness_mask[
euclidean_dist1 < self.train_cfg['near_threshold']] = 1
proposal_objectness_mask[
euclidean_dist1 > self.train_cfg['far_threshold']] = 1
objectness_label_surface[
(euclidean_dist_obj_surface <
self.train_cfg['label_surface_threshold']) *
(euclidean_dist_surface <
self.train_cfg['mask_surface_threshold'])] = 1
objectness_label_surface_sem[
(euclidean_dist_obj_surface <
self.train_cfg['label_surface_threshold']) *
(euclidean_dist_surface < self.train_cfg['mask_surface_threshold'])
* (surface_sel_sem == surface_sel_sem_gt)] = 1
objectness_label_line[
(euclidean_dist_obj_line < self.train_cfg['label_line_threshold'])
*
(euclidean_dist_line < self.train_cfg['mask_line_threshold'])] = 1
objectness_label_line_sem[
(euclidean_dist_obj_line < self.train_cfg['label_line_threshold'])
* (euclidean_dist_line < self.train_cfg['mask_line_threshold']) *
(line_sel_sem == line_sel_sem_gt)] = 1
objectness_label_surface_obj = proposal_objectness_label.repeat(6)
objectness_mask_surface_obj = proposal_objectness_mask.repeat(6)
objectness_label_line_obj = proposal_objectness_label.repeat(12)
objectness_mask_line_obj = proposal_objectness_mask.repeat(12)
objectness_mask_surface = objectness_mask_surface_obj
objectness_mask_line = objectness_mask_line_obj
cues_objectness_label = torch.cat(
(objectness_label_surface, objectness_label_line), 0)
cues_sem_label = torch.cat(
(objectness_label_surface_sem, objectness_label_line_sem), 0)
cues_mask = torch.cat((objectness_mask_surface, objectness_mask_line),
0)
objectness_label_surface *= objectness_label_surface_obj
objectness_label_line *= objectness_label_line_obj
cues_matching_label = torch.cat(
(objectness_label_surface, objectness_label_line), 0)
objectness_label_surface_sem *= objectness_label_surface_obj
objectness_label_line_sem *= objectness_label_line_obj
cues_match_mask = (torch.sum(
cues_objectness_label.view(18, num_proposals), dim=0) >=
1).float()
obj_surface_line_center = torch.cat(
(obj_surface_center, obj_line_center), 1).squeeze(0)
return (cues_objectness_label, cues_sem_label,
proposal_objectness_label, cues_mask, cues_match_mask,
proposal_objectness_mask, cues_matching_label,
obj_surface_line_center)