From 4b15abbb4347a6e718dfb8c2e07a0381889723b3 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Tue, 20 Jun 2023 09:25:43 -0700 Subject: [PATCH] Switch //faiss/gpu to use templates instead of macros (#2914) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2914 The macros are part of a system to reduce compilation time via separate compilation units. Unfortunately, the parallelization is across C++ template functions instead of NVCC invocations on kernel compilation, which would be much more effective. This diff removes the preprocessor macros and expands them into templates. Compilation time after this diff is given by [this buck2 output](https://www.internalfb.com/buck2/ae9e6b28-a1bd-4d46-8af8-2895e6f182c8) with 1,043s through impl/scan/IVFInterleaved2048.cu Differential Revision: D46549341 fbshipit-source-id: 2175702ff58dd0b8fca32e44d94a249e1de2cdf1 --- faiss/gpu/impl/IVFInterleaved.cu | 35 +- faiss/gpu/impl/IVFInterleaved.cuh | 215 +---------- faiss/gpu/impl/scan/IVFInterleaved1.cu | 2 +- faiss/gpu/impl/scan/IVFInterleaved1024.cu | 2 +- faiss/gpu/impl/scan/IVFInterleaved128.cu | 2 +- faiss/gpu/impl/scan/IVFInterleaved2048.cu | 2 +- faiss/gpu/impl/scan/IVFInterleaved256.cu | 2 +- faiss/gpu/impl/scan/IVFInterleaved32.cu | 2 +- faiss/gpu/impl/scan/IVFInterleaved512.cu | 2 +- faiss/gpu/impl/scan/IVFInterleaved64.cu | 2 +- faiss/gpu/impl/scan/IVFInterleavedImpl.cuh | 414 +++++++++++++++++---- 11 files changed, 392 insertions(+), 288 deletions(-) diff --git a/faiss/gpu/impl/IVFInterleaved.cu b/faiss/gpu/impl/IVFInterleaved.cu index 3aca959e78..c9ee87ee42 100644 --- a/faiss/gpu/impl/IVFInterleaved.cu +++ b/faiss/gpu/impl/IVFInterleaved.cu @@ -192,24 +192,43 @@ void runIVFInterleavedScan( // caught for exceptions at a higher level FAISS_ASSERT(k <= GPU_MAX_SELECTION_K); + const auto ivf_interleaved_call = [&](const auto func) { + func(queries, + listIds, + listData, + listIndices, + indicesOptions, + listLengths, + k, + metric, + useResidual, + residualBase, + scalarQ, + outDistances, + outIndices, + res); + }; + if (k == 1) { - IVF_INTERLEAVED_CALL(1); + ivf_interleaved_call(ivfInterleavedScanImpl); } else if (k <= 32) { - IVF_INTERLEAVED_CALL(32); + ivf_interleaved_call(ivfInterleavedScanImpl); } else if (k <= 64) { - IVF_INTERLEAVED_CALL(64); + ivf_interleaved_call(ivfInterleavedScanImpl); } else if (k <= 128) { - IVF_INTERLEAVED_CALL(128); + ivf_interleaved_call(ivfInterleavedScanImpl); } else if (k <= 256) { - IVF_INTERLEAVED_CALL(256); + ivf_interleaved_call(ivfInterleavedScanImpl); } else if (k <= 512) { - IVF_INTERLEAVED_CALL(512); + ivf_interleaved_call(ivfInterleavedScanImpl); } else if (k <= 1024) { - IVF_INTERLEAVED_CALL(1024); + ivf_interleaved_call( + ivfInterleavedScanImpl); } #if GPU_MAX_SELECTION_K >= 2048 else if (k <= 2048) { - IVF_INTERLEAVED_CALL(2048); + ivf_interleaved_call( + ivfInterleavedScanImpl); } #endif } diff --git a/faiss/gpu/impl/IVFInterleaved.cuh b/faiss/gpu/impl/IVFInterleaved.cuh index 79d8feb32b..053f99db09 100644 --- a/faiss/gpu/impl/IVFInterleaved.cuh +++ b/faiss/gpu/impl/IVFInterleaved.cuh @@ -122,9 +122,10 @@ __global__ void ivfInterleavedScan( // whole blocks for (int dBase = 0; dBase < dimBlocks; dBase += kWarpSize) { - int loadDim = dBase + laneId; - float queryReg = query[loadDim]; - float residualReg = Residual ? residualBaseSlice[loadDim] : 0; + const int loadDim = dBase + laneId; + const float queryReg = query[loadDim]; + [[maybe_unused]] const float residualReg = + Residual ? residualBaseSlice[loadDim] : 0; constexpr int kUnroll = 4; @@ -151,7 +152,7 @@ __global__ void ivfInterleavedScan( decV[j] = codec.decodeNew(dBase + d, encV[j]); } - if (Residual) { + if constexpr (Residual) { #pragma unroll for (int j = 0; j < kUnroll; ++j) { int d = i * kUnroll + j; @@ -169,13 +170,13 @@ __global__ void ivfInterleavedScan( } // remainder - int loadDim = dimBlocks + laneId; - bool loadDimInBounds = loadDim < dim; + const int loadDim = dimBlocks + laneId; + const bool loadDimInBounds = loadDim < dim; - float queryReg = loadDimInBounds ? query[loadDim] : 0; - float residualReg = Residual && loadDimInBounds - ? residualBaseSlice[loadDim] - : 0; + const float queryReg = loadDimInBounds ? query[loadDim] : 0; + [[maybe_unused]] const float residualReg = + Residual && loadDimInBounds ? residualBaseSlice[loadDim] + : 0; for (int d = 0; d < dim - dimBlocks; ++d, data += wordsPerVectorBlockDim) { @@ -186,7 +187,7 @@ __global__ void ivfInterleavedScan( enc = WarpPackedBits::postRead( laneId, enc); float dec = codec.decodeNew(dimBlocks + d, enc); - if (Residual) { + if constexpr (Residual) { dec += SHFL_SYNC(residualReg, d, kWarpSize); } @@ -217,198 +218,6 @@ __global__ void ivfInterleavedScan( // compile time using these macros to define the function body // -#define IVFINT_RUN(CODEC_TYPE, METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q) \ - do { \ - dim3 grid(nprobe, std::min(nq, (idx_t)getMaxGridCurrentDevice().y)); \ - if (useResidual) { \ - ivfInterleavedScan< \ - CODEC_TYPE, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q, \ - true><<>>( \ - queries, \ - residualBase, \ - listIds, \ - listData.data(), \ - listLengths.data(), \ - codec, \ - metric, \ - k, \ - distanceTemp, \ - indicesTemp); \ - } else { \ - ivfInterleavedScan< \ - CODEC_TYPE, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q, \ - false><<>>( \ - queries, \ - residualBase, \ - listIds, \ - listData.data(), \ - listLengths.data(), \ - codec, \ - metric, \ - k, \ - distanceTemp, \ - indicesTemp); \ - } \ - \ - runIVFInterleavedScan2( \ - distanceTemp, \ - indicesTemp, \ - listIds, \ - k, \ - listIndices, \ - indicesOptions, \ - METRIC_TYPE::kDirection, \ - outDistances, \ - outIndices, \ - stream); \ - } while (0); - -#define IVFINT_CODECS(METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q) \ - do { \ - if (!scalarQ) { \ - using CodecT = CodecFloat; \ - CodecT codec(dim * sizeof(float)); \ - IVFINT_RUN( \ - CodecT, METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q); \ - } else { \ - switch (scalarQ->qtype) { \ - case ScalarQuantizer::QuantizerType::QT_8bit: { \ - using CodecT = \ - Codec; \ - CodecT codec( \ - scalarQ->code_size, \ - scalarQ->gpuTrained.data(), \ - scalarQ->gpuTrained.data() + dim); \ - IVFINT_RUN( \ - CodecT, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q); \ - } break; \ - case ScalarQuantizer::QuantizerType::QT_8bit_uniform: { \ - using CodecT = Codec< \ - ScalarQuantizer::QuantizerType::QT_8bit_uniform, \ - 1>; \ - CodecT codec( \ - scalarQ->code_size, \ - scalarQ->trained[0], \ - scalarQ->trained[1]); \ - IVFINT_RUN( \ - CodecT, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q); \ - } break; \ - case ScalarQuantizer::QuantizerType::QT_fp16: { \ - using CodecT = \ - Codec; \ - CodecT codec(scalarQ->code_size); \ - IVFINT_RUN( \ - CodecT, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q); \ - } break; \ - case ScalarQuantizer::QuantizerType::QT_8bit_direct: { \ - using CodecT = Codec< \ - ScalarQuantizer::QuantizerType::QT_8bit_direct, \ - 1>; \ - Codec \ - codec(scalarQ->code_size); \ - IVFINT_RUN( \ - CodecT, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q); \ - } break; \ - case ScalarQuantizer::QuantizerType::QT_6bit: { \ - using CodecT = \ - Codec; \ - Codec codec( \ - scalarQ->code_size, \ - scalarQ->gpuTrained.data(), \ - scalarQ->gpuTrained.data() + dim); \ - IVFINT_RUN( \ - CodecT, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q); \ - } break; \ - case ScalarQuantizer::QuantizerType::QT_4bit: { \ - using CodecT = \ - Codec; \ - Codec codec( \ - scalarQ->code_size, \ - scalarQ->gpuTrained.data(), \ - scalarQ->gpuTrained.data() + dim); \ - IVFINT_RUN( \ - CodecT, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q); \ - } break; \ - case ScalarQuantizer::QuantizerType::QT_4bit_uniform: { \ - using CodecT = Codec< \ - ScalarQuantizer::QuantizerType::QT_4bit_uniform, \ - 1>; \ - Codec \ - codec(scalarQ->code_size, \ - scalarQ->trained[0], \ - scalarQ->trained[1]); \ - IVFINT_RUN( \ - CodecT, \ - METRIC_TYPE, \ - THREADS, \ - NUM_WARP_Q, \ - NUM_THREAD_Q); \ - } break; \ - default: \ - FAISS_ASSERT(false); \ - } \ - } \ - } while (0) - -#define IVFINT_METRICS(THREADS, NUM_WARP_Q, NUM_THREAD_Q) \ - do { \ - auto stream = res->getDefaultStreamCurrentDevice(); \ - auto nq = queries.getSize(0); \ - auto dim = queries.getSize(1); \ - auto nprobe = listIds.getSize(1); \ - \ - DeviceTensor distanceTemp( \ - res, \ - makeTempAlloc(AllocType::Other, stream), \ - {queries.getSize(0), listIds.getSize(1), k}); \ - DeviceTensor indicesTemp( \ - res, \ - makeTempAlloc(AllocType::Other, stream), \ - {queries.getSize(0), listIds.getSize(1), k}); \ - \ - if (metric == MetricType::METRIC_L2) { \ - L2Distance metric; \ - IVFINT_CODECS(L2Distance, THREADS, NUM_WARP_Q, NUM_THREAD_Q); \ - } else if (metric == MetricType::METRIC_INNER_PRODUCT) { \ - IPDistance metric; \ - IVFINT_CODECS(IPDistance, THREADS, NUM_WARP_Q, NUM_THREAD_Q); \ - } else { \ - FAISS_ASSERT(false); \ - } \ - } while (0) - // Top-level IVF scan function for the interleaved by 32 layout // with all implementations void runIVFInterleavedScan( diff --git a/faiss/gpu/impl/scan/IVFInterleaved1.cu b/faiss/gpu/impl/scan/IVFInterleaved1.cu index 01c94a64ef..c898ec2d6d 100644 --- a/faiss/gpu/impl/scan/IVFInterleaved1.cu +++ b/faiss/gpu/impl/scan/IVFInterleaved1.cu @@ -10,7 +10,7 @@ namespace faiss { namespace gpu { -IVF_INTERLEAVED_IMPL(128, 1, 1) +IVF_INTERLEAVED_IMPL(IVFINTERLEAVED_1_PARAMS) } } // namespace faiss diff --git a/faiss/gpu/impl/scan/IVFInterleaved1024.cu b/faiss/gpu/impl/scan/IVFInterleaved1024.cu index dc62d9f52a..d067a8b228 100644 --- a/faiss/gpu/impl/scan/IVFInterleaved1024.cu +++ b/faiss/gpu/impl/scan/IVFInterleaved1024.cu @@ -10,7 +10,7 @@ namespace faiss { namespace gpu { -IVF_INTERLEAVED_IMPL(128, 1024, 8) +IVF_INTERLEAVED_IMPL(IVFINTERLEAVED_1024_PARAMS) } } // namespace faiss diff --git a/faiss/gpu/impl/scan/IVFInterleaved128.cu b/faiss/gpu/impl/scan/IVFInterleaved128.cu index 18e0c514f7..1814df4074 100644 --- a/faiss/gpu/impl/scan/IVFInterleaved128.cu +++ b/faiss/gpu/impl/scan/IVFInterleaved128.cu @@ -10,7 +10,7 @@ namespace faiss { namespace gpu { -IVF_INTERLEAVED_IMPL(128, 128, 3) +IVF_INTERLEAVED_IMPL(IVFINTERLEAVED_128_PARAMS) } } // namespace faiss diff --git a/faiss/gpu/impl/scan/IVFInterleaved2048.cu b/faiss/gpu/impl/scan/IVFInterleaved2048.cu index e7b8398124..1ffb6fc9aa 100644 --- a/faiss/gpu/impl/scan/IVFInterleaved2048.cu +++ b/faiss/gpu/impl/scan/IVFInterleaved2048.cu @@ -11,7 +11,7 @@ namespace faiss { namespace gpu { #if GPU_MAX_SELECTION_K >= 2048 -IVF_INTERLEAVED_IMPL(64, 2048, 8) +IVF_INTERLEAVED_IMPL(IVFINTERLEAVED_2048_PARAMS) #endif } // namespace gpu diff --git a/faiss/gpu/impl/scan/IVFInterleaved256.cu b/faiss/gpu/impl/scan/IVFInterleaved256.cu index ecfe8693d2..c7817460f4 100644 --- a/faiss/gpu/impl/scan/IVFInterleaved256.cu +++ b/faiss/gpu/impl/scan/IVFInterleaved256.cu @@ -10,7 +10,7 @@ namespace faiss { namespace gpu { -IVF_INTERLEAVED_IMPL(128, 256, 4) +IVF_INTERLEAVED_IMPL(IVFINTERLEAVED_256_PARAMS) } } // namespace faiss diff --git a/faiss/gpu/impl/scan/IVFInterleaved32.cu b/faiss/gpu/impl/scan/IVFInterleaved32.cu index 3b19a89002..401a2f5ab2 100644 --- a/faiss/gpu/impl/scan/IVFInterleaved32.cu +++ b/faiss/gpu/impl/scan/IVFInterleaved32.cu @@ -10,7 +10,7 @@ namespace faiss { namespace gpu { -IVF_INTERLEAVED_IMPL(128, 32, 2) +IVF_INTERLEAVED_IMPL(IVFINTERLEAVED_32_PARAMS) } } // namespace faiss diff --git a/faiss/gpu/impl/scan/IVFInterleaved512.cu b/faiss/gpu/impl/scan/IVFInterleaved512.cu index 27bb8c8347..ac3c0d3e22 100644 --- a/faiss/gpu/impl/scan/IVFInterleaved512.cu +++ b/faiss/gpu/impl/scan/IVFInterleaved512.cu @@ -10,7 +10,7 @@ namespace faiss { namespace gpu { -IVF_INTERLEAVED_IMPL(128, 512, 8) +IVF_INTERLEAVED_IMPL(IVFINTERLEAVED_512_PARAMS) } } // namespace faiss diff --git a/faiss/gpu/impl/scan/IVFInterleaved64.cu b/faiss/gpu/impl/scan/IVFInterleaved64.cu index 0788f84428..56a02b5054 100644 --- a/faiss/gpu/impl/scan/IVFInterleaved64.cu +++ b/faiss/gpu/impl/scan/IVFInterleaved64.cu @@ -10,7 +10,7 @@ namespace faiss { namespace gpu { -IVF_INTERLEAVED_IMPL(128, 64, 3) +IVF_INTERLEAVED_IMPL(IVFINTERLEAVED_64_PARAMS) } } // namespace faiss diff --git a/faiss/gpu/impl/scan/IVFInterleavedImpl.cuh b/faiss/gpu/impl/scan/IVFInterleavedImpl.cuh index 6a553e59a9..7511374a69 100644 --- a/faiss/gpu/impl/scan/IVFInterleavedImpl.cuh +++ b/faiss/gpu/impl/scan/IVFInterleavedImpl.cuh @@ -9,81 +9,357 @@ #include #include +#include #include -#define IVF_INTERLEAVED_IMPL(THREADS, WARP_Q, THREAD_Q) \ - \ - void ivfInterleavedScanImpl_##WARP_Q##_( \ - Tensor& queries, \ - Tensor& listIds, \ - DeviceVector& listData, \ - DeviceVector& listIndices, \ - IndicesOptions indicesOptions, \ - DeviceVector& listLengths, \ - int k, \ - faiss::MetricType metric, \ - bool useResidual, \ - Tensor& residualBase, \ - GpuScalarQuantizer* scalarQ, \ - Tensor& outDistances, \ - Tensor& outIndices, \ - GpuResources* res) { \ - FAISS_ASSERT(k <= WARP_Q); \ - \ - IVFINT_METRICS(THREADS, WARP_Q, THREAD_Q); \ - \ - CUDA_TEST_ERROR(); \ +namespace faiss { +namespace gpu { + +template < + typename CODEC_TYPE, + typename METRIC_TYPE, + int THREADS, + int NUM_WARP_Q, + int NUM_THREAD_Q> +void IVFINT_RUN( + CODEC_TYPE& codec, + Tensor& queries, + Tensor& listIds, + DeviceVector& listData, + DeviceVector& listIndices, + IndicesOptions indicesOptions, + DeviceVector& listLengths, + const int k, + METRIC_TYPE metric, + const bool useResidual, + Tensor& residualBase, + GpuScalarQuantizer* scalarQ, + Tensor& outDistances, + Tensor& outIndices, + GpuResources* res) { + const auto nq = queries.getSize(0); + const auto dim = queries.getSize(1); + const auto nprobe = listIds.getSize(1); + + const auto stream = res->getDefaultStreamCurrentDevice(); + + DeviceTensor distanceTemp( + res, + makeTempAlloc(AllocType::Other, stream), + {queries.getSize(0), listIds.getSize(1), k}); + DeviceTensor indicesTemp( + res, + makeTempAlloc(AllocType::Other, stream), + {queries.getSize(0), listIds.getSize(1), k}); + + const dim3 grid(nprobe, std::min(nq, (idx_t)getMaxGridCurrentDevice().y)); + + if (useResidual) { + ivfInterleavedScan< + CODEC_TYPE, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q, + true><<>>( + queries, + residualBase, + listIds, + listData.data(), + listLengths.data(), + codec, + metric, + k, + distanceTemp, + indicesTemp); + } else { + ivfInterleavedScan< + CODEC_TYPE, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q, + false><<>>( + queries, + residualBase, + listIds, + listData.data(), + listLengths.data(), + codec, + metric, + k, + distanceTemp, + indicesTemp); } -#define IVF_INTERLEAVED_DECL(WARP_Q) \ - \ - void ivfInterleavedScanImpl_##WARP_Q##_( \ - Tensor& queries, \ - Tensor& listIds, \ - DeviceVector& listData, \ - DeviceVector& listIndices, \ - IndicesOptions indicesOptions, \ - DeviceVector& listLengths, \ - int k, \ - faiss::MetricType metric, \ - bool useResidual, \ - Tensor& residualBase, \ - GpuScalarQuantizer* scalarQ, \ - Tensor& outDistances, \ - Tensor& outIndices, \ - GpuResources* res) - -#define IVF_INTERLEAVED_CALL(WARP_Q) \ - ivfInterleavedScanImpl_##WARP_Q##_( \ - queries, \ - listIds, \ - listData, \ - listIndices, \ - indicesOptions, \ - listLengths, \ - k, \ - metric, \ - useResidual, \ - residualBase, \ - scalarQ, \ - outDistances, \ - outIndices, \ - res) + runIVFInterleavedScan2( + distanceTemp, + indicesTemp, + listIds, + k, + listIndices, + indicesOptions, + METRIC_TYPE::kDirection, + outDistances, + outIndices, + stream); +} -namespace faiss { -namespace gpu { +template +void IVFINT_CODECS( + Tensor& queries, + Tensor& listIds, + DeviceVector& listData, + DeviceVector& listIndices, + IndicesOptions indicesOptions, + DeviceVector& listLengths, + const int k, + METRIC_TYPE metric, + const bool useResidual, + Tensor& residualBase, + GpuScalarQuantizer* scalarQ, + Tensor& outDistances, + Tensor& outIndices, + GpuResources* res) { + const auto dim = queries.getSize(1); + + const auto call_ivfint_run = [&](const auto& func, auto& codec) { + func(codec, + queries, + listIds, + listData, + listIndices, + indicesOptions, + listLengths, + k, + metric, + useResidual, + residualBase, + scalarQ, + outDistances, + outIndices, + res); + }; + + if (!scalarQ) { + using CodecT = CodecFloat; + CodecT codec(dim * sizeof(float)); + call_ivfint_run( + IVFINT_RUN< + CodecT, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q>, + codec); + } else { + switch (scalarQ->qtype) { + case ScalarQuantizer::QuantizerType::QT_8bit: { + using CodecT = + Codec; + CodecT codec( + scalarQ->code_size, + scalarQ->gpuTrained.data(), + scalarQ->gpuTrained.data() + dim); + call_ivfint_run( + IVFINT_RUN< + CodecT, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q>, + codec); + } break; + case ScalarQuantizer::QuantizerType::QT_8bit_uniform: { + using CodecT = + Codec; + CodecT codec( + scalarQ->code_size, + scalarQ->trained[0], + scalarQ->trained[1]); + call_ivfint_run( + IVFINT_RUN< + CodecT, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q>, + codec); + } break; + case ScalarQuantizer::QuantizerType::QT_fp16: { + using CodecT = + Codec; + CodecT codec(scalarQ->code_size); + call_ivfint_run( + IVFINT_RUN< + CodecT, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q>, + codec); + } break; + case ScalarQuantizer::QuantizerType::QT_8bit_direct: { + using CodecT = + Codec; + Codec codec( + scalarQ->code_size); + call_ivfint_run( + IVFINT_RUN< + CodecT, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q>, + codec); + } break; + case ScalarQuantizer::QuantizerType::QT_6bit: { + using CodecT = + Codec; + Codec codec( + scalarQ->code_size, + scalarQ->gpuTrained.data(), + scalarQ->gpuTrained.data() + dim); + call_ivfint_run( + IVFINT_RUN< + CodecT, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q>, + codec); + } break; + case ScalarQuantizer::QuantizerType::QT_4bit: { + using CodecT = + Codec; + Codec codec( + scalarQ->code_size, + scalarQ->gpuTrained.data(), + scalarQ->gpuTrained.data() + dim); + call_ivfint_run( + IVFINT_RUN< + CodecT, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q>, + codec); + } break; + case ScalarQuantizer::QuantizerType::QT_4bit_uniform: { + using CodecT = + Codec; + Codec codec( + scalarQ->code_size, + scalarQ->trained[0], + scalarQ->trained[1]); + call_ivfint_run( + IVFINT_RUN< + CodecT, + METRIC_TYPE, + THREADS, + NUM_WARP_Q, + NUM_THREAD_Q>, + codec); + } break; + default: + FAISS_ASSERT(false); + } + } +} + +#define IVF_INTERLEAVED_SCAN_IMPL_ARGS \ + (Tensor & queries, \ + Tensor & listIds, \ + DeviceVector & listData, \ + DeviceVector & listIndices, \ + IndicesOptions indicesOptions, \ + DeviceVector & listLengths, \ + const int k, \ + faiss::MetricType metric_name, \ + const bool useResidual, \ + Tensor& residualBase, \ + GpuScalarQuantizer* scalarQ, \ + Tensor& outDistances, \ + Tensor& outIndices, \ + GpuResources* res) + +template +void IVF_METRICS IVF_INTERLEAVED_SCAN_IMPL_ARGS { + FAISS_ASSERT(k <= NUM_WARP_Q); + + const auto call_codec = [&](const auto& func, const auto& metric) { + func(queries, + listIds, + listData, + listIndices, + indicesOptions, + listLengths, + k, + metric, + useResidual, + residualBase, + scalarQ, + outDistances, + outIndices, + res); + }; + + if (metric_name == MetricType::METRIC_L2) { + L2Distance metric; + call_codec( + IVFINT_CODECS, + metric); + } else if (metric_name == MetricType::METRIC_INNER_PRODUCT) { + IPDistance metric; + call_codec( + IVFINT_CODECS, + metric); + } else { + FAISS_ASSERT(false); + } + + CUDA_TEST_ERROR(); +} + +template +void ivfInterleavedScanImpl IVF_INTERLEAVED_SCAN_IMPL_ARGS; + +#define IVF_INTERLEAVED_IMPL_HELPER(THREADS, NUM_WARP_Q, NUM_THREAD_Q) \ + template <> \ + void ivfInterleavedScanImpl \ + IVF_INTERLEAVED_SCAN_IMPL_ARGS { \ + IVF_METRICS( \ + queries, \ + listIds, \ + listData, \ + listIndices, \ + indicesOptions, \ + listLengths, \ + k, \ + metric_name, \ + useResidual, \ + residualBase, \ + scalarQ, \ + outDistances, \ + outIndices, \ + res); \ + } + +#define IVF_INTERLEAVED_IMPL(...) IVF_INTERLEAVED_IMPL_HELPER(__VA_ARGS__) -IVF_INTERLEAVED_DECL(1); -IVF_INTERLEAVED_DECL(32); -IVF_INTERLEAVED_DECL(64); -IVF_INTERLEAVED_DECL(128); -IVF_INTERLEAVED_DECL(256); -IVF_INTERLEAVED_DECL(512); -IVF_INTERLEAVED_DECL(1024); - -#if GPU_MAX_SELECTION_K >= 2048 -IVF_INTERLEAVED_DECL(2048); -#endif +// clang-format off +#define IVFINTERLEAVED_1_PARAMS 128,1,1 +#define IVFINTERLEAVED_32_PARAMS 128,32,2 +#define IVFINTERLEAVED_64_PARAMS 128,64,3 +#define IVFINTERLEAVED_128_PARAMS 128,128,3 +#define IVFINTERLEAVED_256_PARAMS 128,256,4 +#define IVFINTERLEAVED_512_PARAMS 128,512,8 +#define IVFINTERLEAVED_1024_PARAMS 128,1024,8 +#define IVFINTERLEAVED_2048_PARAMS 64,2048,8 +// clang-format on } // namespace gpu } // namespace faiss