Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions contrib/torch/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# the kmeans can produce both torch and numpy centroids
from faiss.contrib.clustering import kmeans


class DatasetAssign:
"""Wrapper for a tensor that offers a function to assign the vectors
to centroids. All other implementations offer the same interface"""
Expand Down
61 changes: 52 additions & 9 deletions contrib/torch/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,47 @@
This contrib module contains Pytorch code for quantization.
"""

import numpy as np
import torch
import faiss

from faiss.contrib import torch_utils
import math
from faiss.contrib.torch import clustering
# the kmeans can produce both torch and numpy centroids


class Quantizer:

def __init__(self, d, code_size):
"""
d: dimension of vectors
code_size: nb of bytes of the code (per vector)
"""
self.d = d
self.code_size = code_size

def train(self, x):
"""
takes a n-by-d array and peforms training
"""
pass

def encode(self, x):
"""
takes a n-by-d float array, encodes to an n-by-code_size uint8 array
"""
pass

def decode(self, x):
def decode(self, codes):
"""
takes a n-by-code_size uint8 array, returns a n-by-d array
"""
pass


class VectorQuantizer(Quantizer):

def __init__(self, d, k):
code_size = int(torch.ceil(torch.log2(k) / 8))

code_size = int(math.ceil(torch.log2(k) / 8))
Quantizer.__init__(d, code_size)
self.k = k

Expand All @@ -42,12 +56,41 @@ def train(self, x):


class ProductQuantizer(Quantizer):

def __init__(self, d, M, nbits):
code_size = int(torch.ceil(M * nbits / 8))
Quantizer.__init__(d, code_size)
""" M: number of subvectors, d%M == 0
nbits: number of bits that each vector is encoded into
"""
assert d % M == 0
assert nbits == 8 # todo: implement other nbits values
code_size = int(math.ceil(M * nbits / 8))
Quantizer.__init__(self, d, code_size)
self.M = M
self.nbits = nbits
self.code_size = code_size

def train(self, x):
pass
nc = 2 ** self.nbits
sd = self.d // self.M
dev = x.device
dtype = x.dtype
self.codebook = torch.zeros((self.M, nc, sd), device=dev, dtype=dtype)
for m in range(self.M):
xsub = x[:, m * self.d // self.M: (m + 1) * self.d // self.M]
data = clustering.DatasetAssign(xsub.contiguous())
self.codebook[m] = clustering.kmeans(2 ** self.nbits, data)

def encode(self, x):
codes = torch.zeros((x.shape[0], self.code_size), dtype=torch.uint8)
for m in range(self.M):
xsub = x[:, m * self.d // self.M:(m + 1) * self.d // self.M]
_, I = faiss.knn(xsub.contiguous(), self.codebook[m], 1)
codes[:, m] = I.ravel()
return codes

def decode(self, codes):
idxs = [codes[:, m].long() for m in range(self.M)]
vectors = [self.codebook[m, idxs[m], :] for m in range(self.M)]
stacked_vectors = torch.stack(vectors, dim=1)
cbd = self.codebook.shape[-1]
x_rec = stacked_vectors.reshape(-1, cbd * self.M)
return x_rec
33 changes: 29 additions & 4 deletions tests/torch_test_contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# LICENSE file in the root directory of this source tree.

import torch # usort: skip
import unittest # usort: skip
import numpy as np # usort: skip
import unittest # usort: skip
import numpy as np # usort: skip

import faiss # usort: skip
import faiss # usort: skip
import faiss.contrib.torch_utils # usort: skip
from faiss.contrib import datasets
from faiss.contrib.torch import clustering
from faiss.contrib.torch import clustering, quantization




Expand Down Expand Up @@ -400,3 +401,27 @@ def test_python_kmeans(self):
# 33498.332 33380.477
# print(err, err2) 1/0
self.assertLess(err2, err * 1.1)


class TestQuantization(unittest.TestCase):
def test_python_product_quantization(self):
""" Test the python implementation of product quantization """
d = 64
n = 10000
cs = 4
nbits = 8
M = 4
x = np.random.random(size=(n, d)).astype('float32')
pq = faiss.ProductQuantizer(d, cs, nbits)
pq.train(x)
codes = pq.compute_codes(x)
x2 = pq.decode(codes)
diff = ((x - x2)**2).sum()
# vs pure pytorch impl
xt = torch.from_numpy(x)
my_pq = quantization.ProductQuantizer(d, M, nbits)
my_pq.train(xt)
my_codes = my_pq.encode(xt)
xt2 = my_pq.decode(my_codes)
my_diff = ((xt - xt2)**2).sum()
self.assertLess(abs(diff - my_diff), 100)
Loading