Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,14 @@
import sys
import pickle
from multiprocessing.pool import ThreadPool
from common_faiss_tests import get_dataset_2


d = 32
nt = 2000
nb = 1000
nq = 200

class TestIOVariants(unittest.TestCase):

def test_io_error(self):
Expand Down Expand Up @@ -338,6 +344,113 @@ def test_read_vector_transform(self):
os.unlink(fname)


class Test_IO_PQ(unittest.TestCase):
"""
test read and write PQ.
"""
def test_io_pq(self):
xt, xb, xq = get_dataset_2(d, nt, nb, nq)
index = faiss.IndexPQ(d, 4, 4)
index.train(xt)

fd, fname = tempfile.mkstemp()
os.close(fd)

try:
faiss.write_ProductQuantizer(index.pq, fname)

read_pq = faiss.read_ProductQuantizer(fname)

self.assertEqual(index.pq.M, read_pq.M)
self.assertEqual(index.pq.nbits, read_pq.nbits)
self.assertEqual(index.pq.dsub, read_pq.dsub)
self.assertEqual(index.pq.ksub, read_pq.ksub)
np.testing.assert_array_equal(
faiss.vector_to_array(index.pq.centroids),
faiss.vector_to_array(read_pq.centroids)
)

finally:
if os.path.exists(fname):
os.unlink(fname)


class Test_IO_IndexLSH(unittest.TestCase):
"""
test read and write IndexLSH.
"""
def test_io_lsh(self):
xt, xb, xq = get_dataset_2(d, nt, nb, nq)
index_lsh = faiss.IndexLSH(d, 32, True, True)
index_lsh.train(xt)
index_lsh.add(xb)
D, I = index_lsh.search(xq, 10)

fd, fname = tempfile.mkstemp()
os.close(fd)

try:
faiss.write_index(index_lsh, fname)

reader = faiss.BufferedIOReader(
faiss.FileIOReader(fname), 1234)
read_index_lsh = faiss.read_index(reader)
# Delete reader to prevent [WinError 32] The process cannot
# access the file because it is being used by another process
del reader

self.assertEqual(index_lsh.d, read_index_lsh.d)
np.testing.assert_array_equal(
faiss.vector_to_array(index_lsh.codes),
faiss.vector_to_array(read_index_lsh.codes)
)
D_read, I_read = read_index_lsh.search(xq, 10)

np.testing.assert_array_equal(D, D_read)
np.testing.assert_array_equal(I, I_read)

finally:
if os.path.exists(fname):
os.unlink(fname)


class Test_IO_IndexIVFSpectralHash(unittest.TestCase):
"""
test read and write IndexIVFSpectralHash.
"""
def test_io_ivf_spectral_hash(self):
nlist = 1000
xt, xb, xq = get_dataset_2(d, nt, nb, nq)
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFSpectralHash(quantizer, d, nlist, 8, 1.0)
index.train(xt)
index.add(xb)
D, I = index.search(xq, 10)

fd, fname = tempfile.mkstemp()
os.close(fd)

try:
faiss.write_index(index, fname)

reader = faiss.BufferedIOReader(
faiss.FileIOReader(fname), 1234)
read_index = faiss.read_index(reader)
del reader

self.assertEqual(index.d, read_index.d)
self.assertEqual(index.nbit, read_index.nbit)
self.assertEqual(index.period, read_index.period)
self.assertEqual(index.threshold_type, read_index.threshold_type)

D_read, I_read = read_index.search(xq, 10)
np.testing.assert_array_equal(D, D_read)
np.testing.assert_array_equal(I, I_read)

finally:
if os.path.exists(fname):
os.unlink(fname)

class TestIVFPQRead(unittest.TestCase):
def test_reader(self):
d, n = 32, 1000
Expand Down