Skip to content

Commit dafd79b

Browse files
committed
Fix memory leak on test code (#1776)
Signed-off-by: Heemin Kim <[email protected]>
1 parent 57a081e commit dafd79b

File tree

6 files changed

+23
-17
lines changed

6 files changed

+23
-17
lines changed

jni/include/faiss_wrapper.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ namespace knn_jni {
7878
//
7979
// Return an array of KNNQueryResults
8080
jobjectArray QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
81-
jbyteArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);
81+
jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);
8282

8383
// Free the index located in memory at indexPointerJ
8484
void Free(jlong indexPointer);

jni/src/commons.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,5 @@ int knn_jni::commons::getIntegerMethodParameter(JNIEnv * env, knn_jni::JNIUtilIn
7171
}
7272

7373
return defaultValue;
74+
}
7475
#endif //OPENSEARCH_KNN_COMMONS_H

jni/src/faiss_wrapper.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -490,10 +490,12 @@ jobjectArray knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(knn_jni::JNIUti
490490
std::unique_ptr<faiss::IDGrouperBitmap> idGrouper;
491491
std::vector<uint64_t> idGrouperBitmap;
492492
auto hnswReader = dynamic_cast<const faiss::IndexBinaryHNSW*>(indexReader->index);
493-
if(hnswReader!= nullptr) {
493+
// TODO currently, search parameter is not supported in binary index
494+
// To avoid test failure, we skip setting ef search when methodPramsJ is null temporary
495+
if(hnswReader!= nullptr && (methodParamsJ != nullptr || parentIdsJ != nullptr)) {
494496
// Query param efsearch supersedes ef_search provided during index setting.
495497
hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch);
496-
if(parentIdsJ != nullptr) {
498+
if (parentIdsJ != nullptr) {
497499
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
498500
hnswParams.grp = idGrouper.get();
499501
}

jni/tests/faiss_wrapper_test.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ TEST(FaissQueryBinaryIndexTest, BasicAssertions) {
425425
knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(
426426
&mockJNIUtil, jniEnv,
427427
reinterpret_cast<jlong>(&createdIndexWithData),
428-
reinterpret_cast<jbyteArray>(&query), k, nullptr, 0, nullptr)));
428+
reinterpret_cast<jbyteArray>(&query), k, nullptr, nullptr, 0, nullptr)));
429429

430430
ASSERT_EQ(k, results->size());
431431

@@ -556,6 +556,10 @@ TEST(FaissQueryIndexWithParentFilterTest, BasicAssertions) {
556556
// Setup jni
557557
JNIEnv *jniEnv = nullptr;
558558
NiceMock<test_util::MockJNIUtil> mockJNIUtil;
559+
EXPECT_CALL(mockJNIUtil,
560+
GetJavaIntArrayLength(
561+
jniEnv, reinterpret_cast<jintArray>(&parentIds)))
562+
.WillRepeatedly(Return(parentIds.size()));
559563
for (auto query : queries) {
560564
std::unique_ptr<std::vector<std::pair<int, float> *>> results(
561565
reinterpret_cast<std::vector<std::pair<int, float> *> *>(
@@ -635,13 +639,13 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) {
635639
// Define the data
636640
faiss::idx_t numIds = 200;
637641
std::vector<faiss::idx_t> ids;
638-
auto *vectors = new std::vector<float>();
642+
std::vector<float> vectors;
639643
int dim = 2;
640-
vectors->reserve(dim * numIds);
644+
vectors.reserve(dim * numIds);
641645
for (int64_t i = 0; i < numIds; ++i) {
642646
ids.push_back(i);
643647
for (int j = 0; j < dim; ++j) {
644-
vectors->push_back(test_util::RandomFloat(-500.0, 500.0));
648+
vectors.push_back(test_util::RandomFloat(-500.0, 500.0));
645649
}
646650
}
647651

@@ -660,14 +664,14 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) {
660664
EXPECT_CALL(mockJNIUtil,
661665
GetJavaObjectArrayLength(
662666
jniEnv, reinterpret_cast<jobjectArray>(&vectors)))
663-
.WillRepeatedly(Return(vectors->size()));
667+
.WillRepeatedly(Return(vectors.size()));
664668

665669
// Create the index
666670
std::unique_ptr<FaissMethods> faissMethods(new FaissMethods());
667671
knn_jni::faiss_wrapper::IndexService IndexService(std::move(faissMethods));
668672
knn_jni::faiss_wrapper::CreateIndex(
669673
&mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids),
670-
(jlong)vectors, dim, (jstring)&indexPath,
674+
(jlong)&vectors, dim, (jstring)&indexPath,
671675
(jobject)&parametersMap, &IndexService);
672676

673677
// Make sure index can be loaded

jni/tests/faiss_wrapper_unit_test.cpp

+6-7
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,23 @@
2222
#include "faiss/IndexIDMap.h"
2323

2424
using ::testing::NiceMock;
25-
2625
using idx_t = faiss::idx_t;
2726

28-
struct MockIndex : faiss::IndexHNSW {
29-
explicit MockIndex(idx_t d) : faiss::IndexHNSW(d, 32) {
27+
struct FaissMockIndex : faiss::IndexHNSW {
28+
explicit FaissMockIndex(idx_t d) : faiss::IndexHNSW(d, 32) {
3029
}
3130
};
3231

3332

34-
struct MockIdMap : faiss::IndexIDMap {
33+
struct FaissMockIdMap : faiss::IndexIDMap {
3534
mutable idx_t nCalled;
3635
mutable const float *xCalled;
3736
mutable idx_t kCalled;
3837
mutable float *distancesCalled;
3938
mutable idx_t *labelsCalled;
4039
mutable const faiss::SearchParametersHNSW *paramsCalled;
4140

42-
explicit MockIdMap(MockIndex *index) : faiss::IndexIDMapTemplate<faiss::Index>(index) {
41+
explicit FaissMockIdMap(FaissMockIndex *index) : faiss::IndexIDMapTemplate<faiss::Index>(index) {
4342
}
4443

4544
void search(
@@ -85,8 +84,8 @@ class FaissWrappeterParametrizedTestFixture : public testing::TestWithParam<Quer
8584
};
8685

8786
protected:
88-
MockIndex index_;
89-
MockIdMap id_map_;
87+
FaissMockIndex index_;
88+
FaissMockIdMap id_map_;
9089
};
9190

9291
namespace query_index_test {

src/test/java/org/opensearch/knn/jni/JNIServiceTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,7 @@ public void testQueryBinaryIndex_faiss_valid() {
984984
assertNotEquals(0, pointer);
985985

986986
for (byte[] query : testData.binaryQueries) {
987-
KNNQueryResult[] results = JNIService.queryBinaryIndex(pointer, query, k, Collections.emptyMap(), KNNEngine.FAISS, null, 0, null);
987+
KNNQueryResult[] results = JNIService.queryBinaryIndex(pointer, query, k, null, KNNEngine.FAISS, null, 0, null);
988988
assertEquals(k, results.length);
989989
}
990990
}

0 commit comments

Comments
 (0)