-
Notifications
You must be signed in to change notification settings - Fork 8
/
losses.py
103 lines (85 loc) · 3.54 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import math
class CombinedMarginLoss(torch.nn.Module):
def __init__(self,
s,
m1,
m2,
m3,
interclass_filtering_threshold=0):
super().__init__()
self.s = s
self.m1 = m1
self.m2 = m2
self.m3 = m3
self.interclass_filtering_threshold = interclass_filtering_threshold
# For ArcFace
self.cos_m = math.cos(self.m2)
self.sin_m = math.sin(self.m2)
self.theta = math.cos(math.pi - self.m2)
self.sinmm = math.sin(math.pi - self.m2) * self.m2
self.easy_margin = False
def forward(self, logits, labels):
index_positive = torch.where(labels != -1)[0]
if self.interclass_filtering_threshold > 0:
with torch.no_grad():
dirty = logits > self.interclass_filtering_threshold
dirty = dirty.float()
mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device)
mask.scatter_(1, labels[index_positive], 0)
dirty[index_positive] *= mask
tensor_mul = 1 - dirty
logits = tensor_mul * logits
target_logit = logits[index_positive, labels[index_positive].view(-1)]
if self.m1 == 1.0 and self.m3 == 0.0:
with torch.no_grad():
target_logit.arccos_()
logits.arccos_()
final_target_logit = target_logit + self.m2
logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
logits.cos_()
logits = logits * self.s
elif self.m3 > 0:
final_target_logit = target_logit - self.m3
logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
logits = logits * self.s
else:
raise
return logits
class ArcFace(torch.nn.Module):
""" ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
"""
def __init__(self, s=64.0, margin=0.5):
super(ArcFace, self).__init__()
self.scale = s
self.cos_m = math.cos(margin)
self.sin_m = math.sin(margin)
self.theta = math.cos(math.pi - margin)
self.sinmm = math.sin(math.pi - margin) * margin
self.easy_margin = False
def forward(self, logits: torch.Tensor, labels: torch.Tensor):
index = torch.where(labels != -1)[0]
target_logit = logits[index, labels[index].view(-1)]
sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin)
if self.easy_margin:
final_target_logit = torch.where(
target_logit > 0, cos_theta_m, target_logit)
else:
final_target_logit = torch.where(
target_logit > self.theta, cos_theta_m, target_logit - self.sinmm)
logits[index, labels[index].view(-1)] = final_target_logit
logits = logits * self.scale
return logits
class CosFace(torch.nn.Module):
def __init__(self, s=64.0, m=0.40):
super(CosFace, self).__init__()
self.s = s
self.m = m
def forward(self, logits: torch.Tensor, labels: torch.Tensor):
index = torch.where(labels != -1)[0]
target_logit = logits[index, labels[index].view(-1)]
final_target_logit = target_logit - self.m
logits[index, labels[index].view(-1)] = final_target_logit
logits = logits * self.s
return logits