From d969373dc7e346eabba5eded1cf306dd9cf67a48 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Tue, 13 Feb 2024 13:18:04 -0800 Subject: [PATCH] Skip PQ sdc init with new io flag Add new IO flag, IO_FLAG_PQ_SKIP_SDC_TABLE, so that when reading HNSWPQ from disk, it will skip building the sdc table. sdc table is only used during graph construction, so if this flag is set, the HNSWPQ index will not be updateable. In addition, adds cpp test case verifying functionality and build test util header file to share creation of temporary files amongst tests. Signed-off-by: John Mazanec --- faiss/impl/index_read.cpp | 2 +- faiss/index_io.h | 6 ++++ tests/CMakeLists.txt | 1 + tests/test_io.cpp | 61 +++++++++++++++++++++++++++++++++++++++ tests/test_merge.cpp | 35 ++++------------------ tests/test_util.h | 39 +++++++++++++++++++++++++ 6 files changed, 113 insertions(+), 31 deletions(-) create mode 100644 tests/test_io.cpp create mode 100644 tests/test_util.h diff --git a/faiss/impl/index_read.cpp b/faiss/impl/index_read.cpp index ac62e0269e..8d80329bf9 100644 --- a/faiss/impl/index_read.cpp +++ b/faiss/impl/index_read.cpp @@ -962,7 +962,7 @@ Index* read_index(IOReader* f, int io_flags) { read_HNSW(&idxhnsw->hnsw, f); idxhnsw->storage = read_index(f, io_flags); idxhnsw->own_fields = true; - if (h == fourcc("IHNp")) { + if (h == fourcc("IHNp") && !(io_flags & IO_FLAG_PQ_SKIP_SDC_TABLE)) { dynamic_cast(idxhnsw->storage)->pq.compute_sdc_table(); } idx = idxhnsw; diff --git a/faiss/index_io.h b/faiss/index_io.h index 8d52ee1afd..f73cd073b7 100644 --- a/faiss/index_io.h +++ b/faiss/index_io.h @@ -52,6 +52,12 @@ const int IO_FLAG_ONDISK_SAME_DIR = 4; const int IO_FLAG_SKIP_IVF_DATA = 8; // don't initialize precomputed table after loading const int IO_FLAG_SKIP_PRECOMPUTE_TABLE = 16; +// don't compute the sdc table for PQ-based indices +// this will prevent distances from being computed +// between elements in the index. For indices like HNSWPQ, +// this will prevent graph building because sdc +// computations are required to construct the graph +const int IO_FLAG_PQ_SKIP_SDC_TABLE = 32; // try to memmap data (useful to load an ArrayInvertedLists as an // OnDiskInvertedLists) const int IO_FLAG_MMAP = IO_FLAG_SKIP_IVF_DATA | 0x646f0000; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 8522fa613d..73bc50b25e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -32,6 +32,7 @@ set(FAISS_TEST_SRC test_hnsw.cpp test_partitioning.cpp test_fastscan_perf.cpp + test_io.cpp ) add_executable(faiss_test ${FAISS_TEST_SRC}) diff --git a/tests/test_io.cpp b/tests/test_io.cpp new file mode 100644 index 0000000000..f423243d4e --- /dev/null +++ b/tests/test_io.cpp @@ -0,0 +1,61 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include "faiss/Index.h" +#include "faiss/IndexHNSW.h" +#include "faiss/index_factory.h" +#include "faiss/index_io.h" +#include "test_util.h" + +pthread_mutex_t temp_file_mutex = PTHREAD_MUTEX_INITIALIZER; + +TEST(IO, TestReadHNSWPQ_whenSDCDisabledFlagPassed_thenDisableSDCTable) { + Tempfilename index_filename(&temp_file_mutex, "/tmp/faiss_TestReadHNSWPQ"); + int d = 32, n = 256; + std::default_random_engine rng(123); + std::uniform_real_distribution u(0, 100); + std::vector vectors(n * d); + for (size_t i = 0; i < n * d; i++) { + vectors[i] = u(rng); + } + + // Build the index and write it to the temp file + { + std::unique_ptr index_writer( + faiss::index_factory(d, "HNSW8,PQ4", faiss::METRIC_L2)); + index_writer->train(n, vectors.data()); + index_writer->add(n, vectors.data()); + + faiss::write_index(index_writer.get(), index_filename.c_str()); + } + + // Load index from disk. Confirm that the sdc table is equal to 0 when + // disable sdc is set + { + std::unique_ptr index_reader_read_write( + dynamic_cast( + faiss::read_index(index_filename.c_str()))); + std::unique_ptr index_reader_sdc_disabled( + dynamic_cast(faiss::read_index( + index_filename.c_str(), + faiss::IO_FLAG_PQ_SKIP_SDC_TABLE))); + + ASSERT_NE( + dynamic_cast(index_reader_read_write->storage) + ->pq.sdc_table.size(), + 0); + ASSERT_EQ( + dynamic_cast( + index_reader_sdc_disabled->storage) + ->pq.sdc_table.size(), + 0); + } +} diff --git a/tests/test_merge.cpp b/tests/test_merge.cpp index 7e23f15f72..5a1d08cfba 100644 --- a/tests/test_merge.cpp +++ b/tests/test_merge.cpp @@ -6,47 +6,22 @@ */ #include -#include #include -#include - #include #include #include #include -#include #include #include #include -namespace { - -struct Tempfilename { - static pthread_mutex_t mutex; - - std::string filename = "/tmp/faiss_tmp_XXXXXX"; - - Tempfilename() { - pthread_mutex_lock(&mutex); - int fd = mkstemp(&filename[0]); - close(fd); - pthread_mutex_unlock(&mutex); - } - - ~Tempfilename() { - if (access(filename.c_str(), F_OK)) { - unlink(filename.c_str()); - } - } +#include "test_util.h" - const char* c_str() { - return filename.c_str(); - } -}; +namespace { -pthread_mutex_t Tempfilename::mutex = PTHREAD_MUTEX_INITIALIZER; +pthread_mutex_t temp_file_mutex = PTHREAD_MUTEX_INITIALIZER; typedef faiss::idx_t idx_t; @@ -95,7 +70,7 @@ int compare_merged( std::vector refD(k * nq); index_shards->search(nq, cd.queries.data(), k, refD.data(), refI.data()); - Tempfilename filename; + Tempfilename filename(&temp_file_mutex, "/tmp/faiss_tmp_XXXXXX"); std::vector newI(k * nq); std::vector newD(k * nq); @@ -212,7 +187,7 @@ TEST(MERGE, merge_flat_vt) { TEST(MERGE, merge_flat_ondisk) { faiss::IndexShards index_shards(d, false, false); index_shards.own_indices = true; - Tempfilename filename; + Tempfilename filename(&temp_file_mutex, "/tmp/faiss_tmp_XXXXXX"); for (int i = 0; i < nindex; i++) { auto ivf = new faiss::IndexIVFFlat(&cd.quantizer, d, nlist); diff --git a/tests/test_util.h b/tests/test_util.h new file mode 100644 index 0000000000..3be0e35cff --- /dev/null +++ b/tests/test_util.h @@ -0,0 +1,39 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef FAISS_TEST_UTIL_H +#define FAISS_TEST_UTIL_H + +#include +#include +#include + +struct Tempfilename { + pthread_mutex_t* mutex; + std::string filename; + + Tempfilename(pthread_mutex_t* mutex, std::string filename) { + this->mutex = mutex; + this->filename = filename; + pthread_mutex_lock(mutex); + int fd = mkstemp(&filename[0]); + close(fd); + pthread_mutex_unlock(mutex); + } + + ~Tempfilename() { + if (access(filename.c_str(), F_OK)) { + unlink(filename.c_str()); + } + } + + const char* c_str() { + return filename.c_str(); + } +}; + +#endif // FAISS_TEST_UTIL_H