-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathtest.py
117 lines (98 loc) · 2.85 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import os
import random
from scipy.io import savemat
import shutil
import torch
import torch.optim as optim
import torch.nn as nn
from dataset import Dataset
from templates import get_templates
MODEL_DIR = './models/'
BACKBONE = 'xcp'
MAPTYPE = 'tmp'
BATCH_SIZE = 200
MAX_EPOCHS = 100
CONFIGS = {
'xcp': {
'img_size': (299, 299),
'map_size': (19, 19),
'norms': [[0.5] * 3, [0.5] * 3]
},
'vgg': {
'img_size': (299, 299),
'map_size': (19, 19),
'norms': [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
}
}
CONFIG = CONFIGS[BACKBONE]
if BACKBONE == 'xcp':
from xception import Model
elif BACKBONE == 'vgg':
from vgg import Model
torch.backends.deterministic = True
SEED = 1
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
def get_dataset():
return Dataset('test', BATCH_SIZE, CONFIG['img_size'], CONFIG['map_size'], CONFIG['norms'], SEED)
DATA_TEST = None
TEMPLATES = None
if MAPTYPE in ['tmp', 'pca_tmp']:
TEMPLATES = get_templates()
MODEL_NAME = '{0}_{1}'.format(BACKBONE, MAPTYPE)
MODEL_DIR = MODEL_DIR + MODEL_NAME + '/'
MODEL = Model(MAPTYPE, TEMPLATES, 2, False)
MODEL.model.cuda()
LOSS_CSE = nn.CrossEntropyLoss().cuda()
LOSS_L1 = nn.L1Loss().cuda()
MAXPOOL = nn.MaxPool2d(19).cuda()
def calculate_losses(batch):
img = batch['img']
msk = batch['msk']
lab = batch['lab']
x, mask, vec = MODEL.model(img)
loss_l1 = LOSS_L1(mask, msk)
loss_cse = LOSS_CSE(x, lab)
loss = loss_l1 + loss_cse
pred = torch.max(x, dim=1)[1]
acc = (pred == lab).float().mean()
res = { 'lab': lab, 'msk': msk, 'score': x, 'pred': pred, 'mask': mask }
results = {}
for r in res:
results[r] = res[r].squeeze().cpu().numpy()
return { 'loss': loss, 'loss_l1': loss_l1, 'loss_cse': loss_cse, 'acc': acc }, results
def process_batch(batch, mode):
MODEL.model.eval()
with torch.no_grad():
losses, results = calculate_losses(batch)
return losses, results
def run_step(di, e, s, resultdir):
batch = DATA_TEST.get_batch(di)
if batch is None:
return True
losses, results = process_batch(batch, 'test')
savemat('{0}{1}_{2}.mat'.format(resultdir, di, s), results)
if s % 10 == 0:
print('\r{0} - '.format(s) + ', '.join(['{0}: {1:.3f}'.format(_, losses[_].cpu().detach().numpy()) for _ in losses]), end='')
return False
def run_epoch(di, e, resultdir):
s = 0
while True:
s += 1
is_done = run_step(di, e, s, resultdir)
if is_done:
break
LAST_EPOCH = 75
for e in range(LAST_EPOCH, MAX_EPOCHS, 5):
resultdir = '{0}results/{1}/'.format(MODEL_DIR, e)
if os.path.exists(resultdir):
shutil.rmtree(resultdir)
os.makedirs(resultdir, exist_ok=True)
MODEL.load(e, MODEL_DIR)
DATA_TEST = get_dataset()
for di, d in enumerate(DATA_TEST.datasets):
run_epoch(di, e, resultdir)
print()
print('Testing complete')