-
Notifications
You must be signed in to change notification settings - Fork 1
/
cal_map.py
145 lines (118 loc) · 4.24 KB
/
cal_map.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from torch.autograd import Variable
import numpy as np
import torch
def extractab(test, model, classes=80):
queryB = list([])
queryH = list([])
for batch_step, (data, target, _) in enumerate(test):
var_data = Variable(data.cuda())
H = model(var_data)
code = torch.sign(H)
queryB.extend(code.cpu().data.numpy())
queryH.extend(H.cpu().data.numpy())
queryB = np.array(queryB)
queryH = np.array(queryH)
return queryB, queryH
def extractab1(test, model, classes=80):
queryB = list([])
queryH = list([])
for batch_step, (data, target, _) in enumerate(test):
var_data = Variable(data.cuda())
H= model(var_data)
code = torch.sign(H)
queryB.extend(code.cpu().data.numpy())
queryH.extend(H.cpu().data.numpy())
queryB = np.array(queryB)
queryH = np.array(queryH)
return queryB, queryH
def compress(train, test, model, classes=80):
retrievalB = list([])
retrievalL = np.ones((1, 80))
for batch_step, (data, target, _) in enumerate(train):
var_data = Variable(data.cuda())
H= model(var_data)
code = torch.sign(H)
retrievalB.extend(code.cpu().data.numpy())
#retrievalL.append(target)
retrievalL = np.concatenate((retrievalL,target.numpy()), axis=0)
#retrievalL = torch.cat((Variable(retrievalL),Variable(target)), 0)
queryB = list([])
queryL = np.ones((1, 80))
for batch_step, (data, target, _) in enumerate(test):
var_data = Variable(data.cuda())
H = model(var_data)
code = torch.sign(H)
queryB.extend(code.cpu().data.numpy())
#queryL.append(target)
queryL = np.concatenate((queryL,target.numpy()), axis=0)
retrievalB = np.array(retrievalB)
retrievalL = retrievalL[1:,:]
retrievalL = np.array(retrievalL)
queryB = np.array(queryB)
queryL = queryL[1:,:]
queryL = np.array(queryL)
return retrievalB, retrievalL, queryB, queryL
def calculate_hamming(B1, B2):
"""
:param B1: vector [n]
:param B2: vector [r*n]
:return: hamming distance [r]
"""
q = B2.shape[1] # max inner product value
distH = 0.5 * (q - np.dot(B1, B2.transpose()))
return distH
def calculate_map(qB, rB, queryL, retrievalL):
"""
:param qB: {-1,+1}^{mxq} query bits
:param rB: {-1,+1}^{nxq} retrieval bits
:param queryL: {0,1}^{mxl} query label
:param retrievalL: {0,1}^{nxl} retrieval label
:return:
"""
num_query = queryL.shape[0]
map = 0
for iter in range(num_query):
# gnd : check if exists any retrieval items with same label
gnd = (np.dot(queryL[iter, :], retrievalL.transpose()) > 0).astype(np.float32)
# tsum number of items with same label
tsum = np.sum(gnd)
if tsum == 0:
continue
# sort gnd by hamming dist
hamm = calculate_hamming(qB[iter, :], rB)
ind = np.argsort(hamm)
gnd = gnd[ind]
count = np.linspace(1, tsum, tsum) # [1,2, tsum]
tindex = np.asarray(np.where(gnd == 1)) + 1.0
map_ = np.mean(count / (tindex))
# print(map_)
map = map + map_
map = map / num_query
return map
def calculate_top_map(qB, rB, queryL, retrievalL, topk):
"""
:param qB: {-1,+1}^{mxq} query bits
:param rB: {-1,+1}^{nxq} retrieval bits
:param queryL: {0,1}^{mxl} query label
:param retrievalL: {0,1}^{nxl} retrieval label
:param topk:
:return:
"""
num_query = queryL.shape[0]
topkmap = 0
for iter in range(num_query):
gnd = (np.dot(queryL[iter, :], retrievalL.transpose()) > 0).astype(np.float32)
hamm = calculate_hamming(qB[iter, :], rB)
ind = np.argsort(hamm)
gnd = gnd[ind]
tgnd = gnd[0:topk]
tsum = np.sum(tgnd)
if tsum == 0:
continue
count = np.linspace(1, tsum, tsum)
tindex = np.asarray(np.where(tgnd == 1)) + 1.0
topkmap_ = np.mean(count / (tindex))
# print(topkmap_)
topkmap = topkmap + topkmap_
topkmap = topkmap / num_query
return topkmap