1
+ import warnings
1
2
import numpy as np
2
3
from scipy .cluster .vq import kmeans2 , vq
3
4
4
5
6
+ def dist_l2 (q , x ):
7
+ return np .linalg .norm (q - x , ord = 2 , axis = 1 ) ** 2
8
+
9
+
10
+ def dist_ip (q , x ):
11
+ return np .matmul (x , q [None , :].T ).sum (axis = - 1 )
12
+
13
+
14
+ def dist_angular (q , x ):
15
+ return dist_ip (q , x )
16
+
17
+
18
+ metric_function_map = {
19
+ 'l2' : dist_l2 ,
20
+ 'angular' : dist_angular ,
21
+ 'dot' : dist_ip
22
+ }
23
+
24
+
5
25
class PQ (object ):
6
26
"""Pure python implementation of Product Quantization (PQ) [Jegou11]_.
7
27
@@ -19,12 +39,14 @@ class PQ(object):
19
39
M (int): The number of sub-space
20
40
Ks (int): The number of codewords for each subspace
21
41
(typically 256, so that each sub-vector is quantized
22
- into 8 bits = 1 byte = uint8)
42
+ into 256 bits = 1 byte = uint8)
43
+ metric (str): Type of metric used among vectors
23
44
verbose (bool): Verbose flag
24
45
25
46
Attributes:
26
47
M (int): The number of sub-space
27
48
Ks (int): The number of codewords for each subspace
49
+ metric (str): Type of metric used among vectors
28
50
verbose (bool): Verbose flag
29
51
code_dtype (object): dtype of PQ-code. Either np.uint{8, 16, 32}
30
52
codewords (np.ndarray): shape=(M, Ks, Ds) with dtype=np.float32.
@@ -33,17 +55,22 @@ class PQ(object):
33
55
34
56
"""
35
57
36
- def __init__ (self , M , Ks = 256 , verbose = True ):
58
+ def __init__ (self , M , Ks = 256 , metric = 'l2' , minit = 'random' , verbose = True ):
37
59
assert 0 < Ks <= 2 ** 32
38
- self .M , self .Ks , self .verbose = M , Ks , verbose
60
+ assert metric in ['l2' , 'dot' , 'angular' ]
61
+ assert minit in ['random' , '++' , 'points' , 'matrix' ]
62
+ self .M , self .Ks , self .verbose , self .metric = M , Ks , verbose , metric
39
63
self .code_dtype = (
40
64
np .uint8 if Ks <= 2 ** 8 else (np .uint16 if Ks <= 2 ** 16 else np .uint32 )
41
65
)
42
66
self .codewords = None
43
67
self .Ds = None
68
+ self .metric = metric
69
+ self .minit = minit
44
70
45
71
if verbose :
46
- print ("M: {}, Ks: {}, code_dtype: {}" .format (M , Ks , self .code_dtype ))
72
+ print ("M: {}, Ks: {}, metric : {}, code_dtype: {} minit: {}" .format (
73
+ M , Ks , self .code_dtype , metric , minit ))
47
74
48
75
def __eq__ (self , other ):
49
76
if isinstance (other , PQ ):
@@ -88,9 +115,9 @@ def fit(self, vecs, iter=20, seed=123):
88
115
for m in range (self .M ):
89
116
if self .verbose :
90
117
print ("Training the subspace: {} / {}" .format (m , self .M ))
91
- vecs_sub = vecs [:, m * self .Ds : (m + 1 ) * self .Ds ]
92
- self .codewords [m ], _ = kmeans2 (vecs_sub , self . Ks , iter = iter , minit = "points" )
93
-
118
+ vecs_sub = vecs [:, m * self .Ds : (m + 1 ) * self .Ds ]
119
+ self .codewords [m ], _ = kmeans2 (
120
+ vecs_sub , self . Ks , iter = iter , minit = self . minit )
94
121
return self
95
122
96
123
def encode (self , vecs ):
@@ -167,10 +194,11 @@ def dtable(self, query):
167
194
# dtable[m][ks] : distance between m-th subvec and ks-th codeword of m-th codewords
168
195
dtable = np .empty ((self .M , self .Ks ), dtype = np .float32 )
169
196
for m in range (self .M ):
170
- query_sub = query [m * self .Ds : (m + 1 ) * self .Ds ]
171
- dtable [m , :] = np .linalg .norm (self .codewords [m ] - query_sub , axis = 1 ) ** 2
197
+ query_sub = query [m * self .Ds : (m + 1 ) * self .Ds ]
198
+ dtable [m , :] = metric_function_map [self .metric ](
199
+ query_sub , self .codewords [m ])
172
200
173
- return DistanceTable (dtable )
201
+ return DistanceTable (dtable , D = D , metric = self . metric )
174
202
175
203
176
204
class DistanceTable (object ):
@@ -183,6 +211,7 @@ class DistanceTable(object):
183
211
Args:
184
212
dtable (np.ndarray): Distance table with shape=(M, Ks) and dtype=np.float32
185
213
computed by :func:`PQ.dtable` or :func:`OPQ.dtable`
214
+ metric (str): metric type to calculate distance
186
215
187
216
Attributes:
188
217
dtable (np.ndarray): Distance table with shape=(M, Ks) and dtype=np.float32.
@@ -191,10 +220,13 @@ class DistanceTable(object):
191
220
192
221
"""
193
222
194
- def __init__ (self , dtable ):
223
+ def __init__ (self , dtable , D , metric = 'l2' ):
195
224
assert dtable .ndim == 2
196
225
assert dtable .dtype == np .float32
226
+ assert metric in ['l2' , 'dot' , 'angular' ]
197
227
self .dtable = dtable
228
+ self .metric = metric
229
+ self .D = D
198
230
199
231
def adist (self , codes ):
200
232
"""Given PQ-codes, compute Asymmetric Distances between the query (self.dtable)
@@ -215,6 +247,8 @@ def adist(self, codes):
215
247
216
248
# Fetch distance values using codes. The following codes are
217
249
dists = np .sum (self .dtable [range (M ), codes ], axis = 1 )
250
+ if self .metric == 'angular' :
251
+ dists = 1 - dists
218
252
219
253
# The above line is equivalent to the followings:
220
254
# dists = np.zeros((N, )).astype(np.float32)
0 commit comments