1
- From 0d1385959ddecabb2825957e48ff28ff0e8abf53 Mon Sep 17 00:00:00 2001
1
+ From 35ef01f59b8903dfbd4d08ff874b085e851e4228 Mon Sep 17 00:00:00 2001
2
2
From: Heemin Kim <
[email protected] >
3
3
Date: Tue, 30 Jan 2024 14:43:56 -0800
4
4
Subject: [PATCH] Add IDGrouper for HNSW
7
7
---
8
8
faiss/CMakeLists.txt | 3 +
9
9
faiss/Index.h | 8 +-
10
- faiss/IndexHNSW.cpp | 13 ++ -
11
- faiss/IndexIDMap.cpp | 29 ++++++
12
- faiss/IndexIDMap.h | 22 +++++
13
- faiss/impl/HNSW.cpp | 10 +-
14
- faiss/impl/IDGrouper.cpp | 51 ++++++++++
15
- faiss/impl/IDGrouper.h | 51 ++++++++++
16
- faiss/impl/ResultHandler.h | 187 +++++++ +++++++++++++++++++++++++++++
17
- faiss/utils/GroupHeap.h | 182 +++++++++++++++++++++++++++++++++++
10
+ faiss/IndexHNSW.cpp | 13 +-
11
+ faiss/IndexIDMap.cpp | 29 +++++
12
+ faiss/IndexIDMap.h | 22 ++++
13
+ faiss/impl/HNSW.cpp | 6 +
14
+ faiss/impl/IDGrouper.cpp | 51 ++++++++
15
+ faiss/impl/IDGrouper.h | 51 ++++++++
16
+ faiss/impl/ResultHandler.h | 190 +++++++++++++++++++++++++++++
17
+ faiss/utils/GroupHeap.h | 182 ++++++++++++++++++++++++++++
18
18
tests/CMakeLists.txt | 2 +
19
- tests/test_group_heap.cpp | 98 +++++++++++++++++++
20
- tests/test_id_grouper.cpp | 189 +++++++++++++++++++++++++++++++++++++
21
- 13 files changed, 838 insertions(+), 7 deletions(-)
19
+ tests/test_group_heap.cpp | 98 +++++++++++++++
20
+ tests/test_id_grouper.cpp | 241 +++++++++++++++++++++++++++++++++++++
21
+ 13 files changed, 891 insertions(+), 5 deletions(-)
22
22
create mode 100644 faiss/impl/IDGrouper.cpp
23
23
create mode 100644 faiss/impl/IDGrouper.h
24
24
create mode 100644 faiss/utils/GroupHeap.h
@@ -54,7 +54,7 @@ index a890a46f..137e68d4 100644
54
54
utils/WorkerThread.h
55
55
utils/distances.h
56
56
diff --git a/faiss/Index.h b/faiss/Index.h
57
- index 4b4b302b..3b673d1e 100644
57
+ index 3d1bdb99..a8622858 100644
58
58
--- a/faiss/Index.h
59
59
+++ b/faiss/Index.h
60
60
@@ -38,9 +38,10 @@
@@ -106,7 +106,7 @@ index 9a67332d..a5e0fea0 100644
106
106
if (is_similarity_metric(this->metric_type)) {
107
107
// we need to revert the negated distances
108
108
diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp
109
- index e093bbda..e24365d5 100644
109
+ index dc84052b..3f375e7b 100644
110
110
--- a/faiss/IndexIDMap.cpp
111
111
+++ b/faiss/IndexIDMap.cpp
112
112
@@ -102,6 +102,23 @@ struct ScopedSelChange {
@@ -198,20 +198,9 @@ index 2d164123..a68887bd 100644
198
198
+
199
199
} // namespace faiss
200
200
diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp
201
- index fb4de678..b6f602a0 100644
201
+ index a9fb9daf..33b56638 100644
202
202
--- a/faiss/impl/HNSW.cpp
203
203
+++ b/faiss/impl/HNSW.cpp
204
- @@ -110,8 +110,8 @@ void HNSW::print_neighbor_stats(int level) const {
205
- level,
206
- nb_neighbors(level));
207
- size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
208
- - #pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \
209
- - reduction(+: tot_reciprocal) reduction(+: n_node)
210
- + #pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \
211
- + reduction(+ : tot_reciprocal) reduction(+ : n_node)
212
- for (int i = 0; i < levels.size(); i++) {
213
- if (levels[i] > level) {
214
- n_node++;
215
204
@@ -804,6 +804,12 @@ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
216
205
if (auto hres = dynamic_cast<RH::SingleResultHandler*>(&res)) {
217
206
return hres->k;
@@ -340,19 +329,20 @@ index 00000000..d56113d9
340
329
+
341
330
+ } // namespace faiss
342
331
diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h
343
- index 270de8dc..2f7f3e7f 100644
332
+ index 270de8dc..3199634f 100644
344
333
--- a/faiss/impl/ResultHandler.h
345
334
+++ b/faiss/impl/ResultHandler.h
346
- @@ -12,6 +12,8 @@
335
+ @@ -12,6 +12,9 @@
347
336
#pragma once
348
337
349
338
#include <faiss/impl/AuxIndexStructures.h>
339
+ + #include <faiss/impl/FaissException.h>
350
340
+ #include <faiss/impl/IDGrouper.h>
351
341
+ #include <faiss/utils/GroupHeap.h>
352
342
#include <faiss/utils/Heap.h>
353
343
#include <faiss/utils/partitioning.h>
354
344
355
- @@ -265,6 +267,191 @@ struct HeapBlockResultHandler : BlockResultHandler<C> {
345
+ @@ -265,6 +268,193 @@ struct HeapBlockResultHandler : BlockResultHandler<C> {
356
346
}
357
347
};
358
348
@@ -436,6 +426,7 @@ index 270de8dc..2f7f3e7f 100644
436
426
+ idx,
437
427
+ group_id,
438
428
+ &group_id_to_index_in_heap);
429
+ + threshold = heap_dis[0];
439
430
+ return true;
440
431
+ } else {
441
432
+ size_t pos = it_pos->second;
@@ -452,6 +443,7 @@ index 270de8dc..2f7f3e7f 100644
452
443
+ idx,
453
444
+ group_id,
454
445
+ &group_id_to_index_in_heap);
446
+ + threshold = heap_dis[0];
455
447
+ return true;
456
448
+ }
457
449
+ }
@@ -734,10 +726,10 @@ index 00000000..3b7078da
734
726
+ } // namespace faiss
735
727
\ No newline at end of file
736
728
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
737
- index cc0a4f4c..96e19328 100644
729
+ index 9017edc5..a8e9d30c 100644
738
730
--- a/tests/CMakeLists.txt
739
731
+++ b/tests/CMakeLists.txt
740
- @@ -26 ,6 +26 ,8 @@ set(FAISS_TEST_SRC
732
+ @@ -27 ,6 +27 ,8 @@ set(FAISS_TEST_SRC
741
733
test_approx_topk.cpp
742
734
test_RCQ_cropping.cpp
743
735
test_distances_simd.cpp
@@ -852,10 +844,10 @@ index 00000000..0e8fe7a7
852
844
+ }
853
845
diff --git a/tests/test_id_grouper.cpp b/tests/test_id_grouper.cpp
854
846
new file mode 100644
855
- index 00000000..2aed5500
847
+ index 00000000..6601795b
856
848
--- /dev/null
857
849
+++ b/tests/test_id_grouper.cpp
858
- @@ -0,0 +1,189 @@
850
+ @@ -0,0 +1,241 @@
859
851
+ /**
860
852
+ * Copyright (c) Facebook, Inc. and its affiliates.
861
853
+ *
@@ -920,6 +912,58 @@ index 00000000..2aed5500
920
912
+ ASSERT_EQ(bitmap.NO_MORE_DOCS, bitmap.get_group(group_ids[3] + 1));
921
913
+ }
922
914
+
915
+ + TEST(IdGrouper, sanity_test) {
916
+ + int d = 1; // dimension
917
+ + int nb = 10; // database size
918
+ +
919
+ + std::mt19937 rng;
920
+ + std::uniform_real_distribution<> distrib;
921
+ +
922
+ + float* xb = new float[d * nb];
923
+ +
924
+ + for (int i = 0; i < nb; i++) {
925
+ + for (int j = 0; j < d; j++)
926
+ + xb[d * i + j] = distrib(rng);
927
+ + xb[d * i] += i / 1000.;
928
+ + }
929
+ +
930
+ + uint64_t bitmap[1] = {};
931
+ + faiss::IDGrouperBitmap id_grouper(1, bitmap);
932
+ + for (int i = 0; i < nb; i++) {
933
+ + id_grouper.set_group(i);
934
+ + }
935
+ +
936
+ + int k = 5;
937
+ + int m = 8;
938
+ + faiss::Index* index =
939
+ + new faiss::IndexHNSWFlat(d, m, faiss::MetricType::METRIC_L2);
940
+ + index->add(nb, xb); // add vectors to the index
941
+ +
942
+ + // search
943
+ + auto pSearchParameters = new faiss::SearchParametersHNSW();
944
+ +
945
+ + idx_t* expectedI = new idx_t[k];
946
+ + float* expectedD = new float[k];
947
+ + index->search(1, xb, k, expectedD, expectedI, pSearchParameters);
948
+ +
949
+ + idx_t* I = new idx_t[k];
950
+ + float* D = new float[k];
951
+ + pSearchParameters->grp = &id_grouper;
952
+ + index->search(1, xb, k, D, I, pSearchParameters);
953
+ +
954
+ + // compare
955
+ + for (int j = 0; j < k; j++) {
956
+ + ASSERT_EQ(expectedI[j], I[j]);
957
+ + ASSERT_EQ(expectedD[j], D[j]);
958
+ + }
959
+ +
960
+ + delete[] expectedI;
961
+ + delete[] expectedD;
962
+ + delete[] I;
963
+ + delete[] D;
964
+ + delete[] xb;
965
+ + }
966
+ +
923
967
+ TEST(IdGrouper, bitmap_with_hnsw) {
924
968
+ int d = 1; // dimension
925
969
+ int nb = 10; // database size
0 commit comments