-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
129 lines (112 loc) · 4.55 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Torch imports
import torch
import numpy as np
# Python imports
import tqdm
from tqdm import tqdm
import os
from os.path import join as ospj
# Local imports
from data import dataset as dset
from model.common import Evaluator
from utils.utils import load_args
from config_model import configure_model
from flags import parser
def main():
args = parser.parse_args()
args.dataset = 'ut-zap50k' # Choose from ut-zap50k | mit-states | cgqa
args.main_root = os.path.dirname(__file__)
args.data_root = '/root/datasets'
device = 0 # Your GPU order. If you don't have a GPU, ignore this.
# Get arguments and start logging
print('> Initialize parameters')
config_path = ospj(args.main_root, 'configs', args.dataset, 'CANet.yml')
if os.path.exists(config_path):
load_args(config_path, args)
print(' Load parameter values from file {}'.format(config_path))
else:
print(' No yml file found. Keep default parameter values in flags.py')
if torch.cuda.is_available():
args.device = 'cuda:{}'.format(device)
else:
args.device = 'cpu'
print('> Choose device: {}'.format(args.device))
# Get dataset
print('> Load dataset')
args.phase = 'test'
testset = dset.CompositionDataset(
args=args,
root=os.path.join(args.data_root, args.dataset),
phase=args.phase,
split=args.splitname,
model =args.image_extractor,
update_image_features = args.update_image_features,
)
testloader = torch.utils.data.DataLoader(
testset,
batch_size=args.test_batch_size,
shuffle=False,
num_workers=args.num_workers)
# Get model and optimizer
args.train = False
image_extractor, model = configure_model(args, testset, train=args.train)
print(model)
# load saved model
print('> Load saved trained model')
save_path = os.path.join(args.main_root, 'saved model')
if os.path.exists(save_path):
checkpoint = torch.load(ospj(save_path, 'saved_{}.t7'.format(args.dataset)), map_location=args.device)
else:
print(' No saved model found in local disk. Please run train.py to train the model first')
return
if image_extractor:
try:
image_extractor.load_state_dict(checkpoint['image_extractor'])
image_extractor.eval()
except:
print(' No saved image extractor in checkpoint file')
model.load_state_dict(checkpoint['net'])
model.eval()
print('> Initialize evaluator')
evaluator = Evaluator(testset, args)
with torch.no_grad():
test(image_extractor, model, testloader, evaluator, args)
def test(image_extractor, model, testloader, evaluator, args):
if image_extractor:
image_extractor.eval()
model.eval()
all_attr_gt, all_obj_gt, all_pair_gt, all_pred = [], [], [], []
for _, data in tqdm(enumerate(testloader), total=len(testloader), desc='Testing'):
data = [d.to(args.device) for d in data]
if image_extractor:
data[0] = image_extractor(data[0])
_, predictions = model.val_forward(data)
attr_truth, obj_truth, pair_truth = data[1], data[2], data[3]
all_pred.append(predictions)
all_attr_gt.append(attr_truth)
all_obj_gt.append(obj_truth)
all_pair_gt.append(pair_truth)
all_attr_gt, all_obj_gt, all_pair_gt = torch.cat(all_attr_gt).to('cpu'), torch.cat(all_obj_gt).to(
'cpu'), torch.cat(all_pair_gt).to('cpu')
all_pred_dict = {}
# Gather values as dict of (attr, obj) as key and list of predictions as values
for k in all_pred[0].keys():
all_pred_dict[k] = torch.cat(
[all_pred[i][k].to('cpu') for i in range(len(all_pred))])
# Calculate best unseen accuracy
results = evaluator.score_model(all_pred_dict, all_obj_gt, bias=args.bias, topk=args.topk)
stats = evaluator.evaluate_predictions(results, all_attr_gt, all_obj_gt, all_pair_gt, all_pred_dict,
topk=args.topk)
attr_acc = stats['closed_attr_match']
obj_acc = stats['closed_obj_match']
seen_acc = stats['best_seen']
unseen_acc = stats['best_unseen']
HM = stats['best_hm']
AUC = stats['AUC']
print('|----Test Finished: Attr Acc: {:.2f}% | Obj Acc: {:.2f}% | Seen Acc: {:.2f}% | Unseen Acc: {:.2f}% | HM: {:.2f}% | AUC: {:.2f}'.\
format(attr_acc*100, obj_acc*100, seen_acc*100, unseen_acc*100, HM*100, AUC*100))
if __name__ == '__main__':
print('======== Welcome! ========')
print('> Program start')
main()
print('=== Program terminated ===')