@@ -70,6 +70,9 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env,
70
70
// Train an index with data provided
71
71
void InternalTrainIndex (faiss::Index * index, faiss::idx_t n, const float * x);
72
72
73
+ // Train a binary index with data provided
74
+ void InternalTrainBinaryIndex (faiss::IndexBinary * index, faiss::idx_t n, const float * x);
75
+
73
76
// Converts the int FilterIds to Faiss ids type array.
74
77
void convertFilterIdsToFaissIdType (const int * filterIds, int filterIdsLength, faiss::idx_t * convertedFilterIds);
75
78
@@ -223,6 +226,76 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface *
223
226
faiss::write_index (&idMap, indexPathCpp.c_str ());
224
227
}
225
228
229
+ void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate (knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
230
+ jlong vectorsAddressJ, jint dimJ, jstring indexPathJ,
231
+ jbyteArray templateIndexJ, jobject parametersJ) {
232
+ if (idsJ == nullptr ) {
233
+ throw std::runtime_error (" IDs cannot be null" );
234
+ }
235
+
236
+ if (vectorsAddressJ <= 0 ) {
237
+ throw std::runtime_error (" VectorsAddress cannot be less than 0" );
238
+ }
239
+
240
+ if (dimJ <= 0 ) {
241
+ throw std::runtime_error (" Vectors dimensions cannot be less than or equal to 0" );
242
+ }
243
+
244
+ if (indexPathJ == nullptr ) {
245
+ throw std::runtime_error (" Index path cannot be null" );
246
+ }
247
+
248
+ if (templateIndexJ == nullptr ) {
249
+ throw std::runtime_error (" Template index cannot be null" );
250
+ }
251
+
252
+ // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
253
+ auto parametersCpp = jniUtil->ConvertJavaMapToCppMap (env, parametersJ);
254
+ if (parametersCpp.find (knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end ()) {
255
+ auto threadCount = jniUtil->ConvertJavaObjectToCppInteger (env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
256
+ omp_set_num_threads (threadCount);
257
+ }
258
+ jniUtil->DeleteLocalRef (env, parametersJ);
259
+
260
+ // Read data set
261
+ // Read vectors from memory address
262
+ auto *inputVectors = reinterpret_cast <std::vector<uint8_t >*>(vectorsAddressJ);
263
+ int dim = (int )dimJ;
264
+ if (dim % 8 != 0 ) {
265
+ throw std::runtime_error (" Dimensions should be multiply of 8" );
266
+ }
267
+ int numVectors = (int ) (inputVectors->size () / (uint64_t ) (dim / 8 ));
268
+ int numIds = jniUtil->GetJavaIntArrayLength (env, idsJ);
269
+ if (numIds != numVectors) {
270
+ throw std::runtime_error (" Number of IDs does not match number of vectors" );
271
+ }
272
+
273
+ // Get vector of bytes from jbytearray
274
+ int indexBytesCount = jniUtil->GetJavaBytesArrayLength (env, templateIndexJ);
275
+ jbyte * indexBytesJ = jniUtil->GetByteArrayElements (env, templateIndexJ, nullptr );
276
+
277
+ faiss::VectorIOReader vectorIoReader;
278
+ for (int i = 0 ; i < indexBytesCount; i++) {
279
+ vectorIoReader.data .push_back ((uint8_t ) indexBytesJ[i]);
280
+ }
281
+ jniUtil->ReleaseByteArrayElements (env, templateIndexJ, indexBytesJ, JNI_ABORT);
282
+
283
+ // Create faiss index
284
+ std::unique_ptr<faiss::IndexBinary> indexWriter;
285
+ indexWriter.reset (faiss::read_index_binary (&vectorIoReader, 0 ));
286
+
287
+ auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector (env, idsJ);
288
+ faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap (indexWriter.get ());
289
+ idMap.add_with_ids (numVectors, reinterpret_cast <const uint8_t *>(inputVectors->data ()), idVector.data ());
290
+ // Releasing the vectorsAddressJ memory as that is not required once we have created the index.
291
+ // This is not the ideal approach, please refer this gh issue for long term solution:
292
+ // https://github.com/opensearch-project/k-NN/issues/1600
293
+ delete inputVectors;
294
+ // Write the index to disk
295
+ std::string indexPathCpp (jniUtil->ConvertJavaStringToCppString (env, indexPathJ));
296
+ faiss::write_index_binary (&idMap, indexPathCpp.c_str ());
297
+ }
298
+
226
299
jlong knn_jni::faiss_wrapper::LoadIndex (knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) {
227
300
if (indexPathJ == nullptr ) {
228
301
throw std::runtime_error (" Index path cannot be null" );
@@ -634,6 +707,57 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti
634
707
return ret;
635
708
}
636
709
710
+ jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex (knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ,
711
+ jint dimensionJ, jlong trainVectorsPointerJ) {
712
+ // First, we need to build the index
713
+ if (parametersJ == nullptr ) {
714
+ throw std::runtime_error (" Parameters cannot be null" );
715
+ }
716
+
717
+ auto parametersCpp = jniUtil->ConvertJavaMapToCppMap (env, parametersJ);
718
+
719
+ jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow (parametersCpp, knn_jni::SPACE_TYPE);
720
+ std::string spaceTypeCpp (jniUtil->ConvertJavaObjectToCppString (env, spaceTypeJ));
721
+ faiss::MetricType metric = TranslateSpaceToMetric (spaceTypeCpp);
722
+
723
+ // Create faiss index
724
+ jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow (parametersCpp, knn_jni::INDEX_DESCRIPTION);
725
+ std::string indexDescriptionCpp (jniUtil->ConvertJavaObjectToCppString (env, indexDescriptionJ));
726
+
727
+ std::unique_ptr<faiss::IndexBinary> indexWriter;
728
+ indexWriter.reset (faiss::index_binary_factory ((int ) dimensionJ, indexDescriptionCpp.c_str ()));
729
+
730
+ // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
731
+ if (parametersCpp.find (knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end ()) {
732
+ auto threadCount = jniUtil->ConvertJavaObjectToCppInteger (env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
733
+ omp_set_num_threads (threadCount);
734
+ }
735
+
736
+ // Train index if needed
737
+ auto *trainingVectorsPointerCpp = reinterpret_cast <std::vector<float >*>(trainVectorsPointerJ);
738
+ int numVectors = trainingVectorsPointerCpp->size ()/(int ) dimensionJ;
739
+ if (!indexWriter->is_trained ) {
740
+ InternalTrainBinaryIndex (indexWriter.get (), numVectors, trainingVectorsPointerCpp->data ());
741
+ }
742
+ jniUtil->DeleteLocalRef (env, parametersJ);
743
+
744
+ // Now that indexWriter is trained, we just load the bytes into an array and return
745
+ faiss::VectorIOWriter vectorIoWriter;
746
+ faiss::write_index_binary (indexWriter.get (), &vectorIoWriter);
747
+
748
+ // Wrap in smart pointer
749
+ std::unique_ptr<jbyte[]> jbytesBuffer;
750
+ jbytesBuffer.reset (new jbyte[vectorIoWriter.data .size ()]);
751
+ int c = 0 ;
752
+ for (auto b : vectorIoWriter.data ) {
753
+ jbytesBuffer[c++] = (jbyte) b;
754
+ }
755
+
756
+ jbyteArray ret = jniUtil->NewByteArray (env, vectorIoWriter.data .size ());
757
+ jniUtil->SetByteArrayRegion (env, ret, 0 , vectorIoWriter.data .size (), jbytesBuffer.get ());
758
+ return ret;
759
+ }
760
+
637
761
faiss::MetricType TranslateSpaceToMetric (const std::string& spaceType) {
638
762
if (spaceType == knn_jni::L2) {
639
763
return faiss::METRIC_L2;
@@ -692,6 +816,15 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) {
692
816
}
693
817
}
694
818
819
+ void InternalTrainBinaryIndex (faiss::IndexBinary * index, faiss::idx_t n, const float * x) {
820
+ if (auto * indexIvf = dynamic_cast <faiss::IndexBinaryIVF*>(index )) {
821
+ indexIvf->make_direct_map ();
822
+ }
823
+ if (!index ->is_trained ) {
824
+ index ->train (n, reinterpret_cast <const uint8_t *>(x));
825
+ }
826
+ }
827
+
695
828
std::unique_ptr<faiss::IDGrouperBitmap> buildIDGrouperBitmap (knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector<uint64_t >* bitmap) {
696
829
int *parentIdsArray = jniUtil->GetIntArrayElements (env, parentIdsJ, nullptr );
697
830
int parentIdsLength = jniUtil->GetJavaIntArrayLength (env, parentIdsJ);
0 commit comments