@@ -298,7 +298,7 @@ static void search_typed( //
298
298
if (!threads)
299
299
threads = std::thread::hardware_concurrency ();
300
300
301
- std::vector<std::mutex> vectors_mutexes (static_cast <std::size_t >(vectors_count));
301
+ std::vector<std::mutex> query_mutexes (static_cast <std::size_t >(vectors_count));
302
302
executor_default_t {threads}.execute_bulk (indexes.shards_ .size (), [&](std::size_t , std::size_t task_idx) {
303
303
dense_index_py_t & index = *indexes.shards_ [task_idx].get ();
304
304
@@ -318,7 +318,7 @@ static void search_typed( //
318
318
dense_search_result_t result = index .search (vector, wanted, config);
319
319
result.error .raise ();
320
320
{
321
- std::unique_lock<std::mutex> lock (vectors_mutexes [vector_idx]);
321
+ std::unique_lock<std::mutex> lock (query_mutexes [vector_idx]);
322
322
counts_py1d (vector_idx) = static_cast <Py_ssize_t>(result.merge_into ( //
323
323
&keys_py2d (vector_idx, 0 ), //
324
324
&distances_py2d (vector_idx, 0 ), //
@@ -348,7 +348,7 @@ static py::tuple search_many_in_index( //
348
348
index_at& index, py::buffer vectors, std::size_t wanted, bool exact, std::size_t threads) {
349
349
350
350
if (wanted == 0 )
351
- return py::tuple (3 );
351
+ return py::tuple (5 );
352
352
353
353
if (index .limits ().threads_search < threads)
354
354
throw std::invalid_argument (" Can't use that many threads!" );
@@ -388,6 +388,123 @@ static py::tuple search_many_in_index( //
388
388
return results;
389
389
}
390
390
391
+ template <typename scalar_at>
392
+ static void search_typed_brute_force ( //
393
+ py::buffer_info& dataset_info, py::buffer_info& queries_info, //
394
+ std::size_t wanted, std::size_t threads, metric_t const & metric, //
395
+ py::array_t <key_t >& keys_py, py::array_t <distance_t >& distances_py, py::array_t <Py_ssize_t>& counts_py) {
396
+
397
+ auto keys_py2d = keys_py.template mutable_unchecked <2 >();
398
+ auto distances_py2d = distances_py.template mutable_unchecked <2 >();
399
+ auto counts_py1d = counts_py.template mutable_unchecked <1 >();
400
+
401
+ std::size_t dataset_count = static_cast <std::size_t >(dataset_info.shape [0 ]);
402
+ std::size_t queries_count = static_cast <std::size_t >(queries_info.shape [0 ]);
403
+ std::size_t dimensions = static_cast <std::size_t >(dataset_info.shape [1 ]);
404
+
405
+ byte_t const * dataset_data = reinterpret_cast <byte_t const *>(dataset_info.ptr );
406
+ byte_t const * queries_data = reinterpret_cast <byte_t const *>(queries_info.ptr );
407
+ for (std::size_t query_idx = 0 ; query_idx != queries_count; ++query_idx)
408
+ counts_py1d (query_idx) = 0 ;
409
+
410
+ if (!threads)
411
+ threads = std::thread::hardware_concurrency ();
412
+
413
+ std::size_t tasks_count = static_cast <std::size_t >(dataset_count * queries_count);
414
+ visits_bitset_t query_mutexes (static_cast <std::size_t >(queries_count));
415
+ if (!query_mutexes)
416
+ throw std::bad_alloc ();
417
+
418
+ executor_default_t {threads}.execute_bulk (tasks_count, [&](std::size_t , std::size_t task_idx) {
419
+ //
420
+ std::size_t dataset_idx = task_idx / queries_count;
421
+ std::size_t query_idx = task_idx % queries_count;
422
+
423
+ byte_t const * dataset = dataset_data + dataset_idx * dataset_info.strides [0 ];
424
+ byte_t const * query = queries_data + query_idx * queries_info.strides [0 ];
425
+ distance_t distance = metric (dataset, query);
426
+
427
+ {
428
+ auto lock = query_mutexes.lock (query_idx);
429
+ key_t * keys = &keys_py2d (query_idx, 0 );
430
+ distance_t * distances = &distances_py2d (query_idx, 0 );
431
+ std::size_t & matches = reinterpret_cast <std::size_t &>(counts_py1d (query_idx));
432
+ if (matches == wanted)
433
+ if (distances[wanted - 1 ] <= distance)
434
+ return ;
435
+
436
+ std::size_t offset = std::lower_bound (distances, distances + matches, distance) - distances;
437
+
438
+ std::size_t count_worse = matches - offset - (wanted == matches);
439
+ std::memmove (keys + offset + 1 , keys + offset, count_worse * sizeof (key_t ));
440
+ std::memmove (distances + offset + 1 , distances + offset, count_worse * sizeof (distance_t ));
441
+ keys[offset] = static_cast <key_t >(dataset_idx);
442
+ distances[offset] = distance;
443
+ matches += matches != wanted;
444
+ }
445
+
446
+ if (PyErr_CheckSignals () != 0 )
447
+ throw py::error_already_set ();
448
+ });
449
+ }
450
+
451
+ static py::tuple search_many_brute_force ( //
452
+ py::buffer dataset, py::buffer queries, //
453
+ std::size_t wanted, std::size_t threads, //
454
+ metric_kind_t metric_kind, //
455
+ metric_signature_t metric_signature, //
456
+ std::uintptr_t metric_uintptr) {
457
+
458
+ if (wanted == 0 )
459
+ return py::tuple (5 );
460
+
461
+ py::buffer_info dataset_info = dataset.request ();
462
+ py::buffer_info queries_info = queries.request ();
463
+ if (dataset_info.ndim != 2 || queries_info.ndim != 2 )
464
+ throw std::invalid_argument (" Expects a matrix of dataset to add!" );
465
+
466
+ Py_ssize_t dataset_count = dataset_info.shape [0 ];
467
+ Py_ssize_t dataset_dimensions = dataset_info.shape [1 ];
468
+ Py_ssize_t queries_count = queries_info.shape [0 ];
469
+ Py_ssize_t queries_dimensions = queries_info.shape [1 ];
470
+ if (dataset_dimensions != queries_dimensions)
471
+ throw std::invalid_argument (" The number of vector dimensions doesn't match!" );
472
+
473
+ scalar_kind_t dataset_kind = numpy_string_to_kind (dataset_info.format );
474
+ scalar_kind_t queries_kind = numpy_string_to_kind (queries_info.format );
475
+ if (dataset_kind != queries_kind)
476
+ throw std::invalid_argument (" The types of vectors don't match!" );
477
+
478
+ py::array_t <key_t > keys_py ({dataset_count, static_cast <Py_ssize_t>(wanted)});
479
+ py::array_t <distance_t > distances_py ({dataset_count, static_cast <Py_ssize_t>(wanted)});
480
+ py::array_t <Py_ssize_t> counts_py (dataset_count);
481
+
482
+ std::size_t dimensions = static_cast <std::size_t >(queries_dimensions);
483
+ metric_t metric = //
484
+ metric_uintptr //
485
+ ? udf (metric_kind, metric_signature, metric_uintptr, queries_kind, dimensions)
486
+ : metric_t (dimensions, metric_kind, queries_kind);
487
+
488
+ // clang-format off
489
+ switch (dataset_kind) {
490
+ case scalar_kind_t ::b1x8_k: search_typed_brute_force<b1x8_t >(dataset_info, queries_info, wanted, threads, metric, keys_py, distances_py, counts_py); break ;
491
+ case scalar_kind_t ::i8_k: search_typed_brute_force<i8_bits_t >(dataset_info, queries_info, wanted, threads, metric, keys_py, distances_py, counts_py); break ;
492
+ case scalar_kind_t ::f16_k: search_typed_brute_force<f16_t >(dataset_info, queries_info, wanted, threads, metric, keys_py, distances_py, counts_py); break ;
493
+ case scalar_kind_t ::f32_k: search_typed_brute_force<f32_t >(dataset_info, queries_info, wanted, threads, metric, keys_py, distances_py, counts_py); break ;
494
+ case scalar_kind_t ::f64_k: search_typed_brute_force<f64_t >(dataset_info, queries_info, wanted, threads, metric, keys_py, distances_py, counts_py); break ;
495
+ default : throw std::invalid_argument (" Incompatible vector types: " + dataset_info.format );
496
+ }
497
+ // clang-format on
498
+
499
+ py::tuple results (5 );
500
+ results[0 ] = keys_py;
501
+ results[1 ] = distances_py;
502
+ results[2 ] = counts_py;
503
+ results[3 ] = 0 ;
504
+ results[4 ] = static_cast <std::size_t >(dataset_count * queries_count);
505
+ return results;
506
+ }
507
+
391
508
static std::unordered_map<key_t , key_t > join_index ( //
392
509
dense_index_py_t const & a, dense_index_py_t const & b, //
393
510
std::size_t max_proposals, bool exact) {
@@ -505,7 +622,7 @@ PYBIND11_MODULE(compiled, m) {
505
622
py::enum_<metric_kind_t >(m, " MetricKind" )
506
623
.value (" Unknown" , metric_kind_t ::unknown_k)
507
624
508
- .value (" IP" , metric_kind_t ::ip_k )
625
+ .value (" IP" , metric_kind_t ::cos_k )
509
626
.value (" Cos" , metric_kind_t ::cos_k)
510
627
.value (" L2sq" , metric_kind_t ::l2sq_k)
511
628
@@ -517,7 +634,7 @@ PYBIND11_MODULE(compiled, m) {
517
634
.value (" Sorensen" , metric_kind_t ::sorensen_k)
518
635
519
636
.value (" Cosine" , metric_kind_t ::cos_k)
520
- .value (" InnerProduct" , metric_kind_t ::ip_k );
637
+ .value (" InnerProduct" , metric_kind_t ::cos_k );
521
638
522
639
py::enum_<scalar_kind_t >(m, " ScalarKind" )
523
640
.value (" Unknown" , scalar_kind_t ::unknown_k)
@@ -562,13 +679,24 @@ PYBIND11_MODULE(compiled, m) {
562
679
return result;
563
680
});
564
681
682
+ m.def (" exact_search" , &search_many_brute_force, //
683
+ py::arg (" dataset" ), //
684
+ py::arg (" queries" ), //
685
+ py::arg (" count" ) = 10 , //
686
+ py::kw_only (), //
687
+ py::arg (" threads" ) = 0 , //
688
+ py::arg (" metric_kind" ) = metric_kind_t ::cos_k, //
689
+ py::arg (" metric_signature" ) = metric_signature_t ::array_array_k, //
690
+ py::arg (" metric_pointer" ) = 0 //
691
+ );
692
+
565
693
auto i = py::class_<dense_index_py_t , std::shared_ptr<dense_index_py_t >>(m, " Index" );
566
694
567
695
i.def (py::init (&make_index), //
568
696
py::kw_only (), //
569
697
py::arg (" ndim" ) = 0 , //
570
698
py::arg (" dtype" ) = scalar_kind_t ::f32_k, //
571
- py::arg (" metric_kind" ) = metric_kind_t ::ip_k, //
699
+ py::arg (" metric_kind" ) = metric_kind_t ::cos_k, //
572
700
py::arg (" connectivity" ) = default_connectivity (), //
573
701
py::arg (" expansion_add" ) = default_expansion_add (), //
574
702
py::arg (" expansion_search" ) = default_expansion_search (), //
0 commit comments