diff --git a/faiss/gpu/impl/IVFInterleaved.cu b/faiss/gpu/impl/IVFInterleaved.cu index 3aca959e78..c9ee87ee42 100644 --- a/faiss/gpu/impl/IVFInterleaved.cu +++ b/faiss/gpu/impl/IVFInterleaved.cu @@ -192,24 +192,43 @@ void runIVFInterleavedScan( // caught for exceptions at a higher level FAISS_ASSERT(k <= GPU_MAX_SELECTION_K); + const auto ivf_interleaved_call = [&](const auto func) { + func(queries, + listIds, + listData, + listIndices, + indicesOptions, + listLengths, + k, + metric, + useResidual, + residualBase, + scalarQ, + outDistances, + outIndices, + res); + }; + if (k == 1) { - IVF_INTERLEAVED_CALL(1); + ivf_interleaved_call(ivfInterleavedScanImpl); } else if (k <= 32) { - IVF_INTERLEAVED_CALL(32); + ivf_interleaved_call(ivfInterleavedScanImpl); } else if (k <= 64) { - IVF_INTERLEAVED_CALL(64); + ivf_interleaved_call(ivfInterleavedScanImpl); } else if (k <= 128) { - IVF_INTERLEAVED_CALL(128); + ivf_interleaved_call(ivfInterleavedScanImpl); } else if (k <= 256) { - IVF_INTERLEAVED_CALL(256); + ivf_interleaved_call(ivfInterleavedScanImpl); } else if (k <= 512) { - IVF_INTERLEAVED_CALL(512); + ivf_interleaved_call(ivfInterleavedScanImpl); } else if (k <= 1024) { - IVF_INTERLEAVED_CALL(1024); + ivf_interleaved_call( + ivfInterleavedScanImpl); } #if GPU_MAX_SELECTION_K >= 2048 else if (k <= 2048) { - IVF_INTERLEAVED_CALL(2048); + ivf_interleaved_call( + ivfInterleavedScanImpl); } #endif } diff --git a/faiss/gpu/impl/IVFInterleaved.cuh b/faiss/gpu/impl/IVFInterleaved.cuh index 79d8feb32b..053f99db09 100644 --- a/faiss/gpu/impl/IVFInterleaved.cuh +++ b/faiss/gpu/impl/IVFInterleaved.cuh @@ -122,9 +122,10 @@ __global__ void ivfInterleavedScan( // whole blocks for (int dBase = 0; dBase < dimBlocks; dBase += kWarpSize) { - int loadDim = dBase + laneId; - float queryReg = query[loadDim]; - float residualReg = Residual ? residualBaseSlice[loadDim] : 0; + const int loadDim = dBase + laneId; + const float queryReg = query[loadDim]; + [[maybe_unused]] const float residualReg = + Residual ? residualBaseSlice[loadDim] : 0; constexpr int kUnroll = 4; @@ -151,7 +152,7 @@ __global__ void ivfInterleavedScan( decV[j] = codec.decodeNew(dBase + d, encV[j]); } - if (Residual) { + if constexpr (Residual) { #pragma unroll for (int j = 0; j < kUnroll; ++j) { int d = i * kUnroll + j; @@ -169,13 +170,13 @@ __global__ void ivfInterleavedScan( } // remainder - int loadDim = dimBlocks + laneId; - bool loadDimInBounds = loadDim < dim; + const int loadDim = dimBlocks + laneId; + const bool loadDimInBounds = loadDim < dim; - float queryReg = loadDimInBounds ? query[loadDim] : 0; - float residualReg = Residual && loadDimInBounds - ? residualBaseSlice[loadDim] - : 0; + const float queryReg = loadDimInBounds ? query[loadDim] : 0; + [[maybe_unused]] const float residualReg = + Residual && loadDimInBounds ? residualBaseSlice[loadDim] + : 0; for (int d = 0; d < dim - dimBlocks; ++d, data += wordsPerVectorBlockDim) { @@ -186,7 +187,7 @@ __global__ void ivfInterleavedScan( enc = WarpPackedBits::postRead( laneId, enc); float dec = codec.decodeNew(dimBlocks + d, enc); - if (Residual) { + if constexpr (Residual) { dec += SHFL_SYNC(residualReg, d, kWarpSize); } @@ -217,198 +218,6 @@ __global__ void ivfInterleavedScan( // compile time using these macros to define the function body // -#define IVFINT_RUN(CODEC_TYPE, METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q) \ - do { \ - dim3 grid(nprobe, std::min(nq, (idx_t)getMaxGridCurrentDevice().y)); \ - if (useResidual) { \ - ivfInterleavedScan< \ - CODEC_TYPE, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q, \ - true><<>>( \ - queries, \ - residualBase, \ - listIds, \ - listData.data(), \ - listLengths.data(), \ - codec, \ - metric, \ - k, \ - distanceTemp, \ - indicesTemp); \ - } else { \ - ivfInterleavedScan< \ - CODEC_TYPE, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q, \ - false><<>>( \ - queries, \ - residualBase, \ - listIds, \ - listData.data(), \ - listLengths.data(), \ - codec, \ - metric, \ - k, \ - distanceTemp, \ - indicesTemp); \ - } \ - \ - runIVFInterleavedScan2( \ - distanceTemp, \ - indicesTemp, \ - listIds, \ - k, \ - listIndices, \ - indicesOptions, \ - METRIC_TYPE::kDirection, \ - outDistances, \ - outIndices, \ - stream); \ - } while (0); - -#define IVFINT_CODECS(METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q) \ - do { \ - if (!scalarQ) { \ - using CodecT = CodecFloat; \ - CodecT codec(dim * sizeof(float)); \ - IVFINT_RUN( \ - CodecT, METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q); \ - } else { \ - switch (scalarQ->qtype) { \ - case ScalarQuantizer::QuantizerType::QT_8bit: { \ - using CodecT = \ - Codec; \ - CodecT codec( \ - scalarQ->code_size, \ - scalarQ->gpuTrained.data(), \ - scalarQ->gpuTrained.data() + dim); \ - IVFINT_RUN( \ - CodecT, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q); \ - } break; \ - case ScalarQuantizer::QuantizerType::QT_8bit_uniform: { \ - using CodecT = Codec< \ - ScalarQuantizer::QuantizerType::QT_8bit_uniform, \ - 1>; \ - CodecT codec( \ - scalarQ->code_size, \ - scalarQ->trained[0], \ - scalarQ->trained[1]); \ - IVFINT_RUN( \ - CodecT, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q); \ - } break; \ - case ScalarQuantizer::QuantizerType::QT_fp16: { \ - using CodecT = \ - Codec; \ - CodecT codec(scalarQ->code_size); \ - IVFINT_RUN( \ - CodecT, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q); \ - } break; \ - case ScalarQuantizer::QuantizerType::QT_8bit_direct: { \ - using CodecT = Codec< \ - ScalarQuantizer::QuantizerType::QT_8bit_direct, \ - 1>; \ - Codec \ - codec(scalarQ->code_size); \ - IVFINT_RUN( \ - CodecT, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q); \ - } break; \ - case ScalarQuantizer::QuantizerType::QT_6bit: { \ - using CodecT = \ - Codec; \ - Codec codec( \ - scalarQ->code_size, \ - scalarQ->gpuTrained.data(), \ - scalarQ->gpuTrained.data() + dim); \ - IVFINT_RUN( \ - CodecT, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q); \ - } break; \ - case ScalarQuantizer::QuantizerType::QT_4bit: { \ - using CodecT = \ - Codec; \ - Codec codec( \ - scalarQ->code_size, \ - scalarQ->gpuTrained.data(), \ - scalarQ->gpuTrained.data() + dim); \ - IVFINT_RUN( \ - CodecT, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q); \ - } break; \ - case ScalarQuantizer::QuantizerType::QT_4bit_uniform: { \ - using CodecT = Codec< \ - ScalarQuantizer::QuantizerType::QT_4bit_uniform, \ - 1>; \ - Codec \ - codec(scalarQ->code_size, \ - scalarQ->trained[0], \ - scalarQ->trained[1]); \ - IVFINT_RUN( \ - CodecT, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q); \ - } break; \ - default: \ - FAISS_ASSERT(false); \ - } \ - } \ - } while (0) - -#define IVFINT_METRICS(THREADS, NUM_WARP_Q, NUM_THREAD_Q) \ - do { \ - auto stream = res->getDefaultStreamCurrentDevice(); \ - auto nq = queries.getSize(0); \ - auto dim = queries.getSize(1); \ - auto nprobe = listIds.getSize(1); \ - \ - DeviceTensor distanceTemp( \ - res, \ - makeTempAlloc(AllocType::Other, stream), \ - {queries.getSize(0), listIds.getSize(1), k}); \ - DeviceTensor indicesTemp( \ - res, \ - makeTempAlloc(AllocType::Other, stream), \ - {queries.getSize(0), listIds.getSize(1), k}); \ - \ - if (metric == MetricType::METRIC_L2) { \ - L2Distance metric; \ - IVFINT_CODECS(L2Distance, THREADS, NUM_WARP_Q, NUM_THREAD_Q); \ - } else if (metric == MetricType::METRIC_INNER_PRODUCT) { \ - IPDistance metric; \ - IVFINT_CODECS(IPDistance, THREADS, NUM_WARP_Q, NUM_THREAD_Q); \ - } else { \ - FAISS_ASSERT(false); \ - } \ - } while (0) - // Top-level IVF scan function for the interleaved by 32 layout // with all implementations void runIVFInterleavedScan( diff --git a/faiss/gpu/impl/scan/IVFInterleaved1.cu b/faiss/gpu/impl/scan/IVFInterleaved1.cu index 01c94a64ef..c898ec2d6d 100644 --- a/faiss/gpu/impl/scan/IVFInterleaved1.cu +++ b/faiss/gpu/impl/scan/IVFInterleaved1.cu @@ -10,7 +10,7 @@ namespace faiss { namespace gpu { -IVF_INTERLEAVED_IMPL(128, 1, 1) +IVF_INTERLEAVED_IMPL(IVFINTERLEAVED_1_PARAMS) } } // namespace faiss diff --git a/faiss/gpu/impl/scan/IVFInterleaved1024.cu b/faiss/gpu/impl/scan/IVFInterleaved1024.cu index dc62d9f52a..d067a8b228 100644 --- a/faiss/gpu/impl/scan/IVFInterleaved1024.cu +++ b/faiss/gpu/impl/scan/IVFInterleaved1024.cu @@ -10,7 +10,7 @@ namespace faiss { namespace gpu { -IVF_INTERLEAVED_IMPL(128, 1024, 8) +IVF_INTERLEAVED_IMPL(IVFINTERLEAVED_1024_PARAMS) } } // namespace faiss diff --git a/faiss/gpu/impl/scan/IVFInterleaved128.cu b/faiss/gpu/impl/scan/IVFInterleaved128.cu index 18e0c514f7..1814df4074 100644 --- a/faiss/gpu/impl/scan/IVFInterleaved128.cu +++ b/faiss/gpu/impl/scan/IVFInterleaved128.cu @@ -10,7 +10,7 @@ namespace faiss { namespace gpu { -IVF_INTERLEAVED_IMPL(128, 128, 3) +IVF_INTERLEAVED_IMPL(IVFINTERLEAVED_128_PARAMS) } } // namespace faiss diff --git a/faiss/gpu/impl/scan/IVFInterleaved2048.cu b/faiss/gpu/impl/scan/IVFInterleaved2048.cu index e7b8398124..1ffb6fc9aa 100644 --- a/faiss/gpu/impl/scan/IVFInterleaved2048.cu +++ b/faiss/gpu/impl/scan/IVFInterleaved2048.cu @@ -11,7 +11,7 @@ namespace faiss { namespace gpu { #if GPU_MAX_SELECTION_K >= 2048 -IVF_INTERLEAVED_IMPL(64, 2048, 8) +IVF_INTERLEAVED_IMPL(IVFINTERLEAVED_2048_PARAMS) #endif } // namespace gpu diff --git a/faiss/gpu/impl/scan/IVFInterleaved256.cu b/faiss/gpu/impl/scan/IVFInterleaved256.cu index ecfe8693d2..c7817460f4 100644 --- a/faiss/gpu/impl/scan/IVFInterleaved256.cu +++ b/faiss/gpu/impl/scan/IVFInterleaved256.cu @@ -10,7 +10,7 @@ namespace faiss { namespace gpu { -IVF_INTERLEAVED_IMPL(128, 256, 4) +IVF_INTERLEAVED_IMPL(IVFINTERLEAVED_256_PARAMS) } } // namespace faiss diff --git a/faiss/gpu/impl/scan/IVFInterleaved32.cu b/faiss/gpu/impl/scan/IVFInterleaved32.cu index 3b19a89002..401a2f5ab2 100644 --- a/faiss/gpu/impl/scan/IVFInterleaved32.cu +++ b/faiss/gpu/impl/scan/IVFInterleaved32.cu @@ -10,7 +10,7 @@ namespace faiss { namespace gpu { -IVF_INTERLEAVED_IMPL(128, 32, 2) +IVF_INTERLEAVED_IMPL(IVFINTERLEAVED_32_PARAMS) } } // namespace faiss diff --git a/faiss/gpu/impl/scan/IVFInterleaved512.cu b/faiss/gpu/impl/scan/IVFInterleaved512.cu index 27bb8c8347..ac3c0d3e22 100644 --- a/faiss/gpu/impl/scan/IVFInterleaved512.cu +++ b/faiss/gpu/impl/scan/IVFInterleaved512.cu @@ -10,7 +10,7 @@ namespace faiss { namespace gpu { -IVF_INTERLEAVED_IMPL(128, 512, 8) +IVF_INTERLEAVED_IMPL(IVFINTERLEAVED_512_PARAMS) } } // namespace faiss diff --git a/faiss/gpu/impl/scan/IVFInterleaved64.cu b/faiss/gpu/impl/scan/IVFInterleaved64.cu index 0788f84428..56a02b5054 100644 --- a/faiss/gpu/impl/scan/IVFInterleaved64.cu +++ b/faiss/gpu/impl/scan/IVFInterleaved64.cu @@ -10,7 +10,7 @@ namespace faiss { namespace gpu { -IVF_INTERLEAVED_IMPL(128, 64, 3) +IVF_INTERLEAVED_IMPL(IVFINTERLEAVED_64_PARAMS) } } // namespace faiss diff --git a/faiss/gpu/impl/scan/IVFInterleavedImpl.cuh b/faiss/gpu/impl/scan/IVFInterleavedImpl.cuh index 6a553e59a9..7511374a69 100644 --- a/faiss/gpu/impl/scan/IVFInterleavedImpl.cuh +++ b/faiss/gpu/impl/scan/IVFInterleavedImpl.cuh @@ -9,81 +9,357 @@ #include #include +#include #include -#define IVF_INTERLEAVED_IMPL(THREADS, WARP_Q, THREAD_Q) \ - \ - void ivfInterleavedScanImpl_##WARP_Q##_( \ - Tensor& queries, \ - Tensor& listIds, \ - DeviceVector& listData, \ - DeviceVector& listIndices, \ - IndicesOptions indicesOptions, \ - DeviceVector& listLengths, \ - int k, \ - faiss::MetricType metric, \ - bool useResidual, \ - Tensor& residualBase, \ - GpuScalarQuantizer* scalarQ, \ - Tensor& outDistances, \ - Tensor& outIndices, \ - GpuResources* res) { \ - FAISS_ASSERT(k <= WARP_Q); \ - \ - IVFINT_METRICS(THREADS, WARP_Q, THREAD_Q); \ - \ - CUDA_TEST_ERROR(); \ +namespace faiss { +namespace gpu { + +template < + typename CODEC_TYPE, + typename METRIC_TYPE, + int THREADS, + int NUM_WARP_Q, + int NUM_THREAD_Q> +void IVFINT_RUN( + CODEC_TYPE& codec, + Tensor& queries, + Tensor& listIds, + DeviceVector& listData, + DeviceVector& listIndices, + IndicesOptions indicesOptions, + DeviceVector& listLengths, + const int k, + METRIC_TYPE metric, + const bool useResidual, + Tensor& residualBase, + GpuScalarQuantizer* scalarQ, + Tensor& outDistances, + Tensor& outIndices, + GpuResources* res) { + const auto nq = queries.getSize(0); + const auto dim = queries.getSize(1); + const auto nprobe = listIds.getSize(1); + + const auto stream = res->getDefaultStreamCurrentDevice(); + + DeviceTensor distanceTemp( + res, + makeTempAlloc(AllocType::Other, stream), + {queries.getSize(0), listIds.getSize(1), k}); + DeviceTensor indicesTemp( + res, + makeTempAlloc(AllocType::Other, stream), + {queries.getSize(0), listIds.getSize(1), k}); + + const dim3 grid(nprobe, std::min(nq, (idx_t)getMaxGridCurrentDevice().y)); + + if (useResidual) { + ivfInterleavedScan< + CODEC_TYPE, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q, + true><<>>( + queries, + residualBase, + listIds, + listData.data(), + listLengths.data(), + codec, + metric, + k, + distanceTemp, + indicesTemp); + } else { + ivfInterleavedScan< + CODEC_TYPE, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q, + false><<>>( + queries, + residualBase, + listIds, + listData.data(), + listLengths.data(), + codec, + metric, + k, + distanceTemp, + indicesTemp); } -#define IVF_INTERLEAVED_DECL(WARP_Q) \ - \ - void ivfInterleavedScanImpl_##WARP_Q##_( \ - Tensor& queries, \ - Tensor& listIds, \ - DeviceVector& listData, \ - DeviceVector& listIndices, \ - IndicesOptions indicesOptions, \ - DeviceVector& listLengths, \ - int k, \ - faiss::MetricType metric, \ - bool useResidual, \ - Tensor& residualBase, \ - GpuScalarQuantizer* scalarQ, \ - Tensor& outDistances, \ - Tensor& outIndices, \ - GpuResources* res) - -#define IVF_INTERLEAVED_CALL(WARP_Q) \ - ivfInterleavedScanImpl_##WARP_Q##_( \ - queries, \ - listIds, \ - listData, \ - listIndices, \ - indicesOptions, \ - listLengths, \ - k, \ - metric, \ - useResidual, \ - residualBase, \ - scalarQ, \ - outDistances, \ - outIndices, \ - res) + runIVFInterleavedScan2( + distanceTemp, + indicesTemp, + listIds, + k, + listIndices, + indicesOptions, + METRIC_TYPE::kDirection, + outDistances, + outIndices, + stream); +} -namespace faiss { -namespace gpu { +template +void IVFINT_CODECS( + Tensor& queries, + Tensor& listIds, + DeviceVector& listData, + DeviceVector& listIndices, + IndicesOptions indicesOptions, + DeviceVector& listLengths, + const int k, + METRIC_TYPE metric, + const bool useResidual, + Tensor& residualBase, + GpuScalarQuantizer* scalarQ, + Tensor& outDistances, + Tensor& outIndices, + GpuResources* res) { + const auto dim = queries.getSize(1); + + const auto call_ivfint_run = [&](const auto& func, auto& codec) { + func(codec, + queries, + listIds, + listData, + listIndices, + indicesOptions, + listLengths, + k, + metric, + useResidual, + residualBase, + scalarQ, + outDistances, + outIndices, + res); + }; + + if (!scalarQ) { + using CodecT = CodecFloat; + CodecT codec(dim * sizeof(float)); + call_ivfint_run( + IVFINT_RUN< + CodecT, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q>, + codec); + } else { + switch (scalarQ->qtype) { + case ScalarQuantizer::QuantizerType::QT_8bit: { + using CodecT = + Codec; + CodecT codec( + scalarQ->code_size, + scalarQ->gpuTrained.data(), + scalarQ->gpuTrained.data() + dim); + call_ivfint_run( + IVFINT_RUN< + CodecT, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q>, + codec); + } break; + case ScalarQuantizer::QuantizerType::QT_8bit_uniform: { + using CodecT = + Codec; + CodecT codec( + scalarQ->code_size, + scalarQ->trained[0], + scalarQ->trained[1]); + call_ivfint_run( + IVFINT_RUN< + CodecT, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q>, + codec); + } break; + case ScalarQuantizer::QuantizerType::QT_fp16: { + using CodecT = + Codec; + CodecT codec(scalarQ->code_size); + call_ivfint_run( + IVFINT_RUN< + CodecT, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q>, + codec); + } break; + case ScalarQuantizer::QuantizerType::QT_8bit_direct: { + using CodecT = + Codec; + Codec codec( + scalarQ->code_size); + call_ivfint_run( + IVFINT_RUN< + CodecT, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q>, + codec); + } break; + case ScalarQuantizer::QuantizerType::QT_6bit: { + using CodecT = + Codec; + Codec codec( + scalarQ->code_size, + scalarQ->gpuTrained.data(), + scalarQ->gpuTrained.data() + dim); + call_ivfint_run( + IVFINT_RUN< + CodecT, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q>, + codec); + } break; + case ScalarQuantizer::QuantizerType::QT_4bit: { + using CodecT = + Codec; + Codec codec( + scalarQ->code_size, + scalarQ->gpuTrained.data(), + scalarQ->gpuTrained.data() + dim); + call_ivfint_run( + IVFINT_RUN< + CodecT, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q>, + codec); + } break; + case ScalarQuantizer::QuantizerType::QT_4bit_uniform: { + using CodecT = + Codec; + Codec codec( + scalarQ->code_size, + scalarQ->trained[0], + scalarQ->trained[1]); + call_ivfint_run( + IVFINT_RUN< + CodecT, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q>, + codec); + } break; + default: + FAISS_ASSERT(false); + } + } +} + +#define IVF_INTERLEAVED_SCAN_IMPL_ARGS \ + (Tensor & queries, \ + Tensor & listIds, \ + DeviceVector & listData, \ + DeviceVector & listIndices, \ + IndicesOptions indicesOptions, \ + DeviceVector & listLengths, \ + const int k, \ + faiss::MetricType metric_name, \ + const bool useResidual, \ + Tensor& residualBase, \ + GpuScalarQuantizer* scalarQ, \ + Tensor& outDistances, \ + Tensor& outIndices, \ + GpuResources* res) + +template +void IVF_METRICS IVF_INTERLEAVED_SCAN_IMPL_ARGS { + FAISS_ASSERT(k <= NUM_WARP_Q); + + const auto call_codec = [&](const auto& func, const auto& metric) { + func(queries, + listIds, + listData, + listIndices, + indicesOptions, + listLengths, + k, + metric, + useResidual, + residualBase, + scalarQ, + outDistances, + outIndices, + res); + }; + + if (metric_name == MetricType::METRIC_L2) { + L2Distance metric; + call_codec( + IVFINT_CODECS, + metric); + } else if (metric_name == MetricType::METRIC_INNER_PRODUCT) { + IPDistance metric; + call_codec( + IVFINT_CODECS, + metric); + } else { + FAISS_ASSERT(false); + } + + CUDA_TEST_ERROR(); +} + +template +void ivfInterleavedScanImpl IVF_INTERLEAVED_SCAN_IMPL_ARGS; + +#define IVF_INTERLEAVED_IMPL_HELPER(THREADS, NUM_WARP_Q, NUM_THREAD_Q) \ + template <> \ + void ivfInterleavedScanImpl \ + IVF_INTERLEAVED_SCAN_IMPL_ARGS { \ + IVF_METRICS( \ + queries, \ + listIds, \ + listData, \ + listIndices, \ + indicesOptions, \ + listLengths, \ + k, \ + metric_name, \ + useResidual, \ + residualBase, \ + scalarQ, \ + outDistances, \ + outIndices, \ + res); \ + } + +#define IVF_INTERLEAVED_IMPL(...) IVF_INTERLEAVED_IMPL_HELPER(__VA_ARGS__) -IVF_INTERLEAVED_DECL(1); -IVF_INTERLEAVED_DECL(32); -IVF_INTERLEAVED_DECL(64); -IVF_INTERLEAVED_DECL(128); -IVF_INTERLEAVED_DECL(256); -IVF_INTERLEAVED_DECL(512); -IVF_INTERLEAVED_DECL(1024); - -#if GPU_MAX_SELECTION_K >= 2048 -IVF_INTERLEAVED_DECL(2048); -#endif +// clang-format off +#define IVFINTERLEAVED_1_PARAMS 128,1,1 +#define IVFINTERLEAVED_32_PARAMS 128,32,2 +#define IVFINTERLEAVED_64_PARAMS 128,64,3 +#define IVFINTERLEAVED_128_PARAMS 128,128,3 +#define IVFINTERLEAVED_256_PARAMS 128,256,4 +#define IVFINTERLEAVED_512_PARAMS 128,512,8 +#define IVFINTERLEAVED_1024_PARAMS 128,1024,8 +#define IVFINTERLEAVED_2048_PARAMS 64,2048,8 +// clang-format on } // namespace gpu } // namespace faiss