Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 73 additions & 51 deletions cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <raft/core/detail/macros.hpp> // RAFT_WEAK_FUNCTION
#include <raft/distance/distance_types.hpp> // raft::distance::DistanceType
#include <raft/neighbors/detail/ivf_pq_fp_8bit.cuh> // raft::neighbors::ivf_pq::detail::fp_8bit
#include <raft/neighbors/detail/sample_filter.cuh> // NoneSampleFilter
#include <raft/neighbors/ivf_pq_types.hpp> // raft::neighbors::ivf_pq::codebook_gen
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
#include <rmm/cuda_stream_view.hpp> // rmm::cuda_stream_view
Expand All @@ -36,6 +37,7 @@ auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, ui

template <typename OutT,
typename LutT,
typename SampleFilterT,
uint32_t PqBits,
int Capacity,
bool PrecompBaseDiff,
Expand All @@ -45,6 +47,7 @@ __global__ void compute_similarity_kernel(uint32_t n_rows,
uint32_t n_probes,
uint32_t pq_dim,
uint32_t n_queries,
uint32_t queries_offset,
distance::DistanceType metric,
codebook_gen codebook_kind,
uint32_t topk,
Expand All @@ -57,32 +60,34 @@ __global__ void compute_similarity_kernel(uint32_t n_rows,
const float* queries,
const uint32_t* index_list,
float* query_kths,
SampleFilterT sample_filter,
LutT* lut_scores,
OutT* _out_scores,
uint32_t* _out_indices) RAFT_EXPLICIT;

// The signature of the kernel defined by a minimal set of template parameters
template <typename OutT, typename LutT>
template <typename OutT, typename LutT, typename SampleFilterT>
using compute_similarity_kernel_t =
decltype(&compute_similarity_kernel<OutT, LutT, 8, 0, true, true>);
decltype(&compute_similarity_kernel<OutT, LutT, SampleFilterT, 8, 0, true, true>);

template <typename OutT, typename LutT>
template <typename OutT, typename LutT, typename SampleFilterT>
struct selected {
compute_similarity_kernel_t<OutT, LutT> kernel;
compute_similarity_kernel_t<OutT, LutT, SampleFilterT> kernel;
dim3 grid_dim;
dim3 block_dim;
size_t smem_size;
size_t device_lut_size;
};

template <typename OutT, typename LutT>
void compute_similarity_run(selected<OutT, LutT> s,
template <typename OutT, typename LutT, typename SampleFilterT>
void compute_similarity_run(selected<OutT, LutT, SampleFilterT> s,
rmm::cuda_stream_view stream,
uint32_t n_rows,
uint32_t dim,
uint32_t n_probes,
uint32_t pq_dim,
uint32_t n_queries,
uint32_t queries_offset,
distance::DistanceType metric,
codebook_gen codebook_kind,
uint32_t topk,
Expand All @@ -95,6 +100,7 @@ void compute_similarity_run(selected<OutT, LutT> s,
const float* queries,
const uint32_t* index_list,
float* query_kths,
SampleFilterT sample_filter,
LutT* lut_scores,
OutT* _out_scores,
uint32_t* _out_indices) RAFT_EXPLICIT;
Expand All @@ -113,7 +119,7 @@ void compute_similarity_run(selected<OutT, LutT> s,
* beyond this limit do not consider increasing the number of active blocks per SM
* would improve locality anymore.
*/
template <typename OutT, typename LutT>
template <typename OutT, typename LutT, typename SampleFilterT>
auto compute_similarity_select(const cudaDeviceProp& dev_props,
bool manage_local_topk,
int locality_hint,
Expand All @@ -123,62 +129,78 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
uint32_t precomp_data_count,
uint32_t n_queries,
uint32_t n_probes,
uint32_t topk) -> selected<OutT, LutT> RAFT_EXPLICIT;
uint32_t topk) -> selected<OutT, LutT, SampleFilterT> RAFT_EXPLICIT;

} // namespace raft::neighbors::ivf_pq::detail

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \
extern template auto raft::neighbors::ivf_pq::detail::compute_similarity_select<OutT, LutT>( \
const cudaDeviceProp& dev_props, \
bool manage_local_topk, \
int locality_hint, \
double preferred_shmem_carveout, \
uint32_t pq_bits, \
uint32_t pq_dim, \
uint32_t precomp_data_count, \
uint32_t n_queries, \
uint32_t n_probes, \
uint32_t topk) \
->raft::neighbors::ivf_pq::detail::selected<OutT, LutT>; \
\
extern template void raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
uint32_t n_queries, \
raft::distance::DistanceType metric, \
raft::neighbors::ivf_pq::codebook_gen codebook_kind, \
uint32_t topk, \
uint32_t max_samples, \
const float* cluster_centers, \
const float* pq_centers, \
const uint8_t* const* pq_dataset, \
const uint32_t* cluster_labels, \
const uint32_t* _chunk_indices, \
const float* queries, \
const uint32_t* index_list, \
float* query_kths, \
LutT* lut_scores, \
OutT* _out_scores, \
#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \
OutT, LutT, SampleFilterT) \
extern template auto \
raft::neighbors::ivf_pq::detail::compute_similarity_select<OutT, LutT, SampleFilterT>( \
const cudaDeviceProp& dev_props, \
bool manage_local_topk, \
int locality_hint, \
double preferred_shmem_carveout, \
uint32_t pq_bits, \
uint32_t pq_dim, \
uint32_t precomp_data_count, \
uint32_t n_queries, \
uint32_t n_probes, \
uint32_t topk) \
->raft::neighbors::ivf_pq::detail::selected<OutT, LutT, SampleFilterT>; \
\
extern template void \
raft::neighbors::ivf_pq::detail::compute_similarity_run<OutT, LutT, SampleFilterT>( \
raft::neighbors::ivf_pq::detail::selected<OutT, LutT, SampleFilterT> s, \
rmm::cuda_stream_view stream, \
uint32_t n_rows, \
uint32_t dim, \
uint32_t n_probes, \
uint32_t pq_dim, \
uint32_t n_queries, \
uint32_t queries_offset, \
raft::distance::DistanceType metric, \
raft::neighbors::ivf_pq::codebook_gen codebook_kind, \
uint32_t topk, \
uint32_t max_samples, \
const float* cluster_centers, \
const float* pq_centers, \
const uint8_t* const* pq_dataset, \
const uint32_t* cluster_labels, \
const uint32_t* _chunk_indices, \
const float* queries, \
const uint32_t* index_list, \
float* query_kths, \
SampleFilterT sample_filter, \
LutT* lut_scores, \
OutT* _out_scores, \
uint32_t* _out_indices);

#define COMMA ,
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>);
half,
raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>,
raft::neighbors::ivf_pq::detail::NoneSampleFilter);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(half, half);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, half);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, float);
half,
raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>,
raft::neighbors::ivf_pq::detail::NoneSampleFilter);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>);
half, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>);
float, half, raft::neighbors::ivf_pq::detail::NoneSampleFilter);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float, float, raft::neighbors::ivf_pq::detail::NoneSampleFilter);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float,
raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>,
raft::neighbors::ivf_pq::detail::NoneSampleFilter);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float,
raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>,
raft::neighbors::ivf_pq::detail::NoneSampleFilter);

#undef COMMA

Expand Down
Loading