Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/index/ivf/ivf_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class IvfConfig : public BaseConfig {
CFG_INT nlist;
CFG_INT nprobe;
CFG_BOOL use_elkan;
CFG_BOOL ensure_topk_full;
CFG_BOOL ensure_topk_full; // only take affect on temp index(IVF_FLAT_CC) now
KNOHWERE_DECLARE_CONFIG(IvfConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(nlist)
.set_default(128)
Expand Down
36 changes: 31 additions & 5 deletions tests/ut/test_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,21 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") {
using std::make_tuple;
auto ivfflatcc_gen_ = [base_gen, nb]() {
knowhere::Json json = base_gen();
json[knowhere::indexparam::NLIST] = 16;
json[knowhere::indexparam::NLIST] = 32;
json[knowhere::indexparam::NPROBE] = 1;
json[knowhere::indexparam::SSIZE] = 48;
json[knowhere::meta::TOPK] = nb;
return json;
};
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen_),
}));
auto ivfflatcc_gen_no_ensure_topk_ = [ivfflatcc_gen_, nb]() {
knowhere::Json json = ivfflatcc_gen_();
json[knowhere::meta::TOPK] = nb / 2;
json[knowhere::indexparam::ENSURE_TOPK_FULL] = false;
return json;
};
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>(
{make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen_),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen_no_ensure_topk_)}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
Expand All @@ -235,7 +241,27 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") {
auto results = idx.Search(*query_ds, json, nullptr);
auto gt = knowhere::BruteForce::Search<knowhere::fp32>(train_ds, query_ds, json, nullptr);
float recall = GetKNNRecall(*gt.value(), *results.value());
REQUIRE(recall > kBruteForceRecallThreshold);
if (ivfflatcc_gen_().dump() == cfg_json) {
REQUIRE(recall > kBruteForceRecallThreshold);
} else {
REQUIRE(recall < kBruteForceRecallThreshold);
}

std::vector<std::function<std::vector<uint8_t>(size_t, size_t)>> gen_bitset_funcs = {
GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet};
const auto bitset_percentages = 0.5f;
for (const auto& gen_func : gen_bitset_funcs) {
auto bitset_data = gen_func(nb, bitset_percentages * nb);
knowhere::BitsetView bitset(bitset_data.data(), nb);
auto results = idx.Search(*query_ds, json, bitset);
auto gt = knowhere::BruteForce::Search<knowhere::fp32>(train_ds, query_ds, json, bitset);
float recall = GetKNNRecall(*gt.value(), *results.value());
if (ivfflatcc_gen_().dump() == cfg_json) {
REQUIRE(recall > kBruteForceRecallThreshold);
} else {
REQUIRE(recall < kBruteForceRecallThreshold);
}
}
}

SECTION("Test Search with Bitset") {
Expand Down
8 changes: 7 additions & 1 deletion thirdparty/faiss/faiss/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,12 @@ struct Level1Quantizer {
struct SearchParametersIVF : SearchParameters {
size_t nprobe = 1; ///< number of probes at query time
size_t max_codes = 0; ///< max nb of codes to visit to do a query
bool ensure_topk_full = false; ///< indicate whether we make sure topk result is full
///< indicate whether we should early teriminate before topk results full when search reaches max_codes
///< to minimize code change, when users only use nprobe to search, this config does not take affect since we will first retrieve the nearest nprobe buckets
///< it is a bit heavy to further retrieve more buckets
///< therefore to make sure we get topk results, use nprobe=nlist and use max_codes to narrow down the search range
bool ensure_topk_full = false;

SearchParameters* quantizer_params = nullptr;

virtual ~SearchParametersIVF() {}
Expand Down Expand Up @@ -485,6 +490,7 @@ struct InvertedListScanner {
* @param distances heap distances (size k)
* @param labels heap labels (size k)
* @param k heap size
* @param scan_cnt valid number of codes be scanned
* @return number of heap updates performed
*/
virtual size_t scan_codes(
Expand Down