Skip to content

Commit d227c8f

Browse files
committed
Add binary format support with IVF method in Faiss Engine
Signed-off-by: Junqiu Lei <[email protected]>
1 parent 517506a commit d227c8f

File tree

61 files changed

+1778
-455
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+1778
-455
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1717
* Adds dynamic query parameter ef_search [#1783](https://github.com/opensearch-project/k-NN/pull/1783)
1818
* Adds dynamic query parameter ef_search in radial search faiss engine [#1790](https://github.com/opensearch-project/k-NN/pull/1790)
1919
* Add binary format support with HNSW method in Faiss Engine [#1781](https://github.com/opensearch-project/k-NN/pull/1781)
20+
* Add binary format support with IVF method in Faiss Engine [#1784](https://github.com/opensearch-project/k-NN/pull/1784)
2021
### Enhancements
2122
### Bug Fixes
2223
* Fixing the arithmetic to find the number of vectors to stream from java to jni layer.[#1804](https://github.com/opensearch-project/k-NN/pull/1804)

jni/include/faiss_wrapper.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ namespace knn_jni {
2929
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
3030
jobject parametersJ);
3131

32+
// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
33+
// based off of the template index passed in. The index is serialized to indexPathJ.
34+
void CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
35+
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
36+
jobject parametersJ);
37+
3238
// Load an index from indexPathJ into memory.
3339
//
3440
// Return a pointer to the loaded index
@@ -96,6 +102,13 @@ namespace knn_jni {
96102
jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
97103
jlong trainVectorsPointerJ);
98104

105+
// Create an empty binary index defined by the values in the Java map, parametersJ. Train the index with
106+
// the vector of floats located at trainVectorsPointerJ.
107+
//
108+
// Return the serialized representation
109+
jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
110+
jlong trainVectorsPointerJ);
111+
99112
/*
100113
* Perform a range search with filter against the index located in memory at indexPointerJ.
101114
*

jni/include/org_opensearch_knn_jni_FaissService.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde
4343
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate
4444
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject);
4545

46+
/*
47+
* Class: org_opensearch_knn_jni_FaissService
48+
* Method: createBinaryIndexFromTemplate
49+
* Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V
50+
*/
51+
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate
52+
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject);
53+
4654
/*
4755
* Class: org_opensearch_knn_jni_FaissService
4856
* Method: loadIndex
@@ -139,6 +147,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_initLibrary
139147
JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
140148
(JNIEnv *, jclass, jobject, jint, jlong);
141149

150+
/*
151+
* Class: org_opensearch_knn_jni_FaissService
152+
* Method: trainBinaryIndex
153+
* Signature: (Ljava/util/Map;IJ)[B
154+
*/
155+
JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex
156+
(JNIEnv *, jclass, jobject, jint, jlong);
157+
142158
/*
143159
* Class: org_opensearch_knn_jni_FaissService
144160
* Method: transferVectors

jni/src/faiss_wrapper.cpp

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env,
7070
// Train an index with data provided
7171
void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x);
7272

73+
// Train a binary index with data provided
74+
void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x);
75+
7376
// Converts the int FilterIds to Faiss ids type array.
7477
void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds);
7578

@@ -223,6 +226,76 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface *
223226
faiss::write_index(&idMap, indexPathCpp.c_str());
224227
}
225228

229+
void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
230+
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ,
231+
jbyteArray templateIndexJ, jobject parametersJ) {
232+
if (idsJ == nullptr) {
233+
throw std::runtime_error("IDs cannot be null");
234+
}
235+
236+
if (vectorsAddressJ <= 0) {
237+
throw std::runtime_error("VectorsAddress cannot be less than 0");
238+
}
239+
240+
if(dimJ <= 0) {
241+
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
242+
}
243+
244+
if (indexPathJ == nullptr) {
245+
throw std::runtime_error("Index path cannot be null");
246+
}
247+
248+
if (templateIndexJ == nullptr) {
249+
throw std::runtime_error("Template index cannot be null");
250+
}
251+
252+
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
253+
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);
254+
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
255+
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
256+
omp_set_num_threads(threadCount);
257+
}
258+
jniUtil->DeleteLocalRef(env, parametersJ);
259+
260+
// Read data set
261+
// Read vectors from memory address
262+
auto *inputVectors = reinterpret_cast<std::vector<uint8_t>*>(vectorsAddressJ);
263+
int dim = (int)dimJ;
264+
if (dim % 8 != 0) {
265+
throw std::runtime_error("Dimensions should be multiply of 8");
266+
}
267+
int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8));
268+
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
269+
if (numIds != numVectors) {
270+
throw std::runtime_error("Number of IDs does not match number of vectors");
271+
}
272+
273+
// Get vector of bytes from jbytearray
274+
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
275+
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);
276+
277+
faiss::VectorIOReader vectorIoReader;
278+
for (int i = 0; i < indexBytesCount; i++) {
279+
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
280+
}
281+
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);
282+
283+
// Create faiss index
284+
std::unique_ptr<faiss::IndexBinary> indexWriter;
285+
indexWriter.reset(faiss::read_index_binary(&vectorIoReader, 0));
286+
287+
auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
288+
faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap(indexWriter.get());
289+
idMap.add_with_ids(numVectors, reinterpret_cast<const uint8_t*>(inputVectors->data()), idVector.data());
290+
// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
291+
// This is not the ideal approach, please refer this gh issue for long term solution:
292+
// https://github.com/opensearch-project/k-NN/issues/1600
293+
delete inputVectors;
294+
// Write the index to disk
295+
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
296+
faiss::write_index_binary(&idMap, indexPathCpp.c_str());
297+
}
298+
226299
jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) {
227300
if (indexPathJ == nullptr) {
228301
throw std::runtime_error("Index path cannot be null");
@@ -624,6 +697,57 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti
624697
return ret;
625698
}
626699

700+
jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ,
701+
jint dimensionJ, jlong trainVectorsPointerJ) {
702+
// First, we need to build the index
703+
if (parametersJ == nullptr) {
704+
throw std::runtime_error("Parameters cannot be null");
705+
}
706+
707+
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);
708+
709+
jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE);
710+
std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ));
711+
faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp);
712+
713+
// Create faiss index
714+
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
715+
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));
716+
717+
std::unique_ptr<faiss::IndexBinary> indexWriter;
718+
indexWriter.reset(faiss::index_binary_factory((int) dimensionJ, indexDescriptionCpp.c_str()));
719+
720+
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
721+
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
722+
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
723+
omp_set_num_threads(threadCount);
724+
}
725+
726+
// Train index if needed
727+
auto *trainingVectorsPointerCpp = reinterpret_cast<std::vector<float>*>(trainVectorsPointerJ);
728+
int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ;
729+
if(!indexWriter->is_trained) {
730+
InternalTrainBinaryIndex(indexWriter.get(), numVectors, trainingVectorsPointerCpp->data());
731+
}
732+
jniUtil->DeleteLocalRef(env, parametersJ);
733+
734+
// Now that indexWriter is trained, we just load the bytes into an array and return
735+
faiss::VectorIOWriter vectorIoWriter;
736+
faiss::write_index_binary(indexWriter.get(), &vectorIoWriter);
737+
738+
// Wrap in smart pointer
739+
std::unique_ptr<jbyte[]> jbytesBuffer;
740+
jbytesBuffer.reset(new jbyte[vectorIoWriter.data.size()]);
741+
int c = 0;
742+
for (auto b : vectorIoWriter.data) {
743+
jbytesBuffer[c++] = (jbyte) b;
744+
}
745+
746+
jbyteArray ret = jniUtil->NewByteArray(env, vectorIoWriter.data.size());
747+
jniUtil->SetByteArrayRegion(env, ret, 0, vectorIoWriter.data.size(), jbytesBuffer.get());
748+
return ret;
749+
}
750+
627751
faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType) {
628752
if (spaceType == knn_jni::L2) {
629753
return faiss::METRIC_L2;
@@ -682,6 +806,15 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) {
682806
}
683807
}
684808

809+
void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) {
810+
if (auto * indexIvf = dynamic_cast<faiss::IndexBinaryIVF*>(index)) {
811+
indexIvf->make_direct_map();
812+
}
813+
if (!index->is_trained) {
814+
index->train(n, reinterpret_cast<const uint8_t*>(x));
815+
}
816+
}
817+
685818
std::unique_ptr<faiss::IDGrouperBitmap> buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector<uint64_t>* bitmap) {
686819
int *parentIdsArray = jniUtil->GetIntArrayElements(env, parentIdsJ, nullptr);
687820
int parentIdsLength = jniUtil->GetJavaIntArrayLength(env, parentIdsJ);

jni/src/org_opensearch_knn_jni_FaissService.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,21 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
9090
}
9191
}
9292

93+
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate(JNIEnv * env, jclass cls,
94+
jintArray idsJ,
95+
jlong vectorsAddressJ,
96+
jint dimJ,
97+
jstring indexPathJ,
98+
jbyteArray templateIndexJ,
99+
jobject parametersJ)
100+
{
101+
try {
102+
knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ);
103+
} catch (...) {
104+
jniUtil.CatchCppExceptionAndThrowJava(env);
105+
}
106+
}
107+
93108
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEnv * env, jclass cls, jstring indexPathJ)
94109
{
95110
try {
@@ -220,6 +235,19 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
220235
return nullptr;
221236
}
222237

238+
JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex(JNIEnv * env, jclass cls,
239+
jobject parametersJ,
240+
jint dimensionJ,
241+
jlong trainVectorsPointerJ)
242+
{
243+
try {
244+
return knn_jni::faiss_wrapper::TrainBinaryIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ);
245+
} catch (...) {
246+
jniUtil.CatchCppExceptionAndThrowJava(env);
247+
}
248+
return nullptr;
249+
}
250+
223251
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors(JNIEnv * env, jclass cls,
224252
jlong vectorsPointerJ,
225253
jobjectArray vectorsJ)

src/main/java/org/opensearch/knn/common/KNNConstants.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ public class KNNConstants {
6767
public static final String SEARCH_SIZE_PARAMETER = "search_size";
6868

6969
public static final String VECTOR_DATA_TYPE_FIELD = "data_type";
70+
public static final String MODEL_VECTOR_DATA_TYPE_KEY = VECTOR_DATA_TYPE_FIELD;
7071
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT;
7172

7273
public static final String RADIAL_SEARCH_KEY = "radial_search";

0 commit comments

Comments
 (0)