diff --git a/src/index/ivf/ivf_config.h b/src/index/ivf/ivf_config.h index 9de9c6231..1ed5caf13 100644 --- a/src/index/ivf/ivf_config.h +++ b/src/index/ivf/ivf_config.h @@ -21,7 +21,7 @@ class IvfConfig : public BaseConfig { CFG_INT nlist; CFG_INT nprobe; CFG_BOOL use_elkan; - CFG_BOOL ensure_topk_full; + CFG_BOOL ensure_topk_full; // only take affect on temp index(IVF_FLAT_CC) now KNOHWERE_DECLARE_CONFIG(IvfConfig) { KNOWHERE_CONFIG_DECLARE_FIELD(nlist) .set_default(128) diff --git a/tests/ut/test_search.cc b/tests/ut/test_search.cc index 56e8f51f7..4ea059ca7 100644 --- a/tests/ut/test_search.cc +++ b/tests/ut/test_search.cc @@ -216,15 +216,21 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { using std::make_tuple; auto ivfflatcc_gen_ = [base_gen, nb]() { knowhere::Json json = base_gen(); - json[knowhere::indexparam::NLIST] = 16; + json[knowhere::indexparam::NLIST] = 32; json[knowhere::indexparam::NPROBE] = 1; json[knowhere::indexparam::SSIZE] = 48; json[knowhere::meta::TOPK] = nb; return json; }; - auto [name, gen] = GENERATE_REF(table>({ - make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen_), - })); + auto ivfflatcc_gen_no_ensure_topk_ = [ivfflatcc_gen_, nb]() { + knowhere::Json json = ivfflatcc_gen_(); + json[knowhere::meta::TOPK] = nb / 2; + json[knowhere::indexparam::ENSURE_TOPK_FULL] = false; + return json; + }; + auto [name, gen] = GENERATE_REF(table>( + {make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen_), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen_no_ensure_topk_)})); auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); @@ -235,7 +241,27 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { auto results = idx.Search(*query_ds, json, nullptr); auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, nullptr); float recall = GetKNNRecall(*gt.value(), *results.value()); - REQUIRE(recall > kBruteForceRecallThreshold); + if (ivfflatcc_gen_().dump() == cfg_json) { + REQUIRE(recall > kBruteForceRecallThreshold); + } else { + REQUIRE(recall < kBruteForceRecallThreshold); + } + + std::vector(size_t, size_t)>> gen_bitset_funcs = { + GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet}; + const auto bitset_percentages = 0.5f; + for (const auto& gen_func : gen_bitset_funcs) { + auto bitset_data = gen_func(nb, bitset_percentages * nb); + knowhere::BitsetView bitset(bitset_data.data(), nb); + auto results = idx.Search(*query_ds, json, bitset); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); + float recall = GetKNNRecall(*gt.value(), *results.value()); + if (ivfflatcc_gen_().dump() == cfg_json) { + REQUIRE(recall > kBruteForceRecallThreshold); + } else { + REQUIRE(recall < kBruteForceRecallThreshold); + } + } } SECTION("Test Search with Bitset") { diff --git a/thirdparty/faiss/faiss/IndexIVF.h b/thirdparty/faiss/faiss/IndexIVF.h index e2334e884..1ee546897 100644 --- a/thirdparty/faiss/faiss/IndexIVF.h +++ b/thirdparty/faiss/faiss/IndexIVF.h @@ -71,7 +71,12 @@ struct Level1Quantizer { struct SearchParametersIVF : SearchParameters { size_t nprobe = 1; ///< number of probes at query time size_t max_codes = 0; ///< max nb of codes to visit to do a query - bool ensure_topk_full = false; ///< indicate whether we make sure topk result is full + ///< indicate whether we should early teriminate before topk results full when search reaches max_codes + ///< to minimize code change, when users only use nprobe to search, this config does not take affect since we will first retrieve the nearest nprobe buckets + ///< it is a bit heavy to further retrieve more buckets + ///< therefore to make sure we get topk results, use nprobe=nlist and use max_codes to narrow down the search range + bool ensure_topk_full = false; + SearchParameters* quantizer_params = nullptr; virtual ~SearchParametersIVF() {} @@ -485,6 +490,7 @@ struct InvertedListScanner { * @param distances heap distances (size k) * @param labels heap labels (size k) * @param k heap size + * @param scan_cnt valid number of codes be scanned * @return number of heap updates performed */ virtual size_t scan_codes(