Skip to content

Commit 7651d88

Browse files
authored
Merge pull request #24 from mpskex/main
add different metric type and initialization methods to nanopq
2 parents 8d1cc43 + 48c60c9 commit 7651d88

File tree

4 files changed

+87
-14
lines changed

4 files changed

+87
-14
lines changed

nanopq/convert_faiss.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
from .opq import OPQ
1313
from .pq import PQ
1414

15+
faiss_metric_map = {
16+
'l2': faiss.METRIC_L2,
17+
'dot': faiss.METRIC_INNER_PRODUCT,
18+
'angular': faiss.METRIC_INNER_PRODUCT
19+
}
1520

1621
def nanopq_to_faiss(pq_nanopq):
1722
"""Convert a :class:`nanopq.PQ` instance to `faiss.IndexPQ <https://github.com/facebookresearch/faiss/blob/master/IndexPQ.h>`_.
@@ -31,7 +36,7 @@ def nanopq_to_faiss(pq_nanopq):
3136
D = pq_nanopq.Ds * pq_nanopq.M
3237
nbits = {np.uint8: 8, np.uint16: 16, np.uint32: 32}[pq_nanopq.code_dtype]
3338

34-
pq_faiss = faiss.IndexPQ(D, pq_nanopq.M, nbits)
39+
pq_faiss = faiss.IndexPQ(D, pq_nanopq.M, nbits, faiss_metric_map[pq_nanopq.metric])
3540

3641
for m in range(pq_nanopq.M):
3742
# Prepare std::vector<float>

nanopq/opq.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class OPQ(object):
2626
2727
"""
2828

29-
def __init__(self, M, Ks=256, verbose=True):
30-
self.pq = PQ(M, Ks, verbose)
29+
def __init__(self, M, Ks=256, metric='l2', minit='random', verbose=True):
30+
self.pq = PQ(M, Ks, metric=metric, minit=minit, verbose=verbose)
3131
self.R = None
3232

3333
def __eq__(self, other):

nanopq/pq.py

+45-11
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,27 @@
1+
import warnings
12
import numpy as np
23
from scipy.cluster.vq import kmeans2, vq
34

45

6+
def dist_l2(q, x):
7+
return np.linalg.norm(q - x, ord=2, axis=1) ** 2
8+
9+
10+
def dist_ip(q, x):
11+
return np.matmul(x, q[None, :].T).sum(axis=-1)
12+
13+
14+
def dist_angular(q, x):
15+
return dist_ip(q, x)
16+
17+
18+
metric_function_map = {
19+
'l2': dist_l2,
20+
'angular': dist_angular,
21+
'dot': dist_ip
22+
}
23+
24+
525
class PQ(object):
626
"""Pure python implementation of Product Quantization (PQ) [Jegou11]_.
727
@@ -19,12 +39,14 @@ class PQ(object):
1939
M (int): The number of sub-space
2040
Ks (int): The number of codewords for each subspace
2141
(typically 256, so that each sub-vector is quantized
22-
into 8 bits = 1 byte = uint8)
42+
into 256 bits = 1 byte = uint8)
43+
metric (str): Type of metric used among vectors
2344
verbose (bool): Verbose flag
2445
2546
Attributes:
2647
M (int): The number of sub-space
2748
Ks (int): The number of codewords for each subspace
49+
metric (str): Type of metric used among vectors
2850
verbose (bool): Verbose flag
2951
code_dtype (object): dtype of PQ-code. Either np.uint{8, 16, 32}
3052
codewords (np.ndarray): shape=(M, Ks, Ds) with dtype=np.float32.
@@ -33,17 +55,22 @@ class PQ(object):
3355
3456
"""
3557

36-
def __init__(self, M, Ks=256, verbose=True):
58+
def __init__(self, M, Ks=256, metric='l2', minit='random', verbose=True):
3759
assert 0 < Ks <= 2 ** 32
38-
self.M, self.Ks, self.verbose = M, Ks, verbose
60+
assert metric in ['l2', 'dot', 'angular']
61+
assert minit in ['random', '++', 'points', 'matrix']
62+
self.M, self.Ks, self.verbose, self.metric = M, Ks, verbose, metric
3963
self.code_dtype = (
4064
np.uint8 if Ks <= 2 ** 8 else (np.uint16 if Ks <= 2 ** 16 else np.uint32)
4165
)
4266
self.codewords = None
4367
self.Ds = None
68+
self.metric = metric
69+
self.minit = minit
4470

4571
if verbose:
46-
print("M: {}, Ks: {}, code_dtype: {}".format(M, Ks, self.code_dtype))
72+
print("M: {}, Ks: {}, metric : {}, code_dtype: {} minit: {}".format(
73+
M, Ks, self.code_dtype, metric, minit))
4774

4875
def __eq__(self, other):
4976
if isinstance(other, PQ):
@@ -88,9 +115,9 @@ def fit(self, vecs, iter=20, seed=123):
88115
for m in range(self.M):
89116
if self.verbose:
90117
print("Training the subspace: {} / {}".format(m, self.M))
91-
vecs_sub = vecs[:, m * self.Ds : (m + 1) * self.Ds]
92-
self.codewords[m], _ = kmeans2(vecs_sub, self.Ks, iter=iter, minit="points")
93-
118+
vecs_sub = vecs[:, m * self.Ds: (m + 1) * self.Ds]
119+
self.codewords[m], _ = kmeans2(
120+
vecs_sub, self.Ks, iter=iter, minit=self.minit)
94121
return self
95122

96123
def encode(self, vecs):
@@ -167,10 +194,11 @@ def dtable(self, query):
167194
# dtable[m][ks] : distance between m-th subvec and ks-th codeword of m-th codewords
168195
dtable = np.empty((self.M, self.Ks), dtype=np.float32)
169196
for m in range(self.M):
170-
query_sub = query[m * self.Ds : (m + 1) * self.Ds]
171-
dtable[m, :] = np.linalg.norm(self.codewords[m] - query_sub, axis=1) ** 2
197+
query_sub = query[m * self.Ds: (m + 1) * self.Ds]
198+
dtable[m, :] = metric_function_map[self.metric](
199+
query_sub, self.codewords[m])
172200

173-
return DistanceTable(dtable)
201+
return DistanceTable(dtable, D=D, metric=self.metric)
174202

175203

176204
class DistanceTable(object):
@@ -183,6 +211,7 @@ class DistanceTable(object):
183211
Args:
184212
dtable (np.ndarray): Distance table with shape=(M, Ks) and dtype=np.float32
185213
computed by :func:`PQ.dtable` or :func:`OPQ.dtable`
214+
metric (str): metric type to calculate distance
186215
187216
Attributes:
188217
dtable (np.ndarray): Distance table with shape=(M, Ks) and dtype=np.float32.
@@ -191,10 +220,13 @@ class DistanceTable(object):
191220
192221
"""
193222

194-
def __init__(self, dtable):
223+
def __init__(self, dtable, D, metric='l2'):
195224
assert dtable.ndim == 2
196225
assert dtable.dtype == np.float32
226+
assert metric in ['l2', 'dot', 'angular']
197227
self.dtable = dtable
228+
self.metric = metric
229+
self.D = D
198230

199231
def adist(self, codes):
200232
"""Given PQ-codes, compute Asymmetric Distances between the query (self.dtable)
@@ -215,6 +247,8 @@ def adist(self, codes):
215247

216248
# Fetch distance values using codes. The following codes are
217249
dists = np.sum(self.dtable[range(M), codes], axis=1)
250+
if self.metric == 'angular':
251+
dists = 1 - dists
218252

219253
# The above line is equivalent to the followings:
220254
# dists = np.zeros((N, )).astype(np.float32)

tests/test_pq.py

+34
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,40 @@ def test_pickle(self):
9797
)
9898
self.assertTrue(np.allclose(pq.codewords, pq2.codewords))
9999
self.assertTrue(pq == pq2)
100+
101+
def test_ip(self):
102+
N, D, M, Ks = 100, 12, 4, 10
103+
X = np.random.random((N, D)).astype(np.float32)
104+
pq = nanopq.PQ(M=M, Ks=Ks, metric='dot')
105+
pq.fit(X)
106+
X_ = pq.encode(X)
107+
q = X[13]
108+
dist1 = pq.dtable(q).adist(X_)
109+
dtable = np.empty((pq.M, pq.Ks), dtype=np.float32)
110+
for m in range(pq.M):
111+
query_sub = q[m * pq.Ds : (m + 1) * pq.Ds]
112+
dtable[m, :] = np.matmul(pq.codewords[m], query_sub[None, :].T).sum(axis=-1)
113+
dist2 = np.sum(dtable[range(M), X_], axis=1)
114+
self.assertTrue((dist1 == dist2).all())
115+
self.assertTrue(abs(np.mean(np.matmul(X, q[:, None]).squeeze() - dist1)) < 1e-7)
116+
117+
def test_angular(self):
118+
N, D, M, Ks = 100, 12, 4, 10
119+
X = np.random.random((N, D)).astype(np.float32)
120+
X[np.linalg.norm(X, axis=1) == 0] = 1.0 / np.sqrt(X.shape[1])
121+
X /= np.linalg.norm(X, ord=2, axis=-1)[:, None]
122+
pq = nanopq.PQ(M=M, Ks=Ks, metric='angular')
123+
pq.fit(X)
124+
X_ = pq.encode(X)
125+
q = X[13]
126+
dist1 = pq.dtable(q).adist(X_)
127+
dtable = np.empty((pq.M, pq.Ks), dtype=np.float32)
128+
for m in range(pq.M):
129+
query_sub = q[m * pq.Ds : (m + 1) * pq.Ds]
130+
dtable[m, :] = np.matmul(pq.codewords[m], query_sub[None, :].T).sum(axis=-1)
131+
dist2 = 1 - np.sum(dtable[range(M), X_], axis=1)
132+
self.assertTrue((dist1 == dist2).all())
133+
self.assertTrue(abs(np.mean((1-np.matmul(X, q[:, None]) / (np.linalg.norm(q) * np.linalg.norm(X, ord=2, axis=-1))) - dist1)) < 1e-7)
100134

101135

102136
if __name__ == "__main__":

0 commit comments

Comments
 (0)