-
Notifications
You must be signed in to change notification settings - Fork 10
/
moco.py
106 lines (81 loc) · 3.32 KB
/
moco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
from torch import nn
import torch.nn.functional as F
class AdaMoCo(nn.Module):
def __init__(self, src_model, momentum_model, features_length, num_classes, dataset_length, temporal_length):
super(AdaMoCo, self).__init__()
self.m = 0.999
self.first_update = True
self.src_model = src_model
self.momentum_model = momentum_model
self.momentum_model.requires_grad_(False)
self.queue_ptr = 0
self.mem_ptr = 0
self.T_moco = 0.07
#queue length
self.K = min(16384, dataset_length)
self.memory_length = temporal_length
self.register_buffer("features", torch.randn(features_length, self.K))
self.register_buffer(
"labels", torch.randint(0, num_classes, (self.K,))
)
self.register_buffer(
"idxs", torch.randint(0, dataset_length, (self.K,))
)
self.register_buffer(
"mem_labels", torch.randint(0, num_classes, (dataset_length, self.memory_length))
)
self.register_buffer(
"real_labels", torch.randint(0, num_classes, (dataset_length,))
)
self.features = F.normalize(self.features, dim=0)
self.features = self.features.cuda()
self.labels = self.labels.cuda()
self.mem_labels = self.mem_labels.cuda()
self.real_labels = self.real_labels.cuda()
self.idxs = self.idxs.cuda()
@torch.no_grad()
def _momentum_update_key_encoder(self):
# encoder_q -> encoder_k
for param_q, param_k in zip(
self.src_model.parameters(), self.momentum_model.parameters()
):
param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)
@torch.no_grad()
def update_memory(self, epoch, idxs, keys, pseudo_labels, real_label):
start = self.queue_ptr
end = start + len(keys)
idxs_replace = torch.arange(start, end).cuda() % self.K
self.features[:, idxs_replace] = keys.T
self.labels[idxs_replace] = pseudo_labels
self.idxs[idxs_replace] = idxs
self.real_labels[idxs_replace] = real_label
self.queue_ptr = end % self.K
self.mem_labels[idxs, self.mem_ptr] = pseudo_labels
self.mem_ptr = epoch % self.memory_length
@torch.no_grad()
def get_memory(self):
return self.features, self.labels
def forward(self, im_q, im_k=None, cls_only=False):
# compute query features
feats_q, logits_q = self.src_model(im_q)
if cls_only:
return feats_q, logits_q
q = F.normalize(feats_q, dim=1)
# compute key features
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder() # update the key encoder
k, _ = self.momentum_model(im_k)
k = F.normalize(k, dim=1)
# compute logits
# Einstein sum is more intuitive
# positive logits: Nx1
l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum("nc,ck->nk", [q, self.features.clone().detach()])
# logits: Nx(1+K)
logits_ins = torch.cat([l_pos, l_neg], dim=1)
# apply temperature
logits_ins /= self.T_moco
# dequeue and enqueue will happen outside
return feats_q, logits_q, logits_ins, k