From 9ab5967931ad2fcc79bdb4acd5d0f58af726eebd Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 11 Oct 2023 11:18:04 -0700 Subject: [PATCH] use precomputed norms for raft brute_force knn calls --- faiss/gpu/GpuDistance.cu | 83 ++++++++++++++++----------------- faiss/gpu/impl/RaftFlatIndex.cu | 37 ++++++--------- 2 files changed, 52 insertions(+), 68 deletions(-) diff --git a/faiss/gpu/GpuDistance.cu b/faiss/gpu/GpuDistance.cu index 5965518445..c363aa4bb8 100644 --- a/faiss/gpu/GpuDistance.cu +++ b/faiss/gpu/GpuDistance.cu @@ -236,89 +236,84 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) { raft::device_resources& handle = res->getRaftHandleCurrentDevice(); auto stream = res->getDefaultStreamCurrentDevice(); - idx_t dims = args.dims; - idx_t num_vectors = args.numVectors; - idx_t num_queries = args.numQueries; + int64_t dims = args.dims; + int64_t num_vectors = args.numVectors; + int64_t num_queries = args.numQueries; int k = args.k; float metric_arg = args.metricArg; - auto inds = raft::make_writeback_temporary_device_buffer( - handle, - reinterpret_cast(args.outIndices), - raft::matrix_extent(num_queries, (idx_t)k)); - auto dists = raft::make_writeback_temporary_device_buffer( - handle, - reinterpret_cast(args.outDistances), - raft::matrix_extent(num_queries, (idx_t)k)); + auto inds = + raft::make_writeback_temporary_device_buffer( + handle, + reinterpret_cast(args.outIndices), + raft::matrix_extent(num_queries, (int64_t)k)); + auto dists = + raft::make_writeback_temporary_device_buffer( + handle, + reinterpret_cast(args.outDistances), + raft::matrix_extent(num_queries, (int64_t)k)); if (args.queriesRowMajor) { auto index = raft::make_readonly_temporary_device_buffer< const float, - idx_t, + int64_t, raft::row_major>( handle, const_cast( reinterpret_cast(args.vectors)), - raft::matrix_extent(num_vectors, dims)); + raft::matrix_extent(num_vectors, dims)); auto search = raft::make_readonly_temporary_device_buffer< const float, - idx_t, + int64_t, raft::row_major>( handle, const_cast( reinterpret_cast(args.queries)), - raft::matrix_extent(num_queries, dims)); + raft::matrix_extent(num_queries, dims)); - // For now, use RAFT's fused KNN when k <= 64 and L2 metric is used - if (args.k <= 64 && args.metric == MetricType::METRIC_L2 && - args.numVectors > 0) { - RAFT_LOG_INFO("Invoking flat fused_l2_knn"); - brute_force::fused_l2_knn( - handle, - index.view(), - search.view(), - inds.view(), - dists.view(), - distance); - } else { - std::vector>> + norms; + std::optional> + norms_view; + if (args.vectorNorms) { + norms = raft::make_readonly_temporary_device_buffer< const float, - idx_t, - raft::row_major>> - index_vec = {index.view()}; - RAFT_LOG_INFO("Invoking flat bfknn"); - brute_force::knn( + int64_t>( handle, - index_vec, - search.view(), - inds.view(), - dists.view(), - distance, - metric_arg); + args.vectorNorms, + raft::vector_extent(num_queries)); + norms_view = norms->view(); } + raft::neighbors::brute_force::index idx( + handle, index.view(), norms_view, distance, metric_arg); + raft::neighbors::brute_force::search( + handle, idx, search.view(), inds.view(), dists.view()); } else { auto index = raft::make_readonly_temporary_device_buffer< const float, - idx_t, + int64_t, raft::col_major>( handle, const_cast( reinterpret_cast(args.vectors)), - raft::matrix_extent(num_vectors, dims)); + raft::matrix_extent(num_vectors, dims)); auto search = raft::make_readonly_temporary_device_buffer< const float, - idx_t, + int64_t, raft::col_major>( handle, const_cast( reinterpret_cast(args.queries)), - raft::matrix_extent(num_queries, dims)); + raft::matrix_extent(num_queries, dims)); std::vector> index_vec = {index.view()}; RAFT_LOG_INFO("Invoking flat bfknn"); diff --git a/faiss/gpu/impl/RaftFlatIndex.cu b/faiss/gpu/impl/RaftFlatIndex.cu index fb0f815368..8f5c491163 100644 --- a/faiss/gpu/impl/RaftFlatIndex.cu +++ b/faiss/gpu/impl/RaftFlatIndex.cu @@ -77,41 +77,30 @@ void RaftFlatIndex::query( raft::device_resources& handle = resources_->getRaftHandleCurrentDevice(); - auto index = raft::make_device_matrix_view( + auto index = raft::make_device_matrix_view( vectors_.data(), vectors_.getSize(0), vectors_.getSize(1)); - auto search = raft::make_device_matrix_view( + auto search = raft::make_device_matrix_view( input.data(), input.getSize(0), input.getSize(1)); - auto inds = raft::make_device_matrix_view( + + auto inds = raft::make_device_matrix_view( outIndices.data(), outIndices.getSize(0), outIndices.getSize(1)); - auto dists = raft::make_device_matrix_view( + auto dists = raft::make_device_matrix_view( outDistances.data(), outDistances.getSize(0), outDistances.getSize(1)); DistanceType distance = faiss_to_raft(metric, exactDistance); - std::vector> index_vec = { - index}; - - // For now, use RAFT's fused KNN when k <= 64 and L2 metric is used - if (k <= 64 && metric == MetricType::METRIC_L2 && - vectors_.getSize(0) > 0) { - RAFT_LOG_INFO("Invoking flat fused_l2_knn"); - brute_force::fused_l2_knn( - handle, index, search, inds, dists, distance); - } else { - RAFT_LOG_INFO("Invoking flat bfknn"); - brute_force::knn( - handle, - index_vec, - search, - inds, - dists, - distance, - metricArg); - } + std::optional> + norms_view = raft::make_device_vector_view( + norms_.data(), norms_.getSize(0)); + + raft::neighbors::brute_force::index idx( + handle, index, norms_view, distance, metricArg); + raft::neighbors::brute_force::search( + handle, idx, search, inds, dists); if (metric == MetricType::METRIC_Lp) { raft::linalg::unary_op(