diff --git a/faiss/cppcontrib/factory_tools.cpp b/faiss/cppcontrib/factory_tools.cpp new file mode 100644 index 0000000000..76cbe73efe --- /dev/null +++ b/faiss/cppcontrib/factory_tools.cpp @@ -0,0 +1,152 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include +#include + +namespace faiss { + +namespace { + +const std::map sq_types = { + {faiss::ScalarQuantizer::QT_8bit, "SQ8"}, + {faiss::ScalarQuantizer::QT_4bit, "SQ4"}, + {faiss::ScalarQuantizer::QT_6bit, "SQ6"}, + {faiss::ScalarQuantizer::QT_fp16, "SQfp16"}, + {faiss::ScalarQuantizer::QT_bf16, "SQbf16"}, + {faiss::ScalarQuantizer::QT_8bit_direct_signed, "SQ8_direct_signed"}, + {faiss::ScalarQuantizer::QT_8bit_direct, "SQ8_direct"}, +}; + +int get_hnsw_M(const faiss::IndexHNSW* index) { + if (index->hnsw.cum_nneighbor_per_level.size() >= 1) { + return index->hnsw.cum_nneighbor_per_level[1] / 2; + } + // Avoid runtime error, just return 0. + return 0; +} + +} // namespace + +// Reference for reverse_index_factory: +// https://github.com/facebookresearch/faiss/blob/838612c9d7f2f619811434ec9209c020f44107cb/contrib/factory_tools.py#L81 +std::string reverse_index_factory(const faiss::Index* index) { + std::string prefix; + if (dynamic_cast(index)) { + return "Flat"; + } else if ( + const faiss::IndexIVF* ivf_index = + dynamic_cast(index)) { + const faiss::Index* quantizer = ivf_index->quantizer; + + if (dynamic_cast(quantizer)) { + prefix = "IVF" + std::to_string(ivf_index->nlist); + } else if ( + const faiss::MultiIndexQuantizer* miq = + dynamic_cast( + quantizer)) { + prefix = "IMI" + std::to_string(miq->pq.M) + "x" + + std::to_string(miq->pq.nbits); + } else if ( + const faiss::IndexHNSW* hnsw_index = + dynamic_cast(quantizer)) { + prefix = "IVF" + std::to_string(ivf_index->nlist) + "_HNSW" + + std::to_string(get_hnsw_M(hnsw_index)); + } else { + prefix = "IVF" + std::to_string(ivf_index->nlist) + "(" + + reverse_index_factory(quantizer) + ")"; + } + + if (dynamic_cast(ivf_index)) { + return prefix + ",Flat"; + } else if ( + auto sq_index = + dynamic_cast( + ivf_index)) { + return prefix + "," + sq_types.at(sq_index->sq.qtype); + } else if ( + const faiss::IndexIVFPQ* ivfpq_index = + dynamic_cast(ivf_index)) { + return prefix + ",PQ" + std::to_string(ivfpq_index->pq.M) + "x" + + std::to_string(ivfpq_index->pq.nbits); + } else if ( + const faiss::IndexIVFPQFastScan* ivfpqfs_index = + dynamic_cast( + ivf_index)) { + return prefix + ",PQ" + std::to_string(ivfpqfs_index->pq.M) + "x" + + std::to_string(ivfpqfs_index->pq.nbits) + "fs"; + } + } else if ( + const faiss::IndexPreTransform* pretransform_index = + dynamic_cast(index)) { + if (pretransform_index->chain.size() != 1) { + // Avoid runtime error, just return empty string for logging. + return ""; + } + const faiss::VectorTransform* vt = pretransform_index->chain.at(0); + if (const faiss::OPQMatrix* opq_matrix = + dynamic_cast(vt)) { + prefix = "OPQ" + std::to_string(opq_matrix->M) + "_" + + std::to_string(opq_matrix->d_out); + } else if ( + const faiss::ITQTransform* itq_transform = + dynamic_cast(vt)) { + prefix = "ITQ" + std::to_string(itq_transform->itq.d_out); + } else if ( + const faiss::PCAMatrix* pca_matrix = + dynamic_cast(vt)) { + assert(pca_matrix->eigen_power == 0); + prefix = "PCA" + + std::string(pca_matrix->random_rotation ? "R" : "") + + std::to_string(pca_matrix->d_out); + } else { + // Avoid runtime error, just return empty string for logging. + return ""; + } + return prefix + "," + reverse_index_factory(pretransform_index->index); + } else if ( + const faiss::IndexHNSW* hnsw_index = + dynamic_cast(index)) { + return "HNSW" + std::to_string(get_hnsw_M(hnsw_index)); + } else if ( + const faiss::IndexRefine* refine_index = + dynamic_cast(index)) { + return reverse_index_factory(refine_index->base_index) + ",Refine(" + + reverse_index_factory(refine_index->refine_index) + ")"; + } else if ( + const faiss::IndexPQFastScan* pqfs_index = + dynamic_cast(index)) { + return std::string("PQ") + std::to_string(pqfs_index->pq.M) + "x" + + std::to_string(pqfs_index->pq.nbits) + "fs"; + } else if ( + const faiss::IndexPQ* pq_index = + dynamic_cast(index)) { + return std::string("PQ") + std::to_string(pq_index->pq.M) + "x" + + std::to_string(pq_index->pq.nbits); + } else if ( + const faiss::IndexLSH* lsh_index = + dynamic_cast(index)) { + std::string result = "LSH"; + if (lsh_index->rotate_data) { + result += "r"; + } + if (lsh_index->train_thresholds) { + result += "t"; + } + return result; + } else if ( + const faiss::IndexScalarQuantizer* sq_index = + dynamic_cast(index)) { + return std::string("SQ") + sq_types.at(sq_index->sq.qtype); + } + // Avoid runtime error, just return empty string for logging. + return ""; +} + +} // namespace faiss diff --git a/faiss/cppcontrib/factory_tools.h b/faiss/cppcontrib/factory_tools.h new file mode 100644 index 0000000000..bde4971854 --- /dev/null +++ b/faiss/cppcontrib/factory_tools.h @@ -0,0 +1,24 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { + +std::string reverse_index_factory(const faiss::Index* index); + +} // namespace faiss diff --git a/tests/test_factory_tools.cpp b/tests/test_factory_tools.cpp new file mode 100644 index 0000000000..c8d7f21c3f --- /dev/null +++ b/tests/test_factory_tools.cpp @@ -0,0 +1,54 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include + +using namespace faiss; + +TEST(TestFactoryTools, TestReverseIndexFactory) { + auto factory_string = "Flat"; + auto index = faiss::index_factory(64, factory_string); + EXPECT_EQ(factory_string, reverse_index_factory(index)); + delete index; + + factory_string = "IMI2x5,PQ8x8"; + index = faiss::index_factory(32, factory_string); + EXPECT_EQ(factory_string, reverse_index_factory(index)); + delete index; + + factory_string = "IVF32_HNSW32,SQ8"; + index = faiss::index_factory(64, factory_string); + EXPECT_EQ(factory_string, reverse_index_factory(index)); + delete index; + + factory_string = "IVF8,Flat"; + index = faiss::index_factory(64, factory_string); + EXPECT_EQ(factory_string, reverse_index_factory(index)); + delete index; + + factory_string = "IVF8,SQ4"; + index = faiss::index_factory(64, factory_string); + EXPECT_EQ(factory_string, reverse_index_factory(index)); + delete index; + + factory_string = "IVF8,PQ4x8"; + index = faiss::index_factory(64, factory_string); + EXPECT_EQ(factory_string, reverse_index_factory(index)); + delete index; + + factory_string = "LSHrt"; + index = faiss::index_factory(64, factory_string); + EXPECT_EQ(factory_string, reverse_index_factory(index)); + delete index; + + factory_string = "PQ4x8"; + index = faiss::index_factory(64, factory_string); + EXPECT_EQ(factory_string, reverse_index_factory(index)); + delete index; + + factory_string = "HNSW32"; + index = faiss::index_factory(64, factory_string); + EXPECT_EQ(factory_string, reverse_index_factory(index)); + delete index; +}