-
Notifications
You must be signed in to change notification settings - Fork 10
/
train_target.py
375 lines (281 loc) · 13.6 KB
/
train_target.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import random
import argparse
import numpy as np
import torchvision.transforms as transforms
import wandb
from sklearn.metrics import accuracy_score
from torch.optim.lr_scheduler import *
from utils import *
from os.path import join
from datasets import *
from model import *
from moco import *
parser = argparse.ArgumentParser(description='PyTorch Training')
parser.add_argument('--data_dir', type=str, default='data')
parser.add_argument('--dataset', default='visdac/target', type=str)
parser.add_argument('--source', default='visdac_source', type=str)
parser.add_argument('--weights', type=str)
parser.add_argument('--noisy_path', type=str, default=None)
parser.add_argument('--num_neighbors', default=10, type=int)
parser.add_argument('--num_class', default=10, type=int)
parser.add_argument('--temporal_length', default=5, type=int)
parser.add_argument('--batch_size', default=256, type=int, help='train batchsize')
parser.add_argument('--lr', '--learning_rate', default=0.02, type=float, help='initial learning rate')
parser.add_argument('--num_epochs', default=300, type=int)
parser.add_argument('--temperature', default=0.07, type=float, help='softmax temperature (default: 0.07)')
parser.add_argument('--seed', default=123)
parser.add_argument('--gpuid', default=0, type=int)
parser.add_argument('--ctr', action='store_false', help="use contrastive loss")
parser.add_argument('--label_refinement', action='store_false', help="Use label refinement")
parser.add_argument('--neg_l', action='store_false', help="Use negative learning")
parser.add_argument('--reweighting', action='store_false', help="Use reweighting")
parser.add_argument('--run_name', type=str)
parser.add_argument('--wandb', action='store_true', help="Use wandb")
args = parser.parse_args()
torch.cuda.set_device(args.gpuid)
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
if args.wandb:
wandb.init(project="Guiding Pseudo-labels with Uncertainty Estimation for Test-Time Adaptation", name = args.run_name)
def entropy(p, axis=1):
return -torch.sum(p * torch.log2(p+1e-5), dim=axis)
def get_distances(X, Y, dist_type="cosine"):
if dist_type == "euclidean":
distances = torch.cdist(X, Y)
elif dist_type == "cosine":
distances = 1 - torch.matmul(F.normalize(X, dim=1), F.normalize(Y, dim=1).T)
else:
raise NotImplementedError(f"{dist_type} distance not implemented.")
return distances
@torch.no_grad()
def soft_k_nearest_neighbors(features, features_bank, probs_bank):
pred_probs = []
pred_probs_all = []
for feats in features.split(64):
distances = get_distances(feats, features_bank)
_, idxs = distances.sort()
idxs = idxs[:, : args.num_neighbors]
# (64, num_nbrs, num_classes), average over dim=1
probs = probs_bank[idxs, :].mean(1)
pred_probs.append(probs)
# (64, num_nbrs, num_classes)
probs_all = probs_bank[idxs, :]
pred_probs_all.append(probs_all)
pred_probs_all = torch.cat(pred_probs_all)
pred_probs = torch.cat(pred_probs)
_, pred_labels = pred_probs.max(dim=1)
# (64, num_nbrs, num_classes), max over dim=2
_, pred_labels_all = pred_probs_all.max(dim=2)
#First keep maximum for all classes between neighbors and then keep max between classes
_, pred_labels_hard = pred_probs_all.max(dim=1)[0].max(dim=1)
return pred_labels, pred_probs, pred_labels_all, pred_labels_hard
def refine_predictions(
features,
probs,
banks):
feature_bank = banks["features"]
probs_bank = banks["probs"]
pred_labels, probs, pred_labels_all, pred_labels_hard = soft_k_nearest_neighbors(
features, feature_bank, probs_bank
)
return pred_labels, probs, pred_labels_all, pred_labels_hard
def contrastive_loss(logits_ins, pseudo_labels, mem_labels):
# labels: positive key indicators
labels_ins = torch.zeros(logits_ins.shape[0], dtype=torch.long).cuda()
mask = torch.ones_like(logits_ins, dtype=torch.bool)
mask[:, 1:] = torch.all(pseudo_labels.unsqueeze(1) != mem_labels.unsqueeze(0), dim=2)
logits_ins = torch.where(mask, logits_ins, torch.tensor([float("-inf")]).cuda())
loss = F.cross_entropy(logits_ins, labels_ins)
return loss
@torch.no_grad()
def update_labels(banks, idxs, features, logits):
probs = F.softmax(logits, dim=1)
start = banks["ptr"]
end = start + len(idxs)
idxs_replace = torch.arange(start, end).cuda() % len(banks["features"])
banks["features"][idxs_replace, :] = features
banks["probs"][idxs_replace, :] = probs
banks["ptr"] = end % len(banks["features"])
def div(logits, epsilon=1e-8):
probs = F.softmax(logits, dim=1)
probs_mean = probs.mean(dim=0)
loss_div = -torch.sum(-probs_mean * torch.log(probs_mean + epsilon))
return loss_div
def nl_criterion(output, y):
output = torch.log( torch.clamp(1.-F.softmax(output, dim=1), min=1e-5, max=1.) )
labels_neg = ( (y.unsqueeze(-1).repeat(1, 1) + torch.LongTensor(len(y), 1).random_(1, args.num_class).cuda()) % args.num_class ).view(-1)
l = F.nll_loss(output, labels_neg, reduction='none')
return l
# Training
def train(epoch, net, moco_model, optimizer, trainloader, banks):
loss = 0
acc = 0
net.train()
moco_model.train()
for batch_idx, batch in enumerate(trainloader):
weak_x = batch[0].cuda()
strong_x = batch[1].cuda()
y = batch[2].cuda()
idxs = batch[3].cuda()
strong_x2 = batch[5].cuda()
feats_w, logits_w = moco_model(weak_x, cls_only=True)
if args.label_refinement:
with torch.no_grad():
probs_w = F.softmax(logits_w, dim=1)
pseudo_labels_w, probs_w, _, _ = refine_predictions(feats_w, probs_w, banks)
else:
probs_w = F.softmax(logits_w, dim=1)
pseudo_labels_w = probs_w.max(1)[1]
_, logits_q, logits_ctr, keys = moco_model(strong_x, strong_x2)
if args.ctr:
loss_ctr = contrastive_loss(
logits_ins=logits_ctr,
pseudo_labels=moco_model.mem_labels[idxs],
mem_labels=moco_model.mem_labels[moco_model.idxs]
)
else:
loss_ctr = 0
# update key features and corresponding pseudo labels
moco_model.update_memory(epoch, idxs, keys, pseudo_labels_w, y)
with torch.no_grad():
#CE weights
max_entropy = torch.log2(torch.tensor(args.num_class))
w = entropy(probs_w)
w = w / max_entropy
w = torch.exp(-w)
#Standard positive learning
if args.neg_l:
#Standard negative learning
loss_cls = ( nl_criterion(logits_q, pseudo_labels_w)).mean()
if args.reweighting:
loss_cls = (w * nl_criterion(logits_q, pseudo_labels_w)).mean()
else:
loss_cls = ( CE(logits_q, pseudo_labels_w)).mean()
if args.reweighting:
loss_cls = (w * CE(logits_q, pseudo_labels_w)).mean()
loss_div = div(logits_w) + div(logits_q)
l = loss_cls + loss_ctr + loss_div
update_labels(banks, idxs, feats_w, logits_w)
l.backward()
optimizer.step()
optimizer.zero_grad()
accuracy = 100.*accuracy_score(y.to('cpu'), logits_w.to('cpu').max(1)[1])
loss += l.item()
acc += accuracy
if batch_idx % 100 == 0:
print('Epoch [%3d/%3d] Iter[%3d/%3d]\t '
%(epoch, args.num_epochs, batch_idx+1, len(trainloader)))
print("Acc ", acc/(batch_idx+1))
print("Training acc = ", acc/len(trainloader))
if args.wandb:
wandb.log({
'train_loss': loss_cls/len(trainloader), \
'train_acc': acc/len(trainloader), \
}, step=epoch)
@torch.no_grad()
def eval_and_label_dataset(epoch, model, banks):
print("Evaluating Dataset!")
model.eval()
logits, indices, gt_labels = [], [], []
features = []
for batch_idx, batch in enumerate(test_loader):
inputs, targets, idxs = batch[0].cuda(), batch[2].cuda(), batch[3].cuda()
feats, logits_cls = model(inputs, cls_only=True)
features.append(feats)
gt_labels.append(targets)
logits.append(logits_cls)
indices.append(idxs)
features = torch.cat(features)
gt_labels = torch.cat(gt_labels)
logits = torch.cat(logits)
indices = torch.cat(indices)
probs = F.softmax(logits, dim=1)
rand_idxs = torch.randperm(len(features)).cuda()
banks = {
"features": features[rand_idxs][: 16384],
"probs": probs[rand_idxs][: 16384],
"ptr": 0,
}
# refine predicted labels
pred_labels, _, _, _ = refine_predictions(features, probs, banks)
acc = 100.*accuracy_score(gt_labels.to('cpu'), pred_labels.to('cpu'))
print("\n| Test Epoch #%d\t Accuracy: %.2f%%\n" %(epoch,acc))
if args.wandb:
wandb.log({
'val_accuracy': acc, \
}, step=epoch)
return acc, banks, gt_labels, pred_labels
def create_model(arch, args):
model = Resnet(arch, args)
model = model.cuda()
return model
arch = 'resnet18'
if args.dataset.split('/')[0] == 'pacs':
train_dataset = dataset(dataset=args.dataset, root=join(args.data_dir, 'PACS'), noisy_path=None,
mode='all',
transform=transforms.Compose([transforms.Resize(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
)
test_dataset = dataset(dataset=args.dataset, root=join(args.data_dir, 'PACS'), noisy_path=None,
mode='all',
transform=transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
)
elif args.dataset.split('/')[0] == 'visdac':
train_dataset = dataset(dataset=args.dataset, root=join(args.data_dir, 'VISDA-C'), noisy_path=None,
mode='train',
transform=transforms.Compose([transforms.Resize(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
)
test_dataset = dataset(dataset=args.dataset, root=join(args.data_dir, 'VISDA-C'), noisy_path=None,
mode='test',
transform=transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
)
arch = 'resnet101'
elif args.dataset.split('/')[0] == 'domainnet':
train_dataset = dataset(dataset=args.dataset, root=join(args.data_dir, 'domainnet-126'), noisy_path=None,
mode='all',
transform=transforms.Compose([transforms.Resize(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
)
test_dataset = dataset(dataset=args.dataset, root=join(args.data_dir, 'domainnet-126'), noisy_path=None,
mode='all',
transform=transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
)
arch = 'resnet50'
logdir = 'logs/' + args.run_name
net = create_model(arch, args)
momentum_net = create_model(arch, args)
load_weights(net, 'logs/' + args.source + '/weights_best.tar')
load_weights(momentum_net, 'logs/' + args.source + '/weights_best.tar')
optimizer = optim.SGD(net.parameters(), lr=args.lr, weight_decay=5e-4)
moco_model = AdaMoCo(src_model = net, momentum_model = momentum_net, features_length=net.bottleneck_dim, num_classes=args.num_class, dataset_length=len(train_dataset), temporal_length=args.temporal_length)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
cudnn.benchmark = True
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=args.batch_size,
num_workers=4,
drop_last=True,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=args.batch_size,
num_workers=4,
drop_last=True,
shuffle=False)
CE = nn.CrossEntropyLoss(reduction='none')
CEloss = nn.CrossEntropyLoss()
best = 0
acc, banks, _, _ = eval_and_label_dataset(0, moco_model, None)
for epoch in range(args.num_epochs+1):
print("Training started!")
train(epoch, net, moco_model, optimizer, train_loader, banks) # train net1
acc, banks, gt_labels, pred_labels = eval_and_label_dataset(epoch, moco_model, banks)
if acc > best:
save_weights(net, epoch, logdir + '/weights_best.tar')
best = acc
print("Saving best!")
if args.wandb:
wandb.run.summary['best_acc'] = best