-
Notifications
You must be signed in to change notification settings - Fork 2
/
random_search.py
89 lines (70 loc) · 2.72 KB
/
random_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models
import transforms
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import matplotlib.pyplot as plt
import os
import copy
from PIL import Image
import json
from load_data import build_dataset
def emd_loss(target, output):
p_sum = torch.cumsum(output,dim=1)
q_sum = torch.cumsum(target,dim=1)
loss = torch.mean(torch.sqrt(torch.mean((q_sum-p_sum)**2,1)))
return loss
if __name__ == '__main__':
pth_path = r'/home/mancmanomyst/tid/saved_models/'
path_to_json_train = r'/home/mancmanomyst/tid/tid_labels_train.json'
path_to_json_test = r'/home/mancmanomyst/tid/tid_labels_test.json'
path_to_imgs = r'/home/mancmanomyst/tid/distorted_images/'
TID_train,TID_val,TID_test = build_dataset(path_to_imgs,path_to_json_train,path_to_json_test)
data_loader_train = DataLoader(TID_train,batch_size=16,num_workers=2,shuffle=True)
data_loader_val = DataLoader(TID_val,batch_size=16,num_workers=2)
device = torch.device('cuda')
n_search = 20
cache = {}
for n in range(n_search):
model = models.mobilenet_v2(pretrained=True)
model.classifier = nn.Sequential(nn.Dropout(0.75),
nn.Linear(1280, 10),nn.Softmax())
model = model.to(device)
conv_base_lr = 10**(np.random.uniform(low =-6,high=-3))
dense_lr = 10**(np.random.uniform(low=-6,high=-3))
optimizer = optim.Adam([
{'params': model.features.parameters(), 'lr': conv_base_lr},
{'params': model.classifier.parameters(), 'lr': dense_lr}],
)
num_epochs = 10
val_losses = []
for epoch in range(num_epochs):
model.train()
for i,batch in enumerate(data_loader_train):
imgs = batch['img'].to(device)
labels = batch['label'].to(device)
outputs = model(imgs)
outputs = outputs.view(-1,10)
optimizer.zero_grad()
loss = emd_loss(labels, outputs)
loss.backward()
optimizer.step()
#print('Epoch: {}/{} | Step: {}/{} | Training EMD loss: {0:.4f}'.format(epoch + 1, num_epochs, i + 1, len(TID_train) // 32 + 1, loss.data[0]))
batch_val_losses = []
model.eval()
for i,batch in enumerate(data_loader_val):
images = batch['img'].to(device)
labels = batch['label'].to(device)
with torch.no_grad():
outputs = model(images)
outputs = outputs.view(-1, 10)
val_loss = emd_loss(labels, outputs)
batch_val_losses.append((val_loss.item())*(images.size()[0]))
avg_val_loss = sum(batch_val_losses)/len(TID_val)
val_losses.append(avg_val_loss)
cache[n] = [sum(val_losses)/len(val_losses),conv_base_lr,dense_lr]
print('conv_base_lr : {} dense_lr : {} val_loss : {:.4f}'.format(conv_base_lr,dense_lr, sum(val_losses)/len(val_losses)))