-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathloss.py
138 lines (115 loc) · 4.8 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
def gather_normalize_2d3d(fm:torch.Tensor,
fm_mask:torch.Tensor,
seeds_features:torch.Tensor,
ind2d:torch.Tensor,
ind3d:torch.Tensor):
B, C ,_ ,_ = fm.size()
fm_lin = fm.view(B, C, -1)
fm_lin_sample = _gather(fm_lin, fm_mask) # (B, C, num_samples)
f2d = _gather(fm_lin_sample, ind2d) # (B, C, num_match)
f3d = _gather(seeds_features, ind3d) # (B, C, num_match)
f2d = torch.nn.functional.normalize(f2d, dim=1)
f3d = torch.nn.functional.normalize(f3d, dim=1)
return f2d, f3d
def contrast_in_bacth(f1:torch.Tensor, f2:torch.Tensor, t:float, symmetric=True):
"""use all other points in the batch as negative samples"""
B, C, N = f1.size()
# reshape features
f1 = f1.transpose(0, 1).contiguous().view(1, C, B*N) # (B,C,N) -> (C,B,N) -> (1,C,B*N)
f2 = f2.transpose(0, 1).contiguous().view(1, C, B*N)
logit1 = torch.bmm(f1.transpose(1, 2), f2) # (1, B*N, B*N)
label = torch.arange(N*B).long().unsqueeze(0).to(logit1.device) # (1, B*N)
loss = torch.nn.CrossEntropyLoss()(logit1/t, label)
pred = torch.argmax(logit1, dim=1) # (B, num_match)
acc1 = (pred == label)
acc1 = torch.sum(acc1.float())/( acc1.size(0) * acc1.size(1))
if symmetric:
logit2 = torch.bmm(f2.transpose(1, 2), f1) # (1, B*N, B*N)
loss += torch.nn.CrossEntropyLoss()(logit2/t, label)
loss /= 2.
pred = torch.argmax(logit2, dim=1) # (B, num_match)
acc2 = (pred == label)
acc2 = torch.sum(acc2.float())/( acc2.size(0) * acc2.size(1))
acc = (acc1 + acc2)/2.
else:
acc = acc1
ret = {}
ret["acc"] = acc
ret["loss"] = loss
ret["preds"] = pred
return ret
def contrast_in_scene(f1:torch.Tensor, f2:torch.Tensor, t:float, symmetric=True):
"""only use points in the same scene as negative samples"""
B, C, N = f1.size()
logit1 = torch.bmm(f1.transpose(1, 2), f2) # (B, N, N)
label = torch.arange(N).long().unsqueeze(0).to(logit1.device) # (1, N)
label = label.repeat(B, 1)
loss = torch.nn.CrossEntropyLoss()(logit1/t, label)
pred = torch.argmax(logit1, dim=1) # (B, num_match)
acc1 = (pred == label)
acc1 = torch.sum(acc1.float())/( acc1.size(0) * acc1.size(1))
if symmetric:
logit2 = torch.bmm(f2.transpose(1, 2), f1) # (B, N, N)
loss += torch.nn.CrossEntropyLoss()(logit2/t, label)
loss /= 2.
pred = torch.argmax(logit2, dim=1) # (B, num_match)
acc2 = (pred == label)
acc2 = torch.sum(acc2.float())/( acc2.size(0) * acc2.size(1))
acc = (acc1 + acc2)/2.
else:
acc = acc1
ret = {}
ret["acc"] = acc
ret["loss"] = loss
ret["preds"] = pred
return ret
def point_info_nce_loss_2d3d(fm:torch.Tensor,
fm_mask:torch.Tensor,
seeds_features:torch.Tensor,
ind2d:torch.Tensor,
ind3d:torch.Tensor,
t:float,
symmetric=True,
in_batch=False):
"""constrastive loss for depth and point based network"""
f2d, f3d = gather_normalize_2d3d(fm, fm_mask, seeds_features, ind2d, ind3d)
if in_batch:
ret = contrast_in_bacth(f2d, f3d, t, symmetric)
else:
ret = contrast_in_scene(f2d, f3d, t, symmetric)
return ret
def point_info_nce_loss(f1, f2, t,
symmetric=True,
in_batch=False):
if in_batch:
ret = contrast_in_bacth(f1, f2, t, symmetric)
else:
ret = contrast_in_scene(f1, f2, t, symmetric)
return ret
def l1_loss(fm:torch.Tensor,
fm_mask:torch.Tensor,
seeds_features:torch.Tensor,
ind2d:torch.Tensor,
ind3d:torch.Tensor):
B, C ,_ ,_ = fm.size()
fm_lin = fm.view(B, C, -1)
fm_lin_sample = _gather(fm_lin, fm_mask) # (B, C, num_samples)
f2d = _gather(fm_lin_sample, ind2d) # B, C, num_match
f3d = _gather(seeds_features, ind3d) # B, C, num_match
loss = torch.nn.functional.smooth_l1_loss(f2d, f3d, reduction="mean")
ret = {}
ret["acc"] = torch.Tensor([0])
ret["loss"] = loss * 10
return ret
def _gather(feats:torch.Tensor, ind:torch.Tensor) -> torch.Tensor:
"""expand index gather
Args:
feats (torch.Tensor): (B, C, N)
ind (torch.Tensor): (B, M)
Returns:
torch.Tensor: (B, C, M)
"""
C = feats.size(1)
ind_ext = ind.unsqueeze(1).repeat(1,C,1)
return torch.gather(feats, 2, ind_ext)