@@ -186,7 +186,7 @@ static void add_typed_to_index( //
186
186
byte_t const * keys_data = reinterpret_cast <byte_t const *>(keys_info.ptr );
187
187
188
188
executor_default_t {threads}.execute_bulk (vectors_count, [&](std::size_t thread_idx, std::size_t task_idx) {
189
- index_dense_add_config_t config;
189
+ index_dense_update_config_t config;
190
190
config.force_vector_copy = force_copy;
191
191
config.thread = thread_idx;
192
192
key_t key = *reinterpret_cast <key_t const *>(keys_data + task_idx * keys_info.strides [0 ]);
@@ -246,7 +246,7 @@ static void search_typed( //
246
246
dense_index_py_t & index, py::buffer_info& vectors_info, //
247
247
std::size_t wanted, bool exact, std::size_t threads, //
248
248
py::array_t <key_t >& keys_py, py::array_t <distance_t >& distances_py, py::array_t <Py_ssize_t>& counts_py,
249
- std::atomic<std::size_t >& stats_lookups , std::atomic<std::size_t >& stats_measurements ) {
249
+ std::atomic<std::size_t >& stats_visited_members , std::atomic<std::size_t >& stats_computed_distances ) {
250
250
251
251
auto keys_py2d = keys_py.template mutable_unchecked <2 >();
252
252
auto distances_py2d = distances_py.template mutable_unchecked <2 >();
@@ -270,8 +270,8 @@ static void search_typed( //
270
270
counts_py1d (task_idx) =
271
271
static_cast <Py_ssize_t>(result.dump_to (&keys_py2d (task_idx, 0 ), &distances_py2d (task_idx, 0 )));
272
272
273
- stats_lookups += result.lookups ;
274
- stats_measurements += result.measurements ;
273
+ stats_visited_members += result.visited_members ;
274
+ stats_computed_distances += result.computed_distances ;
275
275
if (PyErr_CheckSignals () != 0 )
276
276
throw py::error_already_set ();
277
277
});
@@ -282,7 +282,7 @@ static void search_typed( //
282
282
dense_indexes_py_t & indexes, py::buffer_info& vectors_info, //
283
283
std::size_t wanted, bool exact, std::size_t threads, //
284
284
py::array_t <key_t >& keys_py, py::array_t <distance_t >& distances_py, py::array_t <Py_ssize_t>& counts_py,
285
- std::atomic<std::size_t >& stats_lookups , std::atomic<std::size_t >& stats_measurements ) {
285
+ std::atomic<std::size_t >& stats_visited_members , std::atomic<std::size_t >& stats_computed_distances ) {
286
286
287
287
auto keys_py2d = keys_py.template mutable_unchecked <2 >();
288
288
auto distances_py2d = distances_py.template mutable_unchecked <2 >();
@@ -324,8 +324,8 @@ static void search_typed( //
324
324
wanted));
325
325
}
326
326
327
- stats_lookups += result.lookups ;
328
- stats_measurements += result.measurements ;
327
+ stats_visited_members += result.visited_members ;
328
+ stats_computed_distances += result.computed_distances ;
329
329
if (PyErr_CheckSignals () != 0 )
330
330
throw py::error_already_set ();
331
331
}
@@ -363,16 +363,16 @@ static py::tuple search_many_in_index( //
363
363
py::array_t <key_t > keys_py ({vectors_count, static_cast <Py_ssize_t>(wanted)});
364
364
py::array_t <distance_t > distances_py ({vectors_count, static_cast <Py_ssize_t>(wanted)});
365
365
py::array_t <Py_ssize_t> counts_py (vectors_count);
366
- std::atomic<std::size_t > stats_lookups (0 );
367
- std::atomic<std::size_t > stats_measurements (0 );
366
+ std::atomic<std::size_t > stats_visited_members (0 );
367
+ std::atomic<std::size_t > stats_computed_distances (0 );
368
368
369
369
// clang-format off
370
370
switch (numpy_string_to_kind (vectors_info.format )) {
371
- case scalar_kind_t ::b1x8_k: search_typed<b1x8_t >(index , vectors_info, wanted, exact, threads, keys_py, distances_py, counts_py, stats_lookups, stats_measurements ); break ;
372
- case scalar_kind_t ::f8_k: search_typed<f8_bits_t >(index , vectors_info, wanted, exact, threads, keys_py, distances_py, counts_py, stats_lookups, stats_measurements ); break ;
373
- case scalar_kind_t ::f16_k: search_typed<f16_t >(index , vectors_info, wanted, exact, threads, keys_py, distances_py, counts_py, stats_lookups, stats_measurements ); break ;
374
- case scalar_kind_t ::f32_k: search_typed<f32_t >(index , vectors_info, wanted, exact, threads, keys_py, distances_py, counts_py, stats_lookups, stats_measurements ); break ;
375
- case scalar_kind_t ::f64_k: search_typed<f64_t >(index , vectors_info, wanted, exact, threads, keys_py, distances_py, counts_py, stats_lookups, stats_measurements ); break ;
371
+ case scalar_kind_t ::b1x8_k: search_typed<b1x8_t >(index , vectors_info, wanted, exact, threads, keys_py, distances_py, counts_py, stats_visited_members, stats_computed_distances ); break ;
372
+ case scalar_kind_t ::f8_k: search_typed<f8_bits_t >(index , vectors_info, wanted, exact, threads, keys_py, distances_py, counts_py, stats_visited_members, stats_computed_distances ); break ;
373
+ case scalar_kind_t ::f16_k: search_typed<f16_t >(index , vectors_info, wanted, exact, threads, keys_py, distances_py, counts_py, stats_visited_members, stats_computed_distances ); break ;
374
+ case scalar_kind_t ::f32_k: search_typed<f32_t >(index , vectors_info, wanted, exact, threads, keys_py, distances_py, counts_py, stats_visited_members, stats_computed_distances ); break ;
375
+ case scalar_kind_t ::f64_k: search_typed<f64_t >(index , vectors_info, wanted, exact, threads, keys_py, distances_py, counts_py, stats_visited_members, stats_computed_distances ); break ;
376
376
default : throw std::invalid_argument (" Incompatible scalars in the query matrix: " + vectors_info.format );
377
377
}
378
378
// clang-format on
@@ -381,8 +381,8 @@ static py::tuple search_many_in_index( //
381
381
results[0 ] = keys_py;
382
382
results[1 ] = distances_py;
383
383
results[2 ] = counts_py;
384
- results[3 ] = stats_lookups .load ();
385
- results[4 ] = stats_measurements .load ();
384
+ results[3 ] = stats_visited_members .load ();
385
+ results[4 ] = stats_computed_distances .load ();
386
386
return results;
387
387
}
388
388
@@ -391,21 +391,18 @@ static std::unordered_map<key_t, key_t> join_index( //
391
391
std::size_t max_proposals, bool exact) {
392
392
393
393
std::unordered_map<key_t , key_t > a_to_b;
394
+ dummy_label_to_label_mapping_t b_to_a;
394
395
a_to_b.reserve ((std::min)(a.size (), b.size ()));
395
396
396
- // index_join_config_t config;
397
-
398
- // config.max_proposals = max_proposals;
399
- // config.exact = exact;
400
- // config.expansion = (std::max)(a.expansion_search(), b.expansion_search());
401
- // std::size_t threads = (std::min)(a.limits().threads(), b.limits().threads());
402
- // executor_default_t executor{threads};
403
- // join_result_t result = dense_index_py_t::join( //
404
- // a, b, config, //
405
- // a_to_b, //
406
- // dummy_label_to_label_mapping_t{}, //
407
- // executor);
408
- // result.error.raise();
397
+ index_join_config_t config;
398
+ config.max_proposals = max_proposals;
399
+ config.exact = exact;
400
+ config.expansion = (std::max)(a.expansion_search (), b.expansion_search ());
401
+ std::size_t threads = (std::min)(a.limits ().threads (), b.limits ().threads ());
402
+ executor_default_t executor{threads};
403
+ join_result_t result = a.join (b, config, a_to_b, b_to_a, executor);
404
+ result.error .raise ();
405
+
409
406
return a_to_b;
410
407
}
411
408
@@ -418,6 +415,16 @@ static dense_index_py_t copy_index(dense_index_py_t const& index) {
418
415
return std::move (result.index );
419
416
}
420
417
418
+ static void compact_index (dense_index_py_t & index, std::size_t threads) {
419
+
420
+ if (!threads)
421
+ threads = std::thread::hardware_concurrency ();
422
+ if (!index .reserve (index_limits_t (index .size (), threads)))
423
+ throw std::invalid_argument (" Out of memory!" );
424
+
425
+ index .compact (executor_default_t {threads});
426
+ }
427
+
421
428
// clang-format off
422
429
template <typename index_at> void save_index (index_at const & index, std::string const & path) { index .save (path.c_str ()).error .raise (); }
423
430
template <typename index_at> void load_index (index_at& index, std::string const & path) { index .load (path.c_str ()).error .raise (); }
@@ -601,7 +608,7 @@ PYBIND11_MODULE(compiled, m) {
601
608
if (!index .reserve (index_limits_t (index .size (), threads)))
602
609
throw std::invalid_argument (" Out of memory!" );
603
610
604
- index .compact (executor_default_t {threads});
611
+ index .isolate (executor_default_t {threads});
605
612
return result.completed ;
606
613
},
607
614
py::arg (" key" ), py::arg (" compact" ), py::arg (" threads" ));
@@ -619,7 +626,7 @@ PYBIND11_MODULE(compiled, m) {
619
626
if (!index .reserve (index_limits_t (index .size (), threads)))
620
627
throw std::invalid_argument (" Out of memory!" );
621
628
622
- index .compact (executor_default_t {threads});
629
+ index .isolate (executor_default_t {threads});
623
630
return result.completed ;
624
631
},
625
632
py::arg (" key" ), py::arg (" compact" ), py::arg (" threads" ));
@@ -651,6 +658,7 @@ PYBIND11_MODULE(compiled, m) {
651
658
i.def (" reset" , &reset_index<dense_index_py_t >, py::call_guard<py::gil_scoped_release>());
652
659
i.def (" clear" , &clear_index<dense_index_py_t >, py::call_guard<py::gil_scoped_release>());
653
660
i.def (" copy" , ©_index, py::call_guard<py::gil_scoped_release>());
661
+ i.def (" compact" , &compact_index, py::call_guard<py::gil_scoped_release>());
654
662
i.def (" join" , &join_index, py::arg (" other" ), py::arg (" max_proposals" ) = 0 , py::arg (" exact" ) = false ,
655
663
py::call_guard<py::gil_scoped_release>());
656
664
0 commit comments