diff --git a/tests/test_io.py b/tests/test_io.py index d2871f2eee..3cbd0a6e10 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -12,8 +12,14 @@ import sys import pickle from multiprocessing.pool import ThreadPool +from common_faiss_tests import get_dataset_2 +d = 32 +nt = 2000 +nb = 1000 +nq = 200 + class TestIOVariants(unittest.TestCase): def test_io_error(self): @@ -338,6 +344,113 @@ def test_read_vector_transform(self): os.unlink(fname) +class Test_IO_PQ(unittest.TestCase): + """ + test read and write PQ. + """ + def test_io_pq(self): + xt, xb, xq = get_dataset_2(d, nt, nb, nq) + index = faiss.IndexPQ(d, 4, 4) + index.train(xt) + + fd, fname = tempfile.mkstemp() + os.close(fd) + + try: + faiss.write_ProductQuantizer(index.pq, fname) + + read_pq = faiss.read_ProductQuantizer(fname) + + self.assertEqual(index.pq.M, read_pq.M) + self.assertEqual(index.pq.nbits, read_pq.nbits) + self.assertEqual(index.pq.dsub, read_pq.dsub) + self.assertEqual(index.pq.ksub, read_pq.ksub) + np.testing.assert_array_equal( + faiss.vector_to_array(index.pq.centroids), + faiss.vector_to_array(read_pq.centroids) + ) + + finally: + if os.path.exists(fname): + os.unlink(fname) + + +class Test_IO_IndexLSH(unittest.TestCase): + """ + test read and write IndexLSH. + """ + def test_io_lsh(self): + xt, xb, xq = get_dataset_2(d, nt, nb, nq) + index_lsh = faiss.IndexLSH(d, 32, True, True) + index_lsh.train(xt) + index_lsh.add(xb) + D, I = index_lsh.search(xq, 10) + + fd, fname = tempfile.mkstemp() + os.close(fd) + + try: + faiss.write_index(index_lsh, fname) + + reader = faiss.BufferedIOReader( + faiss.FileIOReader(fname), 1234) + read_index_lsh = faiss.read_index(reader) + # Delete reader to prevent [WinError 32] The process cannot + # access the file because it is being used by another process + del reader + + self.assertEqual(index_lsh.d, read_index_lsh.d) + np.testing.assert_array_equal( + faiss.vector_to_array(index_lsh.codes), + faiss.vector_to_array(read_index_lsh.codes) + ) + D_read, I_read = read_index_lsh.search(xq, 10) + + np.testing.assert_array_equal(D, D_read) + np.testing.assert_array_equal(I, I_read) + + finally: + if os.path.exists(fname): + os.unlink(fname) + + +class Test_IO_IndexIVFSpectralHash(unittest.TestCase): + """ + test read and write IndexIVFSpectralHash. + """ + def test_io_ivf_spectral_hash(self): + nlist = 1000 + xt, xb, xq = get_dataset_2(d, nt, nb, nq) + quantizer = faiss.IndexFlatL2(d) + index = faiss.IndexIVFSpectralHash(quantizer, d, nlist, 8, 1.0) + index.train(xt) + index.add(xb) + D, I = index.search(xq, 10) + + fd, fname = tempfile.mkstemp() + os.close(fd) + + try: + faiss.write_index(index, fname) + + reader = faiss.BufferedIOReader( + faiss.FileIOReader(fname), 1234) + read_index = faiss.read_index(reader) + del reader + + self.assertEqual(index.d, read_index.d) + self.assertEqual(index.nbit, read_index.nbit) + self.assertEqual(index.period, read_index.period) + self.assertEqual(index.threshold_type, read_index.threshold_type) + + D_read, I_read = read_index.search(xq, 10) + np.testing.assert_array_equal(D, D_read) + np.testing.assert_array_equal(I, I_read) + + finally: + if os.path.exists(fname): + os.unlink(fname) + class TestIVFPQRead(unittest.TestCase): def test_reader(self): d, n = 32, 1000