Skip to content

Commit 0ae5524

Browse files
heemin32github-actions[bot]
authored andcommitted
Update threshold value after new result is added (#1715)
Signed-off-by: Heemin Kim <[email protected]> (cherry picked from commit 9023604)
1 parent c2c1660 commit 0ae5524

File tree

3 files changed

+84
-35
lines changed

3 files changed

+84
-35
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
2020
* Support script score when doc value is disabled and fix misusing DISI [#1696](https://github.com/opensearch-project/k-NN/pull/1696)
2121
### Bug Fixes
2222
* Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692)
23+
* Update threshold value after new result is added [#1715](https://github.com/opensearch-project/k-NN/pull/1715)
2324
### Infrastructure
2425
### Documentation
2526
### Maintenance

jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch

+77-33
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
From 0d1385959ddecabb2825957e48ff28ff0e8abf53 Mon Sep 17 00:00:00 2001
1+
From 35ef01f59b8903dfbd4d08ff874b085e851e4228 Mon Sep 17 00:00:00 2001
22
From: Heemin Kim <[email protected]>
33
Date: Tue, 30 Jan 2024 14:43:56 -0800
44
Subject: [PATCH] Add IDGrouper for HNSW
@@ -7,18 +7,18 @@ Signed-off-by: Heemin Kim <[email protected]>
77
---
88
faiss/CMakeLists.txt | 3 +
99
faiss/Index.h | 8 +-
10-
faiss/IndexHNSW.cpp | 13 ++-
11-
faiss/IndexIDMap.cpp | 29 ++++++
12-
faiss/IndexIDMap.h | 22 +++++
13-
faiss/impl/HNSW.cpp | 10 +-
14-
faiss/impl/IDGrouper.cpp | 51 ++++++++++
15-
faiss/impl/IDGrouper.h | 51 ++++++++++
16-
faiss/impl/ResultHandler.h | 187 ++++++++++++++++++++++++++++++++++++
17-
faiss/utils/GroupHeap.h | 182 +++++++++++++++++++++++++++++++++++
10+
faiss/IndexHNSW.cpp | 13 +-
11+
faiss/IndexIDMap.cpp | 29 +++++
12+
faiss/IndexIDMap.h | 22 ++++
13+
faiss/impl/HNSW.cpp | 6 +
14+
faiss/impl/IDGrouper.cpp | 51 ++++++++
15+
faiss/impl/IDGrouper.h | 51 ++++++++
16+
faiss/impl/ResultHandler.h | 190 +++++++++++++++++++++++++++++
17+
faiss/utils/GroupHeap.h | 182 ++++++++++++++++++++++++++++
1818
tests/CMakeLists.txt | 2 +
19-
tests/test_group_heap.cpp | 98 +++++++++++++++++++
20-
tests/test_id_grouper.cpp | 189 +++++++++++++++++++++++++++++++++++++
21-
13 files changed, 838 insertions(+), 7 deletions(-)
19+
tests/test_group_heap.cpp | 98 +++++++++++++++
20+
tests/test_id_grouper.cpp | 241 +++++++++++++++++++++++++++++++++++++
21+
13 files changed, 891 insertions(+), 5 deletions(-)
2222
create mode 100644 faiss/impl/IDGrouper.cpp
2323
create mode 100644 faiss/impl/IDGrouper.h
2424
create mode 100644 faiss/utils/GroupHeap.h
@@ -54,7 +54,7 @@ index a890a46f..137e68d4 100644
5454
utils/WorkerThread.h
5555
utils/distances.h
5656
diff --git a/faiss/Index.h b/faiss/Index.h
57-
index 4b4b302b..3b673d1e 100644
57+
index 3d1bdb99..a8622858 100644
5858
--- a/faiss/Index.h
5959
+++ b/faiss/Index.h
6060
@@ -38,9 +38,10 @@
@@ -106,7 +106,7 @@ index 9a67332d..a5e0fea0 100644
106106
if (is_similarity_metric(this->metric_type)) {
107107
// we need to revert the negated distances
108108
diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp
109-
index e093bbda..e24365d5 100644
109+
index dc84052b..3f375e7b 100644
110110
--- a/faiss/IndexIDMap.cpp
111111
+++ b/faiss/IndexIDMap.cpp
112112
@@ -102,6 +102,23 @@ struct ScopedSelChange {
@@ -198,20 +198,9 @@ index 2d164123..a68887bd 100644
198198
+
199199
} // namespace faiss
200200
diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp
201-
index fb4de678..b6f602a0 100644
201+
index a9fb9daf..33b56638 100644
202202
--- a/faiss/impl/HNSW.cpp
203203
+++ b/faiss/impl/HNSW.cpp
204-
@@ -110,8 +110,8 @@ void HNSW::print_neighbor_stats(int level) const {
205-
level,
206-
nb_neighbors(level));
207-
size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
208-
-#pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \
209-
- reduction(+: tot_reciprocal) reduction(+: n_node)
210-
+#pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \
211-
+ reduction(+ : tot_reciprocal) reduction(+ : n_node)
212-
for (int i = 0; i < levels.size(); i++) {
213-
if (levels[i] > level) {
214-
n_node++;
215204
@@ -804,6 +804,12 @@ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
216205
if (auto hres = dynamic_cast<RH::SingleResultHandler*>(&res)) {
217206
return hres->k;
@@ -340,19 +329,20 @@ index 00000000..d56113d9
340329
+
341330
+} // namespace faiss
342331
diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h
343-
index 270de8dc..2f7f3e7f 100644
332+
index 270de8dc..3199634f 100644
344333
--- a/faiss/impl/ResultHandler.h
345334
+++ b/faiss/impl/ResultHandler.h
346-
@@ -12,6 +12,8 @@
335+
@@ -12,6 +12,9 @@
347336
#pragma once
348337

349338
#include <faiss/impl/AuxIndexStructures.h>
339+
+#include <faiss/impl/FaissException.h>
350340
+#include <faiss/impl/IDGrouper.h>
351341
+#include <faiss/utils/GroupHeap.h>
352342
#include <faiss/utils/Heap.h>
353343
#include <faiss/utils/partitioning.h>
354344

355-
@@ -265,6 +267,191 @@ struct HeapBlockResultHandler : BlockResultHandler<C> {
345+
@@ -265,6 +268,193 @@ struct HeapBlockResultHandler : BlockResultHandler<C> {
356346
}
357347
};
358348

@@ -436,6 +426,7 @@ index 270de8dc..2f7f3e7f 100644
436426
+ idx,
437427
+ group_id,
438428
+ &group_id_to_index_in_heap);
429+
+ threshold = heap_dis[0];
439430
+ return true;
440431
+ } else {
441432
+ size_t pos = it_pos->second;
@@ -452,6 +443,7 @@ index 270de8dc..2f7f3e7f 100644
452443
+ idx,
453444
+ group_id,
454445
+ &group_id_to_index_in_heap);
446+
+ threshold = heap_dis[0];
455447
+ return true;
456448
+ }
457449
+ }
@@ -734,10 +726,10 @@ index 00000000..3b7078da
734726
+} // namespace faiss
735727
\ No newline at end of file
736728
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
737-
index cc0a4f4c..96e19328 100644
729+
index 9017edc5..a8e9d30c 100644
738730
--- a/tests/CMakeLists.txt
739731
+++ b/tests/CMakeLists.txt
740-
@@ -26,6 +26,8 @@ set(FAISS_TEST_SRC
732+
@@ -27,6 +27,8 @@ set(FAISS_TEST_SRC
741733
test_approx_topk.cpp
742734
test_RCQ_cropping.cpp
743735
test_distances_simd.cpp
@@ -852,10 +844,10 @@ index 00000000..0e8fe7a7
852844
+}
853845
diff --git a/tests/test_id_grouper.cpp b/tests/test_id_grouper.cpp
854846
new file mode 100644
855-
index 00000000..2aed5500
847+
index 00000000..6601795b
856848
--- /dev/null
857849
+++ b/tests/test_id_grouper.cpp
858-
@@ -0,0 +1,189 @@
850+
@@ -0,0 +1,241 @@
859851
+/**
860852
+ * Copyright (c) Facebook, Inc. and its affiliates.
861853
+ *
@@ -920,6 +912,58 @@ index 00000000..2aed5500
920912
+ ASSERT_EQ(bitmap.NO_MORE_DOCS, bitmap.get_group(group_ids[3] + 1));
921913
+}
922914
+
915+
+TEST(IdGrouper, sanity_test) {
916+
+ int d = 1; // dimension
917+
+ int nb = 10; // database size
918+
+
919+
+ std::mt19937 rng;
920+
+ std::uniform_real_distribution<> distrib;
921+
+
922+
+ float* xb = new float[d * nb];
923+
+
924+
+ for (int i = 0; i < nb; i++) {
925+
+ for (int j = 0; j < d; j++)
926+
+ xb[d * i + j] = distrib(rng);
927+
+ xb[d * i] += i / 1000.;
928+
+ }
929+
+
930+
+ uint64_t bitmap[1] = {};
931+
+ faiss::IDGrouperBitmap id_grouper(1, bitmap);
932+
+ for (int i = 0; i < nb; i++) {
933+
+ id_grouper.set_group(i);
934+
+ }
935+
+
936+
+ int k = 5;
937+
+ int m = 8;
938+
+ faiss::Index* index =
939+
+ new faiss::IndexHNSWFlat(d, m, faiss::MetricType::METRIC_L2);
940+
+ index->add(nb, xb); // add vectors to the index
941+
+
942+
+ // search
943+
+ auto pSearchParameters = new faiss::SearchParametersHNSW();
944+
+
945+
+ idx_t* expectedI = new idx_t[k];
946+
+ float* expectedD = new float[k];
947+
+ index->search(1, xb, k, expectedD, expectedI, pSearchParameters);
948+
+
949+
+ idx_t* I = new idx_t[k];
950+
+ float* D = new float[k];
951+
+ pSearchParameters->grp = &id_grouper;
952+
+ index->search(1, xb, k, D, I, pSearchParameters);
953+
+
954+
+ // compare
955+
+ for (int j = 0; j < k; j++) {
956+
+ ASSERT_EQ(expectedI[j], I[j]);
957+
+ ASSERT_EQ(expectedD[j], D[j]);
958+
+ }
959+
+
960+
+ delete[] expectedI;
961+
+ delete[] expectedD;
962+
+ delete[] I;
963+
+ delete[] D;
964+
+ delete[] xb;
965+
+}
966+
+
923967
+TEST(IdGrouper, bitmap_with_hnsw) {
924968
+ int d = 1; // dimension
925969
+ int nb = 10; // database size

src/test/java/org/opensearch/knn/index/NestedSearchIT.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,13 @@ public void testNestedSearchWithLucene_whenKIsTwo_thenReturnTwoResults() {
7575
refreshIndex(INDEX_NAME);
7676
forceMergeKnnIndex(INDEX_NAME);
7777

78-
Float[] queryVector = { 1f, 1f };
78+
Float[] queryVector = { 14f, 14f };
7979
Response response = queryNestedField(INDEX_NAME, 2, queryVector);
8080
String entity = EntityUtils.toString(response.getEntity());
8181
assertEquals(2, parseHits(entity));
8282
assertEquals(2, parseTotalSearchHits(entity));
83+
assertEquals("14", parseIds(entity).get(0));
84+
assertEquals("13", parseIds(entity).get(1));
8385
}
8486

8587
@SneakyThrows
@@ -97,11 +99,13 @@ public void testNestedSearchWithFaiss_whenKIsTwo_thenReturnTwoResults() {
9799
refreshIndex(INDEX_NAME);
98100
forceMergeKnnIndex(INDEX_NAME);
99101

100-
Float[] queryVector = { 1f, 1f };
102+
Float[] queryVector = { 14f, 14f };
101103
Response response = queryNestedField(INDEX_NAME, 2, queryVector);
102104
String entity = EntityUtils.toString(response.getEntity());
103105
assertEquals(2, parseHits(entity));
104106
assertEquals(2, parseTotalSearchHits(entity));
107+
assertEquals("14", parseIds(entity).get(0));
108+
assertEquals("13", parseIds(entity).get(1));
105109
}
106110

107111
/**

0 commit comments

Comments
 (0)