Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ constexpr const char* M = "m"; // PQ param for IVFPQ
constexpr const char* SSIZE = "ssize";
constexpr const char* REORDER_K = "reorder_k";
constexpr const char* WITH_RAW_DATA = "with_raw_data";
constexpr const char* ENSURE_TOPK_FULL = "ensure_topk_full";
// RAFT Params
constexpr const char* REFINE_RATIO = "refine_ratio";
constexpr const char* CACHE_DATASET_ON_DEVICE = "cache_dataset_on_device";
Expand Down
15 changes: 12 additions & 3 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -541,17 +541,26 @@ IvfIndexNode<DataType, IndexType>::Search(const DataSet& dataset, const Config&
distances[i + offset] = static_cast<float>(i_distances[i + offset]);
}
}
} else if constexpr (std::is_same<IndexType, faiss::IndexIVFFlat>::value) {
} else if constexpr (std::is_same<IndexType, faiss::IndexIVFFlatCC>::value) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not using IndexIVFFlat, just IndexIVFFlatCC all the time, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I dont know why we need a seperate code branch for IndexIVFFlat, it is the same with IVFPQ, IVFSQ, etc

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we do need a new code branch for IndexIVFFlatCC, so I replace it.

auto cur_query = (const float*)data + index * dim;
if (is_cosine) {
copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
cur_query = copied_query.get();
}

faiss::IVFSearchParameters ivf_search_params;
ivf_search_params.nprobe = nprobe;
ivf_search_params.max_codes = 0;

ivf_search_params.sel = id_selector;
ivf_search_params.ensure_topk_full = ivf_cfg.ensure_topk_full.value();
if (ivf_search_params.ensure_topk_full) {
ivf_search_params.nprobe = index_->nlist;
// use max_codes to early termination
ivf_search_params.max_codes =
(nprobe * 1.0 / index_->nlist) * (index_->ntotal - bitset.count());
} else {
ivf_search_params.nprobe = nprobe;
ivf_search_params.max_codes = 0;
}

index_->search(1, cur_query, k, distances + offset, ids + offset, &ivf_search_params);
} else if constexpr (std::is_same<IndexType, faiss::IndexScaNN>::value) {
Expand Down
5 changes: 5 additions & 0 deletions src/index/ivf/ivf_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class IvfConfig : public BaseConfig {
CFG_INT nlist;
CFG_INT nprobe;
CFG_BOOL use_elkan;
CFG_BOOL ensure_topk_full;
KNOHWERE_DECLARE_CONFIG(IvfConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(nlist)
.set_default(128)
Expand All @@ -36,6 +37,10 @@ class IvfConfig : public BaseConfig {
.set_default(true)
.description("whether to use elkan algorithm")
.for_train();
KNOWHERE_CONFIG_DECLARE_FIELD(ensure_topk_full)
.set_default(true)
.description("whether to make sure topk results full")
.for_search();
}
};

Expand Down
1 change: 1 addition & 0 deletions tests/ut/test_ivfflat_cc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") {
knowhere::Json json = base_gen();
json[knowhere::indexparam::NLIST] = 128;
json[knowhere::indexparam::NPROBE] = 16;
json[knowhere::indexparam::ENSURE_TOPK_FULL] = false;
return json;
};

Expand Down
26 changes: 26 additions & 0 deletions tests/ut/test_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,32 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") {
REQUIRE(recall > kBruteForceRecallThreshold);
}

SECTION("Test Search with IVFFLATCC ensure topk full") {
using std::make_tuple;
auto ivfflatcc_gen_ = [base_gen, nb]() {
knowhere::Json json = base_gen();
json[knowhere::indexparam::NLIST] = 16;
json[knowhere::indexparam::NPROBE] = 1;
json[knowhere::indexparam::SSIZE] = 48;
json[knowhere::meta::TOPK] = nb;
return json;
};
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen_),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success);

auto results = idx.Search(*query_ds, json, nullptr);
auto gt = knowhere::BruteForce::Search<knowhere::fp32>(train_ds, query_ds, json, nullptr);
float recall = GetKNNRecall(*gt.value(), *results.value());
REQUIRE(recall > kBruteForceRecallThreshold);
}

SECTION("Test Search with Bitset") {
using std::make_tuple;
auto [name, gen, threshold] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>, float>({
Expand Down
16 changes: 11 additions & 5 deletions thirdparty/faiss/faiss/IndexIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ void IndexIVF::search_preassigned(

const idx_t unlimited_list_size = std::numeric_limits<idx_t>::max();
idx_t max_codes = params ? params->max_codes : this->max_codes;
bool ensure_topk_full = params ? params->ensure_topk_full : false;
IDSelector* sel = params ? params->sel : nullptr;
const IDSelectorRange* selr = dynamic_cast<const IDSelectorRange*>(sel);
if (selr) {
Expand Down Expand Up @@ -545,7 +546,7 @@ void IndexIVF::search_preassigned(

return list_size;
} else {
size_t scan_cnt = 0;
size_t scan_cnt = 0; // only record valid cnt

size_t segment_num = invlists->get_segment_num(key);
for (size_t segment_idx = 0; segment_idx < segment_num; segment_idx++) {
Expand All @@ -570,8 +571,8 @@ void IndexIVF::search_preassigned(
ids,
simi,
idxi,
k);
scan_cnt += segment_size;
k,
scan_cnt);
}

return scan_cnt;
Expand Down Expand Up @@ -613,7 +614,9 @@ void IndexIVF::search_preassigned(
simi,
idxi,
max_codes - nscan);
if (nscan >= max_codes) {

// if ensure_topk_full enabled, also make sure nscan >= k, then stop search further
if (nscan >= max_codes && (!ensure_topk_full || nscan >= k)) {
break;
}
}
Expand Down Expand Up @@ -1306,13 +1309,15 @@ size_t InvertedListScanner::scan_codes(
const idx_t* ids,
float* simi,
idx_t* idxi,
size_t k) const {
size_t k,
size_t& scan_cnt) const {
size_t nup = 0;

if (!keep_max) {
for (size_t j = 0; j < list_size; j++) {
// // todo aguzhva: use int64_t id instead of j ?
if (!sel || sel->is_member(j)) {
scan_cnt++;
float dis = distance_to_code(codes);
if (code_norms) {
dis /= code_norms[j];
Expand All @@ -1329,6 +1334,7 @@ size_t InvertedListScanner::scan_codes(
for (size_t j = 0; j < list_size; j++) {
// // todo aguzhva: use int64_t id instead of j ?
if (!sel || sel->is_member(j)) {
scan_cnt++;
float dis = distance_to_code(codes);
if (code_norms) {
dis /= code_norms[j];
Expand Down
4 changes: 3 additions & 1 deletion thirdparty/faiss/faiss/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ struct Level1Quantizer {
struct SearchParametersIVF : SearchParameters {
size_t nprobe = 1; ///< number of probes at query time
size_t max_codes = 0; ///< max nb of codes to visit to do a query
bool ensure_topk_full = false; ///< indicate whether we make sure topk result is full
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add more comments here. It did not become clear what this parameter does before I read the code :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will add more comments in the next pr.

SearchParameters* quantizer_params = nullptr;

virtual ~SearchParametersIVF() {}
Expand Down Expand Up @@ -493,7 +494,8 @@ struct InvertedListScanner {
const idx_t* ids,
float* distances,
idx_t* labels,
size_t k) const;
size_t k,
size_t& scan_cnt) const;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add comments to the description of the function.


// same as scan_codes, using an iterator
virtual size_t iterate_codes(
Expand Down
12 changes: 8 additions & 4 deletions thirdparty/faiss/faiss/IndexIVFFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,18 +286,20 @@ struct IVFFlatScanner : InvertedListScanner {
const idx_t* ids,
float* simi,
idx_t* idxi,
size_t k) const override {
size_t k,
size_t& scan_cnt) const override {
const float* list_vecs = (const float*)codes;
size_t nup = 0;

// the lambda that filters acceptable elements.
auto filter =
[&](const size_t j) { return (!use_sel || sel->is_member(ids[j])); };

// the lambda that applies a filtered element.
// the lambda that applies a valid element.
auto apply =
[&](const float dis_in, const size_t j) {
const float dis = (code_norms == nullptr) ? dis_in : (dis_in / code_norms[j]);
scan_cnt++;
if (C::cmp(simi[0], dis)) {
const int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
heap_replace_top<C>(k, simi, idxi, dis, id);
Expand Down Expand Up @@ -389,18 +391,20 @@ struct IVFFlatBitsetViewScanner : InvertedListScanner {
const idx_t* __restrict ids,
float* __restrict simi,
idx_t* __restrict idxi,
size_t k) const override {
size_t k,
size_t& scan_cnt) const override {
const float* list_vecs = (const float*)codes;
size_t nup = 0;

// the lambda that filters acceptable elements.
auto filter =
[&](const size_t j) { return (!use_sel || !bitset.test(ids[j])); };

// the lambda that applies a filtered element.
// the lambda that applies a valid element.
auto apply =
[&](const float dis_in, const size_t j) {
const float dis = (code_norms == nullptr) ? dis_in : (dis_in / code_norms[j]);
scan_cnt++;
if (C::cmp(simi[0], dis)) {
const int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
heap_replace_top<C>(k, simi, idxi, dis, id);
Expand Down
3 changes: 2 additions & 1 deletion thirdparty/faiss/faiss/IndexIVFPQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1228,7 +1228,8 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
const idx_t* ids,
float* heap_sim,
idx_t* heap_ids,
size_t k) const override {
size_t k,
size_t& scan_cnt) const override {
KnnSearchResults<C, use_sel> res = {
/* key */ this->key,
/* ids */ this->store_pairs ? nullptr : ids,
Expand Down
5 changes: 3 additions & 2 deletions thirdparty/faiss/faiss/IndexIVFSpectralHash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,13 @@ struct IVFScanner : InvertedListScanner {
const idx_t* ids,
float* simi,
idx_t* idxi,
size_t k) const override {
size_t k,
size_t& scan_cnt) const override {
size_t nup = 0;
for (size_t j = 0; j < list_size; j++) {
if (!sel || sel->is_member(ids[j])) {
float dis = hc.compute(codes);

scan_cnt++;
if (dis < simi[0]) {
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
maxheap_replace_top(k, simi, idxi, dis, id);
Expand Down
3 changes: 2 additions & 1 deletion thirdparty/faiss/faiss/IndexScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ void IndexScalarQuantizer::search(
minheap_heapify(k, D, I);
}
scanner->set_query(x + i * d);
scanner->scan_codes(ntotal, codes.data(), nullptr, nullptr, D, I, k);
size_t scan_cnt = 0;
scanner->scan_codes(ntotal, codes.data(), nullptr, nullptr, D, I, k, scan_cnt);

// re-order heap
if (metric_type == METRIC_L2) {
Expand Down
6 changes: 4 additions & 2 deletions thirdparty/faiss/faiss/impl/ScalarQuantizerScanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ struct IVFSQScannerIP : InvertedListScanner {
const idx_t* ids,
float* simi,
idx_t* idxi,
size_t k) const override {
size_t k,
size_t& scan_cnt) const override {
size_t nup = 0;

for (size_t j = 0; j < list_size; j++, codes += code_size) {
Expand Down Expand Up @@ -215,7 +216,8 @@ struct IVFSQScannerL2 : InvertedListScanner {
const idx_t* ids,
float* simi,
idx_t* idxi,
size_t k) const override {
size_t k,
size_t& scan_cnt) const override {
size_t nup = 0;

// // baseline
Expand Down
14 changes: 9 additions & 5 deletions thirdparty/faiss/tests/test_lowlevel_ivf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,15 @@ void test_lowlevel_access(const char* index_key, MetricType metric) {

// here we get the inverted lists from the InvertedLists
// object but they could come from anywhere

size_t scan_cnt = 0;
scanner->scan_codes(
il->list_size(list_no),
InvertedLists::ScopedCodes(il, list_no).get(),
InvertedLists::ScopedIds(il, list_no).get(),
D.data(),
I.data(),
k);
k,
scan_cnt);

if (j == 0) {
// all results so far come from list_no, so let's check if
Expand Down Expand Up @@ -338,14 +339,15 @@ void test_lowlevel_access_binary(const char* index_key) {

// here we get the inverted lists from the InvertedLists
// object but they could come from anywhere

size_t scan_cnt = 0;
scanner->scan_codes(
il->list_size(list_no),
InvertedLists::ScopedCodes(il, list_no).get(),
InvertedLists::ScopedIds(il, list_no).get(),
D.data(),
I.data(),
k);
k,
scan_cnt);

if (j == 0) {
// all results so far come from list_no, so let's check if
Expand Down Expand Up @@ -500,13 +502,15 @@ void test_threaded_search(const char* index_key, MetricType metric) {
continue;
scanner->set_list(list_no, q_dis[i * nprobe + j]);

size_t scan_cnt = 0;
scanner->scan_codes(
il->list_size(list_no),
InvertedLists::ScopedCodes(il, list_no).get(),
InvertedLists::ScopedIds(il, list_no).get(),
local_D,
local_I,
k);
k,
scan_cnt);
}
};

Expand Down