From 9b9b0230d15ad0a0424b16d052105f2c827b5092 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Thu, 7 Nov 2024 20:14:31 -0500 Subject: [PATCH 1/2] add range_search() to IndexRefine Signed-off-by: Alexandr Guzhva --- faiss/IndexRefine.cpp | 39 +++++++++++++++++++++++++++++++++++++++ faiss/IndexRefine.h | 7 +++++++ 2 files changed, 46 insertions(+) diff --git a/faiss/IndexRefine.cpp b/faiss/IndexRefine.cpp index 8bc429a5e9..6f1f588e2e 100644 --- a/faiss/IndexRefine.cpp +++ b/faiss/IndexRefine.cpp @@ -166,6 +166,45 @@ void IndexRefine::search( } } +void IndexRefine::range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + const SearchParameters* params_in) const { + const IndexRefineSearchParameters* params = nullptr; + if (params_in) { + params = dynamic_cast(params_in); + FAISS_THROW_IF_NOT_MSG( + params, "IndexRefine params have incorrect type"); + } + + SearchParameters* base_index_params = + (params != nullptr) ? params->base_index_params : nullptr; + + base_index->range_search(n, x, radius, result, base_index_params); + +#pragma omp parallel if (n > 1) + { + std::unique_ptr dc( + refine_index->get_distance_computer()); + +#pragma omp for + for (idx_t i = 0; i < n; i++) { + dc->set_query(x + i * d); + + // reevaluate distances + const size_t idx_start = result->lims[i]; + const size_t idx_end = result->lims[i + 1]; + + for (size_t j = idx_start; j < idx_end; j++) { + const auto label = result->labels[j]; + result->distances[j] = (*dc)(label); + } + } + } +} + void IndexRefine::reconstruct(idx_t key, float* recons) const { refine_index->reconstruct(key, recons); } diff --git a/faiss/IndexRefine.h b/faiss/IndexRefine.h index 9ad4e4be29..255271695f 100644 --- a/faiss/IndexRefine.h +++ b/faiss/IndexRefine.h @@ -54,6 +54,13 @@ struct IndexRefine : Index { idx_t* labels, const SearchParameters* params = nullptr) const override; + void range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + const SearchParameters* params = nullptr) const override; + // reconstruct is routed to the refine_index void reconstruct(idx_t key, float* recons) const override; From f28ff9bc61e0966a397a2b9d099c5aaf9c4e08c4 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 24 Dec 2024 12:04:04 -0500 Subject: [PATCH 2/2] implement a unit test Signed-off-by: Alexandr Guzhva --- tests/test_refine.py | 52 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/tests/test_refine.py b/tests/test_refine.py index f272584245..9b9ce73d0d 100644 --- a/tests/test_refine.py +++ b/tests/test_refine.py @@ -8,7 +8,7 @@ import unittest import faiss -from faiss.contrib import datasets +from faiss.contrib import datasets, evaluation class TestDistanceComputer(unittest.TestCase): @@ -119,3 +119,53 @@ def test_rflat(self): def test_refine_sq8(self): # this case uses the IndexRefine class self.do_test("IVF8,PQ2x4np,Refine(SQ8)") + + +class TestIndexRefineRangeSearch(unittest.TestCase): + + def do_test(self, factory_string): + d = 32 + radius = 8 + + ds = datasets.SyntheticDataset(d, 1024, 512, 256) + + index = faiss.index_factory(d, factory_string) + index.train(ds.get_train()) + index.add(ds.get_database()) + xq = ds.get_queries() + xb = ds.get_database() + + # perform a range_search + lims_1, D1, I1 = index.range_search(xq, radius) + + # create a baseline (FlatL2) + index_flat = faiss.IndexFlatL2(d) + index_flat.train(ds.get_train()) + index_flat.add(ds.get_database()) + + lims_ref, Dref, Iref = index_flat.range_search(xq, radius) + + # add a refine index on top of the index + index_r = faiss.IndexRefine(index, index_flat) + lims_2, D2, I2 = index_r.range_search(xq, radius) + + # validate: refined range_search() keeps indices untouched + precision_1, recall_1 = evaluation.range_PR(lims_ref, Iref, lims_1, I1) + + precision_2, recall_2 = evaluation.range_PR(lims_ref, Iref, lims_2, I2) + + self.assertAlmostEqual(recall_1, recall_2) + + # validate: refined range_search() updates distances, and new distances are correct L2 distances + for iq in range(0, ds.nq): + start_lim = lims_2[iq] + end_lim = lims_2[iq + 1] + for i_lim in range(start_lim, end_lim): + idx = I2[i_lim] + l2_dis = np.sum(np.square(xq[iq : iq + 1,] - xb[idx : idx + 1,])) + + self.assertAlmostEqual(l2_dis, D2[i_lim], places=4) + + + def test_refine_1(self): + self.do_test("SQ4")