diff --git a/conda/environments/all_cuda-131_arch-aarch64.yaml b/conda/environments/all_cuda-131_arch-aarch64.yaml index 4259c753e1..af6b71990e 100644 --- a/conda/environments/all_cuda-131_arch-aarch64.yaml +++ b/conda/environments/all_cuda-131_arch-aarch64.yaml @@ -31,6 +31,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev +- libnvjitlink-dev - librmm==26.4.*,>=0.0.0a0 - make - nccl>=2.19 diff --git a/conda/environments/all_cuda-131_arch-x86_64.yaml b/conda/environments/all_cuda-131_arch-x86_64.yaml index c4d70ea4aa..20fe9b82a2 100644 --- a/conda/environments/all_cuda-131_arch-x86_64.yaml +++ b/conda/environments/all_cuda-131_arch-x86_64.yaml @@ -31,6 +31,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev +- libnvjitlink-dev - librmm==26.4.*,>=0.0.0a0 - make - nccl>=2.19 diff --git a/conda/environments/bench_ann_cuda-131_arch-aarch64.yaml b/conda/environments/bench_ann_cuda-131_arch-aarch64.yaml index d71ee647a8..b11035fcd6 100644 --- a/conda/environments/bench_ann_cuda-131_arch-aarch64.yaml +++ b/conda/environments/bench_ann_cuda-131_arch-aarch64.yaml @@ -30,6 +30,7 @@ dependencies: - libcusolver-dev - libcusparse-dev - libcuvs==26.4.*,>=0.0.0a0 +- libnvjitlink-dev - librmm==26.4.*,>=0.0.0a0 - matplotlib-base>=3.9 - nccl>=2.19 diff --git a/conda/environments/bench_ann_cuda-131_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-131_arch-x86_64.yaml index ab0a93c967..48d203af8b 100644 --- a/conda/environments/bench_ann_cuda-131_arch-x86_64.yaml +++ b/conda/environments/bench_ann_cuda-131_arch-x86_64.yaml @@ -32,6 +32,7 @@ dependencies: - libcusolver-dev - libcusparse-dev - libcuvs==26.4.*,>=0.0.0a0 +- libnvjitlink-dev - librmm==26.4.*,>=0.0.0a0 - matplotlib-base>=3.9 - mkl-devel=2023 diff --git a/conda/environments/go_cuda-131_arch-aarch64.yaml b/conda/environments/go_cuda-131_arch-aarch64.yaml index cf0e9ed50d..135f6a88cc 100644 --- a/conda/environments/go_cuda-131_arch-aarch64.yaml +++ b/conda/environments/go_cuda-131_arch-aarch64.yaml @@ -25,6 +25,7 @@ dependencies: - libcusolver-dev - libcusparse-dev - libcuvs==26.4.*,>=0.0.0a0 +- libnvjitlink-dev - libraft==26.4.*,>=0.0.0a0 - nccl>=2.19 - ninja diff --git a/conda/environments/go_cuda-131_arch-x86_64.yaml b/conda/environments/go_cuda-131_arch-x86_64.yaml index 537039da30..df6a779331 100644 --- a/conda/environments/go_cuda-131_arch-x86_64.yaml +++ b/conda/environments/go_cuda-131_arch-x86_64.yaml @@ -25,6 +25,7 @@ dependencies: - libcusolver-dev - libcusparse-dev - libcuvs==26.4.*,>=0.0.0a0 +- libnvjitlink-dev - libraft==26.4.*,>=0.0.0a0 - nccl>=2.19 - ninja diff --git a/conda/environments/rust_cuda-131_arch-aarch64.yaml b/conda/environments/rust_cuda-131_arch-aarch64.yaml index ec28e79286..062cbc8ea0 100644 --- a/conda/environments/rust_cuda-131_arch-aarch64.yaml +++ b/conda/environments/rust_cuda-131_arch-aarch64.yaml @@ -22,6 +22,7 @@ dependencies: - libcusolver-dev - libcusparse-dev - libcuvs==26.4.*,>=0.0.0a0 +- libnvjitlink-dev - libraft==26.4.*,>=0.0.0a0 - make - nccl>=2.19 diff --git a/conda/environments/rust_cuda-131_arch-x86_64.yaml b/conda/environments/rust_cuda-131_arch-x86_64.yaml index ee36391fb6..2b96d4a64e 100644 --- a/conda/environments/rust_cuda-131_arch-x86_64.yaml +++ b/conda/environments/rust_cuda-131_arch-x86_64.yaml @@ -22,6 +22,7 @@ dependencies: - libcusolver-dev - libcusparse-dev - libcuvs==26.4.*,>=0.0.0a0 +- libnvjitlink-dev - libraft==26.4.*,>=0.0.0a0 - make - nccl>=2.19 diff --git a/conda/recipes/libcuvs/recipe.yaml b/conda/recipes/libcuvs/recipe.yaml index c0c74209af..abd3031a94 100644 --- a/conda/recipes/libcuvs/recipe.yaml +++ b/conda/recipes/libcuvs/recipe.yaml @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 schema_version: 1 @@ -72,6 +72,9 @@ cache: - ninja - ${{ stdlib("c") }} host: + - if: cuda_major == "13" + then: + - libnvjitlink-dev - librmm =${{ minor_version }} - libraft-headers =${{ minor_version }} - nccl ${{ nccl_version }} @@ -118,6 +121,9 @@ outputs: - libcurand-dev - libcusolver-dev - libcusparse-dev + - if: cuda_major == "13" + then: + - libnvjitlink-dev run: - ${{ pin_compatible("cuda-version", upper_bound="x", lower_bound="x") }} - libraft-headers =${{ minor_version }} @@ -128,6 +134,9 @@ outputs: - libcurand - libcusolver - libcusparse + - if: cuda_major == "13" + then: + - libnvjitlink ignore_run_exports: by_name: - cuda-cudart @@ -141,6 +150,9 @@ outputs: - librmm - mkl - nccl + - if: cuda_major == "13" + then: + - libnvjitlink about: homepage: ${{ load_from_file("python/libcuvs/pyproject.toml").project.urls.Homepage }} license: ${{ load_from_file("python/libcuvs/pyproject.toml").project.license }} @@ -177,6 +189,9 @@ outputs: - libcurand-dev - libcusolver-dev - libcusparse-dev + - if: cuda_major == "13" + then: + - libnvjitlink-dev run: - ${{ pin_compatible("cuda-version", upper_bound="x", lower_bound="x") }} - ${{ pin_subpackage("libcuvs-headers", exact=True) }} @@ -188,6 +203,9 @@ outputs: - libcurand - libcusolver - libcusparse + - if: cuda_major == "13" + then: + - libnvjitlink ignore_run_exports: by_name: - cuda-cudart @@ -201,6 +219,9 @@ outputs: - librmm - mkl - nccl + - if: cuda_major == "13" + then: + - libnvjitlink about: homepage: ${{ load_from_file("python/libcuvs/pyproject.toml").project.urls.Homepage }} license: ${{ load_from_file("python/libcuvs/pyproject.toml").project.license }} @@ -235,6 +256,9 @@ outputs: - libcurand-dev - libcusolver-dev - libcusparse-dev + - if: cuda_major == "13" + then: + - libnvjitlink-dev run: - ${{ pin_compatible("cuda-version", upper_bound="x", lower_bound="x") }} - ${{ pin_subpackage("libcuvs-headers", exact=True) }} @@ -246,6 +270,9 @@ outputs: - libcurand - libcusolver - libcusparse + - if: cuda_major == "13" + then: + - libnvjitlink ignore_run_exports: by_name: - cuda-cudart @@ -256,6 +283,9 @@ outputs: - libcurand - libcusolver - libcusparse + - if: cuda_major == "13" + then: + - libnvjitlink - librmm - mkl - nccl @@ -393,6 +423,9 @@ outputs: - libcurand-dev - libcusolver-dev - libcusparse-dev + - if: cuda_major == "13" + then: + - libnvjitlink-dev run: - ${{ pin_subpackage("libcuvs-headers", exact=True) }} - ${{ pin_subpackage("libcuvs", exact=True) }} @@ -403,6 +436,9 @@ outputs: - libcurand - libcusolver - libcusparse + - if: cuda_major == "13" + then: + - libnvjitlink ignore_run_exports: by_name: - cuda-cudart @@ -413,6 +449,9 @@ outputs: - libcurand - libcusolver - libcusparse + - if: cuda_major == "13" + then: + - libnvjitlink - librmm - mkl - nccl diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 6313db71ca..a75890737e 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -343,6 +343,29 @@ if(NOT BUILD_CPU_ONLY) ) endif() + set(JIT_LTO_TARGET_ARCHITECTURE "") + set(JIT_LTO_COMPILATION OFF) + set(JIT_LTO_FILES "") + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) + set(JIT_LTO_TARGET_ARCHITECTURE "75-real") + set(JIT_LTO_COMPILATION ON) + endif() + + if(JIT_LTO_COMPILATION) + # Generate interleaved scan kernel files at build time + include(cmake/modules/generate_jit_lto_kernels.cmake) + generate_jit_lto_kernels(cuvs_jit_lto_kernels) + add_library(cuvs::cuvs_jit_lto_kernels ALIAS cuvs_jit_lto_kernels) + + set(JIT_LTO_FILES + src/detail/jit_lto/AlgorithmLauncher.cu + src/detail/jit_lto/AlgorithmPlanner.cu + src/detail/jit_lto/FragmentDatabase.cu + src/detail/jit_lto/FragmentEntry.cu + src/detail/jit_lto/nvjitlink_checker.cpp + ) + endif() + add_library( cuvs_objs OBJECT src/cluster/detail/minClusterDistanceCompute.cu @@ -556,6 +579,7 @@ if(NOT BUILD_CPU_ONLY) src/stats/silhouette_score.cu src/stats/trustworthiness_score.cu ${CUVS_MG_ALGOS} + $<$:${JIT_LTO_FILES}> ) set_target_properties( @@ -572,8 +596,10 @@ if(NOT BUILD_CPU_ONLY) ) target_compile_definitions( - cuvs_objs PRIVATE $<$:CUVS_BUILD_CAGRA_HNSWLIB> - $<$:NVTX_ENABLED> + cuvs_objs + PRIVATE $<$:CUVS_BUILD_CAGRA_HNSWLIB> + $<$:NVTX_ENABLED> + $<$:CUVS_ENABLE_JIT_LTO> ) target_link_libraries( @@ -586,6 +612,13 @@ if(NOT BUILD_CPU_ONLY) $ ) + target_include_directories( + cuvs_objs + PUBLIC "$" + "$" + INTERFACE "$" + ) + # Endian detection include(TestBigEndian) test_big_endian(BIG_ENDIAN) @@ -640,8 +673,10 @@ if(NOT BUILD_CPU_ONLY) "$<$:${CUVS_CUDA_FLAGS}>" ) target_compile_definitions( - cuvs PUBLIC $<$:CUVS_BUILD_CAGRA_HNSWLIB> - $<$:NVTX_ENABLED> + cuvs + PUBLIC $<$:CUVS_BUILD_CAGRA_HNSWLIB> + $<$:NVTX_ENABLED> + $<$:CUVS_ENABLE_JIT_LTO> ) target_link_libraries( @@ -653,8 +688,12 @@ if(NOT BUILD_CPU_ONLY) $> $> $<$:CUDA::nvtx3> - PRIVATE $ $ - $ + PRIVATE + $ + $ + $ + $<$:CUDA::nvJitLink> + $<$:$> ) # ensure CUDA symbols aren't relocated to the middle of the debug build binaries @@ -692,8 +731,10 @@ SECTIONS target_compile_options(cuvs_static PRIVATE "$<$:${CUVS_CXX_FLAGS}>") target_compile_definitions( - cuvs_static PUBLIC $<$:CUVS_BUILD_CAGRA_HNSWLIB> - $<$:NVTX_ENABLED> + cuvs_static + PUBLIC $<$:CUVS_BUILD_CAGRA_HNSWLIB> + $<$:NVTX_ENABLED> + $<$:CUVS_ENABLE_JIT_LTO> ) target_include_directories(cuvs_static INTERFACE "$") @@ -709,8 +750,13 @@ SECTIONS ${CUVS_CTK_MATH_DEPENDENCIES} $ # needs to be public for DT_NEEDED $> # header only - PRIVATE $ $<$:CUDA::nvtx3> - $ $ + PRIVATE + $ + $<$:CUDA::nvJitLink> + $<$:CUDA::nvtx3> + $ + $ + $<$:$> ) endif() @@ -751,9 +797,11 @@ target_compile_definitions(cuvs::cuvs INTERFACE $<$:NVTX_ENAB include(GNUInstallDirs) include(CPack) - set(target_names cuvs cuvs_static cuvs_cpp_headers cuvs_c) - set(component_names cuvs_shared cuvs_static cuvs_cpp_headers c_api) - set(export_names cuvs-shared-exports cuvs-static-exports cuvs-cpp-headers-exports cuvs-c-exports) + set(target_names cuvs cuvs_static cuvs_jit_lto_kernels cuvs_cpp_headers cuvs_c) + set(component_names cuvs_shared cuvs_static cuvs_static cuvs_cpp_headers c_api) + set(export_names cuvs-shared-exports cuvs-static-exports cuvs-static-exports + cuvs-cpp-headers-exports cuvs-c-exports + ) foreach(target component export IN ZIP_LISTS target_names component_names export_names) if(TARGET ${target}) install( @@ -794,6 +842,8 @@ target_compile_definitions(cuvs::cuvs INTERFACE $<$:NVTX_ENAB ) endif() + list(REMOVE_DUPLICATES cuvs_components) + list(REMOVE_DUPLICATES cuvs_export_sets) include(cmake/modules/generate_cuvs_export.cmake) generate_cuvs_export(COMPONENTS ${cuvs_components} EXPORT_SETS ${cuvs_export_sets}) diff --git a/cpp/cmake/config.json b/cpp/cmake/config.json index a9f1b53007..aa46006a44 100644 --- a/cpp/cmake/config.json +++ b/cpp/cmake/config.json @@ -10,6 +10,15 @@ "ADDITIONAL_DEP": "?", "PATH": "*" } + }, + "embed_jit_lto_fatbin": { + "kwargs": { + "FATBIN_TARGET": 1, + "FATBIN_SOURCE": 1, + "EMBEDDED_TARGET": 1, + "EMBEDDED_HEADER": 1, + "EMBEDDED_ARRAY": 1 + } } } }, diff --git a/cpp/cmake/modules/generate_jit_lto_kernels.cmake b/cpp/cmake/modules/generate_jit_lto_kernels.cmake new file mode 100644 index 0000000000..671b08321c --- /dev/null +++ b/cpp/cmake/modules/generate_jit_lto_kernels.cmake @@ -0,0 +1,220 @@ +# ============================================================================= +# cmake-format: off +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +# cmake-format: on +# ============================================================================= + +include_guard(GLOBAL) + +function(embed_jit_lto_fatbin) + set(options) + set(one_value FATBIN_TARGET FATBIN_SOURCE EMBEDDED_TARGET EMBEDDED_HEADER EMBEDDED_ARRAY) + set(multi_value) + + cmake_parse_arguments(_JIT_LTO "${options}" "${one_value}" "${multi_value}" ${ARGN}) + + find_package(CUDAToolkit REQUIRED) + find_program( + bin_to_c + NAMES bin2c + PATHS ${CUDAToolkit_BIN_DIR} + ) + + add_library(${_JIT_LTO_FATBIN_TARGET} OBJECT "${_JIT_LTO_FATBIN_SOURCE}") + target_compile_definitions(${_JIT_LTO_FATBIN_TARGET} PRIVATE BUILD_KERNEL) + target_include_directories( + ${_JIT_LTO_FATBIN_TARGET} + PRIVATE "$" + "$" + "$" + ) + target_compile_options( + ${_JIT_LTO_FATBIN_TARGET} + PRIVATE -Xfatbin=--compress-all + --compress-mode=size + "$<$:${CUVS_CXX_FLAGS}>" + "$<$:${CUVS_CUDA_FLAGS}>" + ) + set_target_properties( + ${_JIT_LTO_FATBIN_TARGET} + PROPERTIES CUDA_ARCHITECTURES ${JIT_LTO_TARGET_ARCHITECTURE} + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON + CUDA_SEPARABLE_COMPILATION ON + CUDA_FATBIN_COMPILATION ON + POSITION_INDEPENDENT_CODE ON + INTERPROCEDURAL_OPTIMIZATION ON + ) + target_link_libraries(${_JIT_LTO_FATBIN_TARGET} PRIVATE rmm::rmm raft::raft CCCL::CCCL) + + add_custom_command( + OUTPUT "${_JIT_LTO_EMBEDDED_HEADER}" + COMMAND "${bin_to_c}" -c -p 0x0 --name "${_JIT_LTO_EMBEDDED_ARRAY}" --static + $ > "${_JIT_LTO_EMBEDDED_HEADER}" + DEPENDS $ + ) + target_sources( + ${_JIT_LTO_EMBEDDED_TARGET} PRIVATE "${_JIT_LTO_FATBIN_SOURCE}" "${_JIT_LTO_EMBEDDED_HEADER}" + ) + cmake_path(GET _JIT_LTO_EMBEDDED_HEADER PARENT_PATH header_dir) + target_include_directories(${_JIT_LTO_EMBEDDED_TARGET} PRIVATE "${header_dir}") +endfunction() + +function(parse_jit_lto_data_type_configs config) + set(options) + set(one_value DATA_TYPE ACC_TYPE VECLENS TYPE_ABBREV ACC_ABBREV) + set(multi_value) + + cmake_parse_arguments(_JIT_LTO "${options}" "${one_value}" "${multi_value}" ${ARGN}) + + if(config MATCHES [==[^([^,]+),([^,]+),\[([0-9]+(,[0-9]+)*)?\],([^,]+),([^,]+)$]==]) + if(_JIT_LTO_DATA_TYPE) + set(${_JIT_LTO_DATA_TYPE} + "${CMAKE_MATCH_1}" + PARENT_SCOPE + ) + endif() + if(_JIT_LTO_ACC_TYPE) + set(${_JIT_LTO_ACC_TYPE} + "${CMAKE_MATCH_2}" + PARENT_SCOPE + ) + endif() + if(_JIT_LTO_VECLENS) + string(REPLACE "," ";" veclens_value "${CMAKE_MATCH_3}") + set(${_JIT_LTO_VECLENS} + "${veclens_value}" + PARENT_SCOPE + ) + endif() + if(_JIT_LTO_TYPE_ABBREV) + set(${_JIT_LTO_TYPE_ABBREV} + "${CMAKE_MATCH_5}" + PARENT_SCOPE + ) + endif() + if(_JIT_LTO_ACC_ABBREV) + set(${_JIT_LTO_ACC_ABBREV} + "${CMAKE_MATCH_6}" + PARENT_SCOPE + ) + endif() + else() + message(FATAL_ERROR "Invalid data type config: ${config}") + endif() +endfunction() + +# cmake-lint: disable=R0915 +function(generate_jit_lto_kernels target) + add_library(${target} STATIC) + target_include_directories( + ${target} + PRIVATE "$" + "$" + "$" + ) + set_target_properties(${target} PROPERTIES POSITION_INDEPENDENT_CODE ON) + + set(generated_kernels_dir "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels") + string(TIMESTAMP year "%Y") + + set(capacities 0 1 2 4 8 16 32 64 128 256) + set(ascending_values true false) + set(compute_norm_values true false) + set(data_type_configs "float,float,[1,4],f,f" "__half,__half,[1,8],h,h" + "uint8_t,uint32_t,[1,16],uc,ui" "int8_t,int32_t,[1,16],sc,i" + ) + set(idx_type int64_t) + set(idx_abbrev l) + set(metric_configs euclidean inner_prod) + set(filter_configs filter_none filter_bitset) + set(post_lambda_configs post_identity post_sqrt post_compose) + + foreach(config IN LISTS data_type_configs) + parse_jit_lto_data_type_configs( + "${config}" DATA_TYPE data_type ACC_TYPE acc_type VECLENS veclens TYPE_ABBREV type_abbrev + ACC_ABBREV acc_abbrev + ) + foreach(veclen IN LISTS veclens) + foreach(capacity IN LISTS capacities) + foreach(ascending IN LISTS ascending_values) + foreach(compute_norm IN LISTS compute_norm_values) + set(kernel_name + "interleaved_scan_kernel_${capacity}_${veclen}_${ascending}_${compute_norm}_${type_abbrev}_${acc_abbrev}_${idx_abbrev}" + ) + set(filename + "${generated_kernels_dir}/interleaved_scan_kernels/fatbin_${kernel_name}.cu" + ) + configure_file( + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_kernel.cu.in" + "${filename}" + @ONLY + ) + embed_jit_lto_fatbin( + FATBIN_TARGET "fatbin_${kernel_name}" + FATBIN_SOURCE "${filename}" + EMBEDDED_TARGET "${target}" + EMBEDDED_HEADER "${generated_kernels_dir}/interleaved_scan_kernels/${kernel_name}.h" + EMBEDDED_ARRAY "embedded_${kernel_name}" + ) + endforeach() + endforeach() + endforeach() + + foreach(metric_name IN LISTS metric_configs) + set(header_file "neighbors/ivf_flat/jit_lto_kernels/metric_${metric_name}.cuh") + + set(kernel_name "metric_${metric_name}_${veclen}_${type_abbrev}_${acc_abbrev}") + set(filename "${generated_kernels_dir}/metric_device_functions/fatbin_${kernel_name}.cu") + configure_file( + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/metric.cu.in" + "${filename}" @ONLY + ) + embed_jit_lto_fatbin( + FATBIN_TARGET "fatbin_${kernel_name}" + FATBIN_SOURCE "${filename}" + EMBEDDED_TARGET "${target}" + EMBEDDED_HEADER "${generated_kernels_dir}/metric_device_functions/${kernel_name}.h" + EMBEDDED_ARRAY "embedded_${kernel_name}" + ) + endforeach() + endforeach() + endforeach() + + foreach(filter_name IN LISTS filter_configs) + set(header_file "neighbors/ivf_flat/jit_lto_kernels/${filter_name}.cuh") + + set(kernel_name "${filter_name}") + set(filename "${generated_kernels_dir}/filter_device_functions/fatbin_${kernel_name}.cu") + configure_file( + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/filter.cu.in" + "${filename}" @ONLY + ) + embed_jit_lto_fatbin( + FATBIN_TARGET "fatbin_${kernel_name}" + FATBIN_SOURCE "${filename}" + EMBEDDED_TARGET "${target}" + EMBEDDED_HEADER "${generated_kernels_dir}/filter_device_functions/${kernel_name}.h" + EMBEDDED_ARRAY "embedded_${kernel_name}" + ) + endforeach() + + foreach(post_lambda_name IN LISTS post_lambda_configs) + set(header_file "neighbors/ivf_flat/jit_lto_kernels/${post_lambda_name}.cuh") + + set(kernel_name "${post_lambda_name}") + set(filename "${generated_kernels_dir}/post_lambda_device_functions/${post_lambda_name}.cu") + configure_file( + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda.cu.in" + "${filename}" @ONLY + ) + embed_jit_lto_fatbin( + FATBIN_TARGET "fatbin_${kernel_name}" + FATBIN_SOURCE "${filename}" + EMBEDDED_TARGET "${target}" + EMBEDDED_HEADER "${generated_kernels_dir}/post_lambda_device_functions/${kernel_name}.h" + EMBEDDED_ARRAY "embedded_${kernel_name}" + ) + endforeach() +endfunction() diff --git a/cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp b/cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp new file mode 100644 index 0000000000..7a578a8306 --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp @@ -0,0 +1,36 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +struct AlgorithmLauncher { + AlgorithmLauncher() = default; + + AlgorithmLauncher(cudaKernel_t k); + + template + void dispatch(cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, Args&&... args) + { + void* kernel_args[] = {const_cast(static_cast(&args))...}; + this->call(stream, grid, block, shared_mem, kernel_args); + } + + cudaKernel_t get_kernel() { return this->kernel; } + + private: + void call(cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** args); + cudaKernel_t kernel; +}; + +std::unordered_map>& get_cached_launchers(); diff --git a/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp b/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp new file mode 100644 index 0000000000..93f24d0c6c --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include + +#include "AlgorithmLauncher.hpp" + +struct FragmentEntry; + +struct AlgorithmPlanner { + AlgorithmPlanner(std::string const& n, std::string const& p) : entrypoint(n + "_" + p) {} + + std::shared_ptr get_launcher(); + + std::string entrypoint; + std::vector device_functions; + std::vector fragments; + + private: + void add_entrypoint(); + void add_device_functions(); + std::string get_device_functions_key() const; + std::shared_ptr build(); +}; diff --git a/cpp/include/cuvs/detail/jit_lto/FragmentDatabase.hpp b/cpp/include/cuvs/detail/jit_lto/FragmentDatabase.hpp new file mode 100644 index 0000000000..aeb170d861 --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/FragmentDatabase.hpp @@ -0,0 +1,45 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include + +#include "FragmentEntry.hpp" +#include "MakeFragmentKey.hpp" + +class FragmentDatabase { + public: + FragmentDatabase(FragmentDatabase const&) = delete; + FragmentDatabase(FragmentDatabase&&) = delete; + + FragmentDatabase& operator=(FragmentDatabase&&) = delete; + FragmentDatabase& operator=(FragmentDatabase const&) = delete; + + FragmentEntry* get_fragment(std::string const& key); + + private: + FragmentDatabase(); + + bool make_cache_entry(std::string const& key); + + friend FragmentDatabase& fragment_database(); + + friend void registerFatbinFragment(std::string const& algo, + std::string const& params, + unsigned char const* blob, + std::size_t size); + + std::unordered_map> cache; +}; + +FragmentDatabase& fragment_database(); + +void registerFatbinFragment(std::string const& algo, + std::string const& params, + unsigned char const* blob, + std::size_t size); diff --git a/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp b/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp new file mode 100644 index 0000000000..a376068425 --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp @@ -0,0 +1,32 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include + +#include + +struct FragmentEntry { + FragmentEntry(std::string const& key); + + bool operator==(const FragmentEntry& rhs) const { return compute_key == rhs.compute_key; } + + virtual bool add_to(nvJitLinkHandle& handle) const = 0; + + std::string compute_key{}; +}; + +struct FatbinFragmentEntry final : FragmentEntry { + FatbinFragmentEntry(std::string const& key, unsigned char const* view, std::size_t size); + + virtual bool add_to(nvJitLinkHandle& handle) const; + + std::size_t data_size = 0; + unsigned char const* data_view = nullptr; +}; diff --git a/cpp/include/cuvs/detail/jit_lto/MakeFragmentKey.hpp b/cpp/include/cuvs/detail/jit_lto/MakeFragmentKey.hpp new file mode 100644 index 0000000000..21482d5234 --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/MakeFragmentKey.hpp @@ -0,0 +1,31 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include + +namespace detail { + +template +std::string type_as_string() +{ + if constexpr (std::is_reference_v) { + return std::string(typeid(T).name()) + "&"; + } else { + return std::string(typeid(T).name()); + } +} +} // namespace detail + +template +std::string make_fragment_key() +{ + std::string result; + ((result += detail::type_as_string() + "_"), ...); + if (!result.empty()) { result.pop_back(); } + return result; +} diff --git a/cpp/include/cuvs/detail/jit_lto/RegisterKernelFragment.hpp b/cpp/include/cuvs/detail/jit_lto/RegisterKernelFragment.hpp new file mode 100644 index 0000000000..5643be6523 --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/RegisterKernelFragment.hpp @@ -0,0 +1,24 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "MakeFragmentKey.hpp" + +void registerFatbinFragment(std::string const& algo, + std::string const& params, + unsigned char const* blob, + std::size_t size); + +namespace { + +template +void registerAlgorithm(std::string algo, unsigned char const* blob, std::size_t size) +{ + auto key = make_fragment_key(); + registerFatbinFragment(algo, key, blob, size); +} + +} // namespace diff --git a/cpp/include/cuvs/detail/jit_lto/ivf_flat/interleaved_scan_tags.hpp b/cpp/include/cuvs/detail/jit_lto/ivf_flat/interleaved_scan_tags.hpp new file mode 100644 index 0000000000..d9ed7e6b0b --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/ivf_flat/interleaved_scan_tags.hpp @@ -0,0 +1,45 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +namespace cuvs::neighbors::ivf_flat::detail { + +// Tag types for data types +struct tag_f {}; +struct tag_h {}; +struct tag_sc {}; +struct tag_uc {}; + +// Tag types for accumulator types +struct tag_acc_f {}; +struct tag_acc_h {}; +struct tag_acc_i {}; +struct tag_acc_ui {}; + +// Tag types for index types +struct tag_idx_l {}; + +// Tag types for filter subtypes +struct tag_filter_bitset_impl {}; +struct tag_filter_none_impl {}; + +// Tag types for sample filter types with full template info +template +struct tag_filter {}; + +// Tag types for distance metrics with full template info +template +struct tag_metric_euclidean {}; + +template +struct tag_metric_inner_product {}; + +// Tag types for post-processing +struct tag_post_identity {}; +struct tag_post_sqrt {}; +struct tag_post_compose {}; + +} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index b2da48aaa1..39967999ed 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -539,7 +539,8 @@ struct ivf_to_sample_filter : public base_filter { const index_t* const* inds_ptrs_; const filter_t next_filter_; - ivf_to_sample_filter(const index_t* const* inds_ptrs, const filter_t next_filter); + _RAFT_HOST_DEVICE ivf_to_sample_filter(const index_t* const* inds_ptrs, + const filter_t next_filter); /** \cond */ /** If the original filter takes three arguments, then don't modify the arguments. diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index b28c01de04..23c6dd4944 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -1401,6 +1401,9 @@ void extend(raft::resources const& handle, /** * @brief Search ANN using the constructed index. + * This function JIT compiles the kernel for the very first usage, after which it maintains an + * in-memory and disk-based cache of the compiled kernels. We recommend running a warmup search + * before the actual searches to avoid the first-time JIT compilation overhead. * * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. * @@ -1442,6 +1445,9 @@ void search(raft::resources const& handle, /** * @brief Search ANN using the constructed index. + * This function JIT compiles the kernel for the very first usage, after which it maintains an + * in-memory and disk-based cache of the compiled kernels. We recommend running a warmup search + * before the actual searches to avoid the first-time JIT compilation overhead. * * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. * @@ -1482,6 +1488,9 @@ void search(raft::resources const& handle, cuvs::neighbors::filtering::none_sample_filter{}); /** * @brief Search ANN using the constructed index. + * This function JIT compiles the kernel for the very first usage, after which it maintains an + * in-memory and disk-based cache of the compiled kernels. We recommend running a warmup search + * before the actual searches to avoid the first-time JIT compilation overhead. * * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. * @@ -1523,6 +1532,9 @@ void search(raft::resources const& handle, /** * @brief Search ANN using the constructed index. + * This function JIT compiles the kernel for the very first usage, after which it maintains an + * in-memory and disk-based cache of the compiled kernels. We recommend running a warmup search + * before the actual searches to avoid the first-time JIT compilation overhead. * * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. * diff --git a/cpp/src/detail/jit_lto/AlgorithmLauncher.cu b/cpp/src/detail/jit_lto/AlgorithmLauncher.cu new file mode 100644 index 0000000000..d095689f5d --- /dev/null +++ b/cpp/src/detail/jit_lto/AlgorithmLauncher.cu @@ -0,0 +1,34 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include + +AlgorithmLauncher::AlgorithmLauncher(cudaKernel_t k) : kernel{k} {} + +void AlgorithmLauncher::call( + cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** kernel_args) +{ + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = 1; + + cudaLaunchConfig_t config; + config.gridDim = grid; + config.blockDim = block; + config.stream = stream; + config.attrs = attribute; + config.numAttrs = 1; + config.dynamicSmemBytes = shared_mem; + + RAFT_CUDA_TRY(cudaLaunchKernelExC(&config, kernel, kernel_args)); +} + +std::unordered_map>& get_cached_launchers() +{ + static std::unordered_map> launchers; + return launchers; +} diff --git a/cpp/src/detail/jit_lto/AlgorithmPlanner.cu b/cpp/src/detail/jit_lto/AlgorithmPlanner.cu new file mode 100644 index 0000000000..0983267e04 --- /dev/null +++ b/cpp/src/detail/jit_lto/AlgorithmPlanner.cu @@ -0,0 +1,120 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "nvjitlink_checker.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "cuda_runtime.h" +#include "nvJitLink.h" + +#include +#include + +void AlgorithmPlanner::add_entrypoint() +{ + auto entrypoint_fragment = fragment_database().get_fragment(this->entrypoint); + this->fragments.push_back(entrypoint_fragment); +} + +void AlgorithmPlanner::add_device_functions() +{ + for (const auto& device_function_key : this->device_functions) { + auto device_function_fragment = fragment_database().get_fragment(device_function_key); + this->fragments.push_back(device_function_fragment); + } +} + +std::string AlgorithmPlanner::get_device_functions_key() const +{ + std::string key = ""; + for (const auto& device_function : this->device_functions) { + key += device_function; + } + return key; +} + +std::shared_ptr AlgorithmPlanner::get_launcher() +{ + auto& launchers = get_cached_launchers(); + auto launch_key = this->entrypoint + this->get_device_functions_key(); + + static std::mutex cache_mutex; + std::lock_guard lock(cache_mutex); + if (launchers.count(launch_key) == 0) { + add_entrypoint(); + add_device_functions(); + std::string log_message = + "JIT compiling launcher for entrypoint: " + this->entrypoint + " and device functions: "; + for (const auto& device_function : this->device_functions) { + log_message += device_function + ","; + } + log_message.pop_back(); + RAFT_LOG_INFO("%s", log_message.c_str()); + launchers[launch_key] = this->build(); + } + return launchers[launch_key]; +} + +std::shared_ptr AlgorithmPlanner::build() +{ + int device = 0; + int major = 0; + int minor = 0; + RAFT_CUDA_TRY(cudaGetDevice(&device)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); + + std::string archs = "-arch=sm_" + std::to_string((major * 10 + minor)); + + // Load the generated LTO IR and link them together + nvJitLinkHandle handle; + const char* lopts[] = {"-lto", archs.c_str()}; + auto result = nvJitLinkCreate(&handle, 2, lopts); + check_nvjitlink_result(handle, result); + + for (auto& frag : this->fragments) { + frag->add_to(handle); + } + + // Call to nvJitLinkComplete causes linker to link together all the LTO-IR + // modules perform any optimizations and generate cubin from it. + result = nvJitLinkComplete(handle); + check_nvjitlink_result(handle, result); + + // get cubin from nvJitLink + size_t cubin_size; + result = nvJitLinkGetLinkedCubinSize(handle, &cubin_size); + check_nvjitlink_result(handle, result); + + std::unique_ptr cubin{new char[cubin_size]}; + result = nvJitLinkGetLinkedCubin(handle, cubin.get()); + check_nvjitlink_result(handle, result); + + result = nvJitLinkDestroy(&handle); + RAFT_EXPECTS(result == NVJITLINK_SUCCESS, "nvJitLinkDestroy failed"); + + // cubin is linked, so now load it + // NOTE: cudaLibrary_t does not need to be freed explicitly + cudaLibrary_t library; + RAFT_CUDA_TRY( + cudaLibraryLoadData(&library, cubin.get(), nullptr, nullptr, 0, nullptr, nullptr, 0)); + + constexpr unsigned int count = 1; + // NOTE: cudaKernel_t does not need to be freed explicitly + std::unique_ptr kernels{new cudaKernel_t[count]}; + RAFT_CUDA_TRY(cudaLibraryEnumerateKernels(kernels.get(), count, library)); + + return std::make_shared(kernels.release()[0]); +} diff --git a/cpp/src/detail/jit_lto/FragmentDatabase.cu b/cpp/src/detail/jit_lto/FragmentDatabase.cu new file mode 100644 index 0000000000..02ea688a0d --- /dev/null +++ b/cpp/src/detail/jit_lto/FragmentDatabase.cu @@ -0,0 +1,47 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +#include + +FragmentDatabase::FragmentDatabase() {} + +bool FragmentDatabase::make_cache_entry(std::string const& key) +{ + if (this->cache.count(key) == 0) { + this->cache[key] = std::unique_ptr{}; + return false; + } + return true; +} + +FragmentDatabase& fragment_database() +{ + static FragmentDatabase database; + return database; +} + +FragmentEntry* FragmentDatabase::get_fragment(std::string const& key) +{ + auto& db = fragment_database(); + auto val = db.cache.find(key); + RAFT_EXPECTS(val != db.cache.end(), "FragmentDatabase: Key not found: %s", key.c_str()); + return val->second.get(); +} + +void registerFatbinFragment(std::string const& algo, + std::string const& params, + unsigned char const* blob, + std::size_t size) +{ + auto& planner = fragment_database(); + std::string key = algo; + if (!params.empty()) { key += "_" + params; } + auto entry_exists = planner.make_cache_entry(key); + if (entry_exists) { return; } + planner.cache[key] = std::make_unique(key, blob, size); +} diff --git a/cpp/src/detail/jit_lto/FragmentEntry.cu b/cpp/src/detail/jit_lto/FragmentEntry.cu new file mode 100644 index 0000000000..af1fb90e58 --- /dev/null +++ b/cpp/src/detail/jit_lto/FragmentEntry.cu @@ -0,0 +1,28 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "nvjitlink_checker.hpp" + +#include + +#include + +FragmentEntry::FragmentEntry(std::string const& key) : compute_key(key) {} + +FatbinFragmentEntry::FatbinFragmentEntry(std::string const& key, + unsigned char const* view, + std::size_t size) + : FragmentEntry(key), data_size(size), data_view(view) +{ +} + +bool FatbinFragmentEntry::add_to(nvJitLinkHandle& handle) const +{ + auto result = nvJitLinkAddData( + handle, NVJITLINK_INPUT_ANY, this->data_view, this->data_size, this->compute_key.c_str()); + + check_nvjitlink_result(handle, result); + return true; +} diff --git a/cpp/src/detail/jit_lto/nvjitlink_checker.cpp b/cpp/src/detail/jit_lto/nvjitlink_checker.cpp new file mode 100644 index 0000000000..6f9ae988db --- /dev/null +++ b/cpp/src/detail/jit_lto/nvjitlink_checker.cpp @@ -0,0 +1,27 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "nvjitlink_checker.hpp" + +#include +#include +#include + +#include + +void check_nvjitlink_result(nvJitLinkHandle handle, nvJitLinkResult result) +{ + if (result != NVJITLINK_SUCCESS) { + std::string error_msg = "nvJITLink failed with error " + std::to_string(result); + size_t log_size = 0; + result = nvJitLinkGetErrorLogSize(handle, &log_size); + if (result == NVJITLINK_SUCCESS && log_size > 0) { + std::unique_ptr log{new char[log_size]}; + result = nvJitLinkGetErrorLog(handle, log.get()); + if (result == NVJITLINK_SUCCESS) { error_msg += "\n" + std::string(log.get()); } + } + RAFT_FAIL("AlgorithmPlanner nvJITLink error log: %s", error_msg.c_str()); + } +} diff --git a/cpp/src/detail/jit_lto/nvjitlink_checker.hpp b/cpp/src/detail/jit_lto/nvjitlink_checker.hpp new file mode 100644 index 0000000000..12b062d795 --- /dev/null +++ b/cpp/src/detail/jit_lto/nvjitlink_checker.hpp @@ -0,0 +1,11 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +// We can make a better RAII wrapper around nvjitlinkhandle +void check_nvjitlink_result(nvJitLinkHandle handle, nvJitLinkResult result); diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_explicit_inst.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_explicit_inst.cuh index 002286b31f..81833a63b1 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_explicit_inst.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_explicit_inst.cuh @@ -1,12 +1,16 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once #include "../detail/ann_utils.cuh" +#ifdef CUVS_ENABLE_JIT_LTO +#include "ivf_flat_interleaved_scan_jit.cuh" +#else #include "ivf_flat_interleaved_scan.cuh" +#endif #include #include #include diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh new file mode 100644 index 0000000000..be8652dd59 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh @@ -0,0 +1,445 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../ivf_common.cuh" +#include "jit_lto_kernels/interleaved_scan_planner.hpp" +#include +#include +#include +#include + +#include "../detail/ann_utils.cuh" +#include +#include +#include +#include // RAFT_CUDA_TRY +#include + +#include + +namespace cuvs::neighbors::ivf_flat::detail { + +static constexpr int kThreadsPerBlock = 128; + +using namespace cuvs::spatial::knn::detail; // NOLINT + +// Constexpr mapping functions from actual types to tags +template +constexpr auto get_data_type_tag() +{ + if constexpr (std::is_same_v) { return tag_f{}; } + if constexpr (std::is_same_v) { return tag_h{}; } + if constexpr (std::is_same_v) { return tag_sc{}; } + if constexpr (std::is_same_v) { return tag_uc{}; } +} + +template +constexpr auto get_acc_type_tag() +{ + if constexpr (std::is_same_v) { return tag_acc_f{}; } + if constexpr (std::is_same_v) { return tag_acc_h{}; } + if constexpr (std::is_same_v) { return tag_acc_i{}; } + if constexpr (std::is_same_v) { return tag_acc_ui{}; } +} + +template +constexpr auto get_idx_type_tag() +{ + if constexpr (std::is_same_v) { return tag_idx_l{}; } +} + +template +constexpr auto get_filter_type_tag() +{ + using namespace cuvs::neighbors::filtering; + + // Determine the filter implementation tag + if constexpr (std::is_same_v) { + return tag_filter{}; + } + if constexpr (std::is_same_v>) { + return tag_filter{}; + } +} + +template +constexpr auto get_metric_name() +{ + if constexpr (std::is_same_v>) { + return "euclidean"; + } + if constexpr (std::is_same_v>) { + return "inner_prod"; + } +} + +template +constexpr auto get_filter_name() +{ + if constexpr (std::is_same_v>) { + return "filter_none"; + } + if constexpr (std::is_same_v>) { + return "filter_bitset"; + } +} + +template +constexpr auto get_post_lambda_name() +{ + if constexpr (std::is_same_v) { return "post_identity"; } + if constexpr (std::is_same_v) { return "post_sqrt"; } + if constexpr (std::is_same_v) { return "post_compose"; } +} + +/** + * Configure the gridDim.x to maximize GPU occupancy, but reduce the output size + */ +inline uint32_t configure_launch_x(uint32_t numQueries, + uint32_t n_probes, + int32_t sMemSize, + cudaKernel_t func) +{ + int dev_id; + RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); + int num_sms; + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); + int num_blocks_per_sm = 0; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, func, kThreadsPerBlock, sMemSize)); + + size_t min_grid_size = num_sms * num_blocks_per_sm; + size_t min_grid_x = raft::ceildiv(min_grid_size, numQueries); + return min_grid_x > n_probes ? n_probes : static_cast(min_grid_x); +} + +template +void launch_kernel(const index& index, + const T* queries, + const uint32_t* coarse_index, + const uint32_t num_queries, + const uint32_t queries_offset, + const uint32_t n_probes, + const uint32_t k, + const uint32_t max_samples, + const uint32_t* chunk_indices, + IdxT* const* const inds_ptrs, + cuda::std::optional bitset_ptr, + cuda::std::optional bitset_len, + cuda::std::optional original_nbits, + uint32_t* neighbors, + float* distances, + uint32_t& grid_dim_x, + rmm::cuda_stream_view stream) +{ + RAFT_EXPECTS(Veclen == index.veclen(), + "Configured Veclen does not match the index interleaving pattern."); + + // Use tag types for the planner to avoid template bloat + auto kernel_planner = InterleavedScanPlanner()), + decltype(get_acc_type_tag()), + decltype(get_idx_type_tag())>( + Capacity, Veclen, Ascending, ComputeNorm); + kernel_planner.template add_metric_device_function()), + decltype(get_acc_type_tag())>( + get_metric_name(), Veclen); + kernel_planner.add_filter_device_function(get_filter_name()); + kernel_planner.add_post_lambda_device_function(get_post_lambda_name()); + auto kernel_launcher = kernel_planner.get_launcher(); + + const int max_query_smem = 16384; + int query_smem_elems = std::min(max_query_smem / sizeof(T), + raft::Pow2::roundUp(index.dim())); + int smem_size = query_smem_elems * sizeof(T); + + if constexpr (Capacity > 0) { + constexpr int kSubwarpSize = std::min(Capacity, raft::WarpSize); + auto block_merge_mem = + raft::matrix::detail::select::warpsort::calc_smem_size_for_block_wide( + kThreadsPerBlock / kSubwarpSize, k); + smem_size += std::max(smem_size, block_merge_mem); + } + + // power-of-two less than cuda limit (for better addr alignment) + constexpr uint32_t kMaxGridY = 32768; + + if (grid_dim_x == 0) { + grid_dim_x = configure_launch_x( + std::min(kMaxGridY, num_queries), n_probes, smem_size, kernel_launcher->get_kernel()); + return; + } + + for (uint32_t query_offset = 0; query_offset < num_queries; query_offset += kMaxGridY) { + uint32_t grid_dim_y = std::min(kMaxGridY, num_queries - query_offset); + dim3 grid_dim(grid_dim_x, grid_dim_y, 1); + dim3 block_dim(kThreadsPerBlock); + RAFT_LOG_TRACE( + "Launching the ivf-flat interleaved_scan_kernel (%d, %d, 1) x (%d, 1, 1), n_probes = %d, " + "smem_size = %d", + grid_dim.x, + grid_dim.y, + block_dim.x, + n_probes, + smem_size); + kernel_launcher->dispatch(stream, + grid_dim, + block_dim, + smem_size, + query_smem_elems, + queries, + coarse_index, + index.data_ptrs().data_handle(), + index.list_sizes().data_handle(), + queries_offset + query_offset, + n_probes, + k, + max_samples, + chunk_indices, + index.dim(), + // sample_filter, + inds_ptrs, + bitset_ptr.value_or(nullptr), + bitset_len.value_or(0), + original_nbits.value_or(0), + neighbors, + distances); + queries += grid_dim_y * index.dim(); + if constexpr (Capacity > 0) { + neighbors += grid_dim_y * grid_dim_x * k; + distances += grid_dim_y * grid_dim_x * k; + } else { + distances += grid_dim_y * max_samples; + } + chunk_indices += grid_dim_y * n_probes; + coarse_index += grid_dim_y * n_probes; + } +} + +/** Select the distance computation function and forward the rest of the arguments. */ +template +void launch_with_fixed_consts(cuvs::distance::DistanceType metric, Args&&... args) +{ + switch (metric) { + case cuvs::distance::DistanceType::L2Expanded: + case cuvs::distance::DistanceType::L2Unexpanded: + return launch_kernel, + tag_post_identity>(std::forward(args)...); + case cuvs::distance::DistanceType::L2SqrtExpanded: + case cuvs::distance::DistanceType::L2SqrtUnexpanded: + return launch_kernel, + tag_post_sqrt>(std::forward(args)...); + case cuvs::distance::DistanceType::InnerProduct: + return launch_kernel, + tag_post_identity>(std::forward(args)...); + case cuvs::distance::DistanceType::CosineExpanded: + // NB: "Ascending" is reversed because the post-processing step is done after that sort + return launch_kernel, + tag_post_compose>( + std::forward(args)...); // NB: update the description of `knn::ivf_flat::build` when + // adding here a new metric. + default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); + } +} + +/** + * Lift the `capacity` and `veclen` parameters to the template level, + * forward the rest of the arguments unmodified to `launch_interleaved_scan_kernel`. + */ +template (1, 16 / sizeof(T))> +struct select_interleaved_scan_kernel { + /** + * Recursively reduce the `Capacity` and `Veclen` parameters until they match the + * corresponding runtime arguments. + * By default, this recursive process starts with maximum possible values of the + * two parameters and ends with both values equal to 1. + */ + template + static inline void run(int k_max, int veclen, bool select_min, Args&&... args) + { + if constexpr (Capacity > 0) { + if (k_max == 0 || k_max > Capacity) { + return select_interleaved_scan_kernel::run( + k_max, veclen, select_min, std::forward(args)...); + } + } + if constexpr (Capacity > 1) { + if (k_max * 2 <= Capacity) { + return select_interleaved_scan_kernel::run(k_max, + veclen, + select_min, + std::forward(args)...); + } + } + if constexpr (Veclen > 1) { + if (veclen % Veclen != 0) { + return select_interleaved_scan_kernel::run( + k_max, 1, select_min, std::forward(args)...); + } + } + // NB: this is the limitation of the warpsort structures that use a huge number of + // registers (used in the main kernel here). + RAFT_EXPECTS(Capacity == 0 || k_max == Capacity, + "Capacity must be either 0 or a power-of-two not bigger than the maximum " + "allowed size matrix::detail::select::warpsort::kMaxCapacity (%d).", + raft::matrix::detail::select::warpsort::kMaxCapacity); + RAFT_EXPECTS( + veclen == Veclen, + "Veclen must be power-of-two not bigger than the maximum allowed size for this data type."); + if (select_min) { + launch_with_fixed_consts( + std::forward(args)...); + } else { + launch_with_fixed_consts( + std::forward(args)...); + } + } +}; + +/** + * @brief Configure and launch an appropriate template instance of the interleaved scan kernel. + * + * @tparam T value type + * @tparam AccT accumulated type + * @tparam IdxT type of the indices + * + * @param index previously built ivf-flat index + * @param[in] queries device pointer to the query vectors [batch_size, dim] + * @param[in] coarse_query_results device pointer to the cluster (list) ids [batch_size, n_probes] + * @param n_queries batch size + * @param[in] queries_offset + * An offset of the current query batch. It is used for feeding sample_filter with the + * correct query index. + * @param metric type of the measured distance + * @param n_probes number of nearest clusters to query + * @param k number of nearest neighbors. + * NB: the maximum value of `k` is limited statically by `kMaxCapacity`. + * @param select_min whether to select nearest (true) or furthest (false) points w.r.t. the given + * metric. + * @param[out] neighbors device pointer to the result indices for each query and cluster + * [batch_size, grid_dim_x, k] + * @param[out] distances device pointer to the result distances for each query and cluster + * [batch_size, grid_dim_x, k] + * @param[inout] grid_dim_x number of blocks launched across all n_probes clusters; + * (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes) + * @param stream + * @param sample_filter + * A filter that selects samples for a given query. Use an instance of none_sample_filter to + * provide a green light for every sample. + */ +template +void ivfflat_interleaved_scan(const index& index, + const T* queries, + const uint32_t* coarse_query_results, + const uint32_t n_queries, + const uint32_t queries_offset, + const cuvs::distance::DistanceType metric, + const uint32_t n_probes, + const uint32_t k, + const uint32_t max_samples, + const uint32_t* chunk_indices, + const bool select_min, + IvfSampleFilterT sample_filter, + uint32_t* neighbors, + float* distances, + uint32_t& grid_dim_x, + rmm::cuda_stream_view stream) +{ + const int capacity = raft::bound_by_power_of_two(k); + + cuda::std::optional bitset_ptr; + cuda::std::optional bitset_len; + cuda::std::optional original_nbits; + + if constexpr (std::is_same_v>) { + bitset_ptr = sample_filter.view().data(); + bitset_len = sample_filter.view().size(); + original_nbits = sample_filter.view().get_original_nbits(); + } + select_interleaved_scan_kernel())>:: + run(capacity, + index.veclen(), + select_min, + metric, + index, + queries, + coarse_query_results, + n_queries, + queries_offset, + n_probes, + k, + max_samples, + chunk_indices, + index.inds_ptrs().data_handle(), + bitset_ptr, + bitset_len, + original_nbits, + neighbors, + distances, + grid_dim_x, + stream); +} + +} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter.cu.in b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter.cu.in new file mode 100644 index 0000000000..934e36dba7 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter.cu.in @@ -0,0 +1,32 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) @year@, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +// This file is auto-generated. Do not edit manually. + +#ifdef BUILD_KERNEL + +#include <@header_file@> + +namespace cuvs::neighbors::ivf_flat::detail { + +// Instantiate the device function template +template __device__ bool sample_filter(int64_t* const* const, const uint32_t, const uint32_t, const uint32_t, uint32_t*, int64_t, int64_t); + +} // namespace cuvs::neighbors::ivf_flat::detail + +#else + +#include +#include "@filter_name@.h" + +__attribute__((__constructor__)) static void register_@filter_name@() +{ + registerAlgorithm( + "@filter_name@", + embedded_@filter_name@, + sizeof(embedded_@filter_name@)); +} + +#endif diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_bitset.cuh b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_bitset.cuh new file mode 100644 index 0000000000..07fc4a21f5 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_bitset.cuh @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../../sample_filter.cuh" + +namespace cuvs::neighbors::ivf_flat::detail { + +template +__device__ bool sample_filter(index_t* const* const inds_ptrs, + const uint32_t query_ix, + const uint32_t cluster_ix, + const uint32_t sample_ix, + uint32_t* bitset_ptr, + index_t bitset_len, + index_t original_nbits) +{ + auto bitset_view = + raft::core::bitset_view{bitset_ptr, bitset_len, original_nbits}; + auto bitset_filter = cuvs::neighbors::filtering::bitset_filter{bitset_view}; + auto ivf_to_sample_filter = cuvs::neighbors::filtering:: + ivf_to_sample_filter>{ + inds_ptrs, bitset_filter}; + return ivf_to_sample_filter(query_ix, cluster_ix, sample_ix); +} + +} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_none.cuh b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_none.cuh new file mode 100644 index 0000000000..aad15d64bc --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_none.cuh @@ -0,0 +1,24 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../../sample_filter.cuh" + +namespace cuvs::neighbors::ivf_flat::detail { + +template +__device__ constexpr bool sample_filter(index_t* const* const inds_ptrs, + const uint32_t query_ix, + const uint32_t cluster_ix, + const uint32_t sample_ix, + uint32_t* bitset_ptr, + index_t bitset_len, + index_t original_nbits) +{ + return true; +} + +} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_kernel.cu.in b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_kernel.cu.in new file mode 100644 index 0000000000..5e75253939 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_kernel.cu.in @@ -0,0 +1,40 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) @year@, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +// This file is auto-generated. Do not edit manually. + +#ifdef BUILD_KERNEL + +#include + +namespace cuvs::neighbors::ivf_flat::detail { + +// Instantiate the kernel template +template __global__ void interleaved_scan_kernel<@capacity@, @veclen@, @ascending@, @compute_norm@, @data_type@, @acc_type@, @idx_type@>( + const uint32_t, const @data_type@*, const uint32_t*, const @data_type@* const*, const uint32_t*, + const uint32_t, const uint32_t, const uint32_t, const uint32_t, const uint32_t*, const uint32_t, + @idx_type@* const* const, uint32_t*, @idx_type@, @idx_type@, uint32_t*, float*); + +} // namespace cuvs::neighbors::ivf_flat::detail + +#else + +#include +#include +#include "interleaved_scan_kernel_@capacity@_@veclen@_@ascending@_@compute_norm@_@type_abbrev@_@acc_abbrev@_@idx_abbrev@.h" + +using namespace cuvs::neighbors::ivf_flat::detail; + +__attribute__((__constructor__)) static void register_kernel_@capacity@_@veclen@_@ascending@_@compute_norm@_@type_abbrev@_@acc_abbrev@_@idx_abbrev@() +{ + registerAlgorithm( + "interleaved_scan_kernel_@capacity@_@veclen@_@ascending@_@compute_norm@", + embedded_interleaved_scan_kernel_@capacity@_@veclen@_@ascending@_@compute_norm@_@type_abbrev@_@acc_abbrev@_@idx_abbrev@, + sizeof(embedded_interleaved_scan_kernel_@capacity@_@veclen@_@ascending@_@compute_norm@_@type_abbrev@_@acc_abbrev@_@idx_abbrev@)); +} + +#endif diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_planner.hpp b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_planner.hpp new file mode 100644 index 0000000000..792c64f39a --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_planner.hpp @@ -0,0 +1,45 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include +#include + +inline std::string bool_to_string(bool b) { return b ? "true" : "false"; } + +template +struct InterleavedScanPlanner : AlgorithmPlanner { + InterleavedScanPlanner(int Capacity, int Veclen, bool Ascending, bool ComputeNorm) + : AlgorithmPlanner("interleaved_scan_kernel_" + std::to_string(Capacity) + "_" + + std::to_string(Veclen) + "_" + bool_to_string(Ascending) + "_" + + bool_to_string(ComputeNorm), + make_fragment_key()) + { + } + + template + void add_metric_device_function(std::string metric_name, int Veclen) + { + auto key = metric_name + "_" + std::to_string(Veclen); + auto params = make_fragment_key(); + this->device_functions.push_back(key + "_" + params); + } + + void add_filter_device_function(std::string filter_name) + { + auto key = filter_name; + this->device_functions.push_back(key); + } + + void add_post_lambda_device_function(std::string post_lambda_name) + { + auto key = post_lambda_name; + this->device_functions.push_back(key); + } +}; diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/ivf_flat_interleaved_scan_kernel.cuh b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/ivf_flat_interleaved_scan_kernel.cuh new file mode 100644 index 0000000000..3a14fe8afd --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/ivf_flat_interleaved_scan_kernel.cuh @@ -0,0 +1,924 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../../ivf_common.cuh" + +#include + +#include +#include +#include +#include + +// This header contains the kernel definition and should only be included +// when compiling JIT-LTO kernel fragments (when BUILD_KERNEL is defined). + +namespace cuvs::neighbors::ivf_flat::detail { + +static constexpr int kThreadsPerBlock = 128; + +// These extern device functions are linked at runtime using JIT-LTO. +template +extern __device__ void compute_dist(AccT& acc, AccT x, AccT y); + +template +extern __device__ bool sample_filter(index_t* const* const inds_ptrs, + const uint32_t query_ix, + const uint32_t cluster_ix, + const uint32_t sample_ix, + uint32_t* bitset_ptr, + index_t bitset_len, + index_t original_nbits); + +template +extern __device__ T post_process(T val); + +/** + * @brief Load a part of a vector from the index and from query, compute the (part of the) distance + * between them, and aggregate it using the provided Lambda; one structure per thread, per query, + * and per index item. + * + * @tparam kUnroll elements per loop (normally, kUnroll = WarpSize / Veclen) + * @tparam Lambda computing the part of the distance for one dimension and aggregating it: + * void (AccT& acc, AccT x, AccT y) + * @tparam Veclen size of the vectorized load + * @tparam T type of the data in the query and the index + * @tparam AccT type of the accumulated value (an optimization for 8bit values to be loaded as 32bit + * values) + */ +template +struct loadAndComputeDist { + AccT& dist; + AccT& norm_query; + AccT& norm_data; + + __device__ __forceinline__ loadAndComputeDist(AccT& dist, AccT& norm_query, AccT& norm_data) + : dist(dist), norm_query(norm_query), norm_data(norm_data) + { + } + + /** + * Load parts of vectors from the index and query and accumulates the partial distance. + * This version assumes the query is stored in shared memory. + * Every thread here processes exactly kUnroll * Veclen elements independently of others. + */ + template + __device__ __forceinline__ void runLoadShmemCompute(const T* const& data, + const T* query_shared, + IdxT loadIndex, + IdxT shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + T encV[Veclen]; + raft::ldg(encV, data + (loadIndex + j * kIndexGroupSize) * Veclen); + T queryRegs[Veclen]; + raft::lds(queryRegs, &query_shared[shmemIndex + j * Veclen]); +#pragma unroll + for (int k = 0; k < Veclen; ++k) { + compute_dist(dist, queryRegs[k], encV[k]); + if constexpr (ComputeNorm) { + norm_query += queryRegs[k] * queryRegs[k]; + norm_data += encV[k] * encV[k]; + } + } + } + } + + /** + * Load parts of vectors from the index and query and accumulates the partial distance. + * This version assumes the query is stored in the global memory and is different for every + * thread. One warp loads exactly WarpSize query elements at once and then reshuffles them into + * corresponding threads (`WarpSize / (kUnroll * Veclen)` elements per thread at once). + */ + template + __device__ __forceinline__ void runLoadShflAndCompute(const T*& data, + const T* query, + IdxT baseLoadIndex, + const int lane_id) + { + T queryReg = query[baseLoadIndex + lane_id]; + constexpr int stride = kUnroll * Veclen; + constexpr int totalIter = raft::WarpSize / stride; + constexpr int gmemStride = stride * kIndexGroupSize; +#pragma unroll + for (int i = 0; i < totalIter; ++i, data += gmemStride) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + T encV[Veclen]; + raft::ldg(encV, data + (lane_id + j * kIndexGroupSize) * Veclen); + const int d = (i * kUnroll + j) * Veclen; +#pragma unroll + for (int k = 0; k < Veclen; ++k) { + T q = raft::shfl(queryReg, d + k, raft::WarpSize); + compute_dist(dist, q, encV[k]); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += encV[k] * encV[k]; + } + } + } + } + } + + /** + * Load parts of vectors from the index and query and accumulates the partial distance. + * This version augments `runLoadShflAndCompute` when `dim` is not a multiple of `WarpSize`. + */ + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + const T*& data, const T* query, const int lane_id, const int dim, const int dimBlocks) + { + const int loadDim = dimBlocks + lane_id; + T queryReg = loadDim < dim ? query[loadDim] : T{0}; + const int loadDataIdx = lane_id * Veclen; + for (int d = 0; d < dim - dimBlocks; d += Veclen, data += kIndexGroupSize * Veclen) { + T enc[Veclen]; + raft::ldg(enc, data + loadDataIdx); +#pragma unroll + for (int k = 0; k < Veclen; k++) { + T q = raft::shfl(queryReg, d + k, raft::WarpSize); + compute_dist(dist, q, enc[k]); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += enc[k] * enc[k]; + } + } + } + } +}; + +// This handles uint8_t 8, 16 Veclens +template +struct loadAndComputeDist { + uint32_t& dist; + uint32_t& norm_query; + uint32_t& norm_data; + + __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, + uint32_t& norm_query, + uint32_t& norm_data) + : dist(dist), norm_query(norm_query), norm_data(norm_data) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + const uint8_t* query_shared, + int loadIndex, + int shmemIndex) + { + constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int + loadIndex = loadIndex * veclen_int; +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV[veclen_int]; + raft::ldg( + encV, + reinterpret_cast(data) + loadIndex + j * kIndexGroupSize * veclen_int); + uint32_t queryRegs[veclen_int]; + raft::lds(queryRegs, + reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); +#pragma unroll + for (int k = 0; k < veclen_int; k++) { + compute_dist(dist, queryRegs[k], encV[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(queryRegs[k], queryRegs[k], norm_query); + norm_data = raft::dp4a(encV[k], encV[k], norm_data); + } + } + } + } + __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, + const uint8_t* query, + int baseLoadIndex, + const int lane_id) + { + constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int + uint32_t queryReg = + (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int stride = kUnroll * uint8_veclen; + +#pragma unroll + for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV[veclen_int]; + raft::ldg( + encV, + reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); + const int d = (i * kUnroll + j) * veclen_int; +#pragma unroll + for (int k = 0; k < veclen_int; ++k) { + uint32_t q = raft::shfl(queryReg, d + k, raft::WarpSize); + compute_dist(dist, q, encV[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(encV[k], encV[k], norm_data); + } + } + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, + const uint8_t* query, + const int lane_id, + const int dim, + const int dimBlocks) + { + constexpr int veclen_int = uint8_veclen / 4; + const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int + uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; + for (int d = 0; d < dim - dimBlocks; + d += uint8_veclen, data += kIndexGroupSize * uint8_veclen) { + uint32_t enc[veclen_int]; + raft::ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); +#pragma unroll + for (int k = 0; k < veclen_int; k++) { + uint32_t q = raft::shfl(queryReg, (d / 4) + k, raft::WarpSize); + compute_dist(dist, q, enc[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(enc[k], enc[k], norm_data); + } + } + } + } +}; + +// Keep this specialized uint8 Veclen = 4, because compiler is generating suboptimal code while +// using above common template of int2/int4 +template +struct loadAndComputeDist { + uint32_t& dist; + uint32_t& norm_query; + uint32_t& norm_data; + + __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, + uint32_t& norm_query, + uint32_t& norm_data) + : dist(dist), norm_query(norm_query), norm_data(norm_data) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + const uint8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; + uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; + compute_dist<4, uint8_t, uint32_t>(dist, queryRegs, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(queryRegs, queryRegs, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } + } + } + __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, + const uint8_t* query, + int baseLoadIndex, + const int lane_id) + { + uint32_t queryReg = + (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int veclen = 4; + constexpr int stride = kUnroll * veclen; + +#pragma unroll + for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; + uint32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); + compute_dist<4, uint8_t, uint32_t>(dist, q, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, + const uint8_t* query, + const int lane_id, + const int dim, + const int dimBlocks) + { + constexpr int veclen = 4; + const int loadDim = dimBlocks + lane_id; + uint32_t queryReg = loadDim < dim ? reinterpret_cast(query)[loadDim] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + uint32_t enc = reinterpret_cast(data)[lane_id]; + uint32_t q = raft::shfl(queryReg, d / veclen, raft::WarpSize); + compute_dist<4, uint8_t, uint32_t>(dist, q, enc); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(enc, enc, norm_data); + } + } + } +}; + +template +struct loadAndComputeDist { + uint32_t& dist; + uint32_t& norm_query; + uint32_t& norm_data; + + __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, + uint32_t& norm_query, + uint32_t& norm_data) + : dist(dist), norm_query(norm_query), norm_data(norm_data) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + const uint8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; + uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; + compute_dist<2, uint8_t, uint32_t>(dist, queryRegs, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(queryRegs, queryRegs, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, + const uint8_t* query, + int baseLoadIndex, + const int lane_id) + { + uint32_t queryReg = + (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int veclen = 2; + constexpr int stride = kUnroll * veclen; + +#pragma unroll + for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; + uint32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); + compute_dist<2, uint8_t, uint32_t>(dist, q, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, + const uint8_t* query, + const int lane_id, + const int dim, + const int dimBlocks) + { + constexpr int veclen = 2; + int loadDim = dimBlocks + lane_id * veclen; + uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + uint32_t enc = reinterpret_cast(data)[lane_id]; + uint32_t q = raft::shfl(queryReg, d / veclen, raft::WarpSize); + compute_dist<2, uint8_t, uint32_t>(dist, q, enc); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(enc, enc, norm_data); + } + } + } +}; + +template +struct loadAndComputeDist { + uint32_t& dist; + uint32_t& norm_query; + uint32_t& norm_data; + + __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, + uint32_t& norm_query, + uint32_t& norm_data) + : dist(dist), norm_query(norm_query), norm_data(norm_data) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + const uint8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = data[loadIndex + j * kIndexGroupSize]; + uint32_t queryRegs = query_shared[shmemIndex + j]; + compute_dist<1, uint8_t, uint32_t>(dist, queryRegs, encV); + if constexpr (ComputeNorm) { + norm_query += queryRegs * queryRegs; + norm_data += encV * encV; + } + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, + const uint8_t* query, + int baseLoadIndex, + const int lane_id) + { + uint32_t queryReg = query[baseLoadIndex + lane_id]; + constexpr int veclen = 1; + constexpr int stride = kUnroll * veclen; + +#pragma unroll + for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = data[lane_id + j * kIndexGroupSize]; + uint32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); + compute_dist<1, uint8_t, uint32_t>(dist, q, encV); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += encV * encV; + } + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, + const uint8_t* query, + const int lane_id, + const int dim, + const int dimBlocks) + { + constexpr int veclen = 1; + int loadDim = dimBlocks + lane_id; + uint32_t queryReg = loadDim < dim ? query[loadDim] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + uint32_t enc = data[lane_id]; + uint32_t q = raft::shfl(queryReg, d, raft::WarpSize); + compute_dist<1, uint8_t, uint32_t>(dist, q, enc); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += enc * enc; + } + } + } +}; + +// This device function is for int8 veclens 4, 8 and 16 +template +struct loadAndComputeDist { + int32_t& dist; + int32_t& norm_query; + int32_t& norm_data; + + __device__ __forceinline__ loadAndComputeDist(int32_t& dist, + int32_t& norm_query, + int32_t& norm_data) + : dist(dist), norm_query(norm_query), norm_data(norm_data) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, + const int8_t* query_shared, + int loadIndex, + int shmemIndex) + { + constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int32_t encV[veclen_int]; + raft::ldg( + encV, + reinterpret_cast(data) + (loadIndex + j * kIndexGroupSize) * veclen_int); + int32_t queryRegs[veclen_int]; + raft::lds(queryRegs, + reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); +#pragma unroll + for (int k = 0; k < veclen_int; k++) { + compute_dist(dist, queryRegs[k], encV[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(queryRegs[k], queryRegs[k], norm_query); + norm_data = raft::dp4a(encV[k], encV[k], norm_data); + } + } + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, + const int8_t* query, + int baseLoadIndex, + const int lane_id) + { + constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int + + int32_t queryReg = + (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int stride = kUnroll * int8_veclen; + +#pragma unroll + for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int32_t encV[veclen_int]; + raft::ldg( + encV, + reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); + const int d = (i * kUnroll + j) * veclen_int; +#pragma unroll + for (int k = 0; k < veclen_int; ++k) { + int32_t q = raft::shfl(queryReg, d + k, raft::WarpSize); + compute_dist(dist, q, encV[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(encV[k], encV[k], norm_data); + } + } + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) + { + constexpr int veclen_int = int8_veclen / 4; + const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int; + int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; + for (int d = 0; d < dim - dimBlocks; d += int8_veclen, data += kIndexGroupSize * int8_veclen) { + int32_t enc[veclen_int]; + raft::ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); +#pragma unroll + for (int k = 0; k < veclen_int; k++) { + int32_t q = raft::shfl(queryReg, (d / 4) + k, raft::WarpSize); // Here 4 is for 1 - int; + compute_dist(dist, q, enc[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(enc[k], enc[k], norm_data); + } + } + } + } +}; + +template +struct loadAndComputeDist { + int32_t& dist; + int32_t& norm_query; + int32_t& norm_data; + __device__ __forceinline__ loadAndComputeDist(int32_t& dist, + int32_t& norm_query, + int32_t& norm_data) + : dist(dist), norm_query(norm_query), norm_data(norm_data) + { + } + __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, + const int8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; + int32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; + compute_dist<2, int8_t, int32_t>(dist, queryRegs, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(queryRegs, queryRegs, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, + const int8_t* query, + int baseLoadIndex, + const int lane_id) + { + int32_t queryReg = + (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int veclen = 2; + constexpr int stride = kUnroll * veclen; + +#pragma unroll + for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; + int32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); + compute_dist<2, int8_t, int32_t>(dist, q, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) + { + constexpr int veclen = 2; + int loadDim = dimBlocks + lane_id * veclen; + int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + int32_t enc = reinterpret_cast(data + lane_id * veclen)[0]; + int32_t q = raft::shfl(queryReg, d / veclen, raft::WarpSize); + compute_dist<2, int8_t, int32_t>(dist, q, enc); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(enc, enc, norm_data); + } + } + } +}; + +template +struct loadAndComputeDist { + int32_t& dist; + int32_t& norm_query; + int32_t& norm_data; + __device__ __forceinline__ loadAndComputeDist(int32_t& dist, + int32_t& norm_query, + int32_t& norm_data) + : dist(dist), norm_query(norm_query), norm_data(norm_data) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, + const int8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + compute_dist<1, int8_t, int32_t>( + dist, query_shared[shmemIndex + j], data[loadIndex + j * kIndexGroupSize]); + if constexpr (ComputeNorm) { + norm_query += int32_t{query_shared[shmemIndex + j]} * int32_t{query_shared[shmemIndex + j]}; + norm_data += int32_t{data[loadIndex + j * kIndexGroupSize]} * + int32_t{data[loadIndex + j * kIndexGroupSize]}; + } + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, + const int8_t* query, + int baseLoadIndex, + const int lane_id) + { + constexpr int veclen = 1; + constexpr int stride = kUnroll * veclen; + int32_t queryReg = query[baseLoadIndex + lane_id]; + +#pragma unroll + for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); + compute_dist<1, int8_t, int32_t>(dist, q, data[lane_id + j * kIndexGroupSize]); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += data[lane_id + j * kIndexGroupSize] * data[lane_id + j * kIndexGroupSize]; + } + } + } + } + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) + { + constexpr int veclen = 1; + const int loadDim = dimBlocks + lane_id; + int32_t queryReg = loadDim < dim ? query[loadDim] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + int32_t q = raft::shfl(queryReg, d, raft::WarpSize); + compute_dist<1, int8_t, int32_t>(dist, q, data[lane_id]); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += int32_t{data[lane_id]} * int32_t{data[lane_id]}; + } + } + } +}; + +// switch to dummy blocksort when Capacity is 0 this explicit dummy is chosen +// to support access to warpsort constants like ::queue_t::kDummy +template +struct flat_block_sort { + using type = raft::matrix::detail::select::warpsort::block_sort< + raft::matrix::detail::select::warpsort::warp_sort_filtered, + Capacity, + Ascending, + T, + IdxT>; +}; + +template +struct flat_block_sort<0, Ascending, T, IdxT> + : ivf::detail::dummy_block_sort_t { + using type = ivf::detail::dummy_block_sort_t; +}; + +template +using block_sort_t = typename flat_block_sort::type; + +/** + * Scan clusters for nearest neighbors of the query vectors. + * See `ivfflat_interleaved_scan` for more information. + * + * The clusters are stored in the interleaved index format described in ivf_flat_types.hpp. + * For each query vector, a set of clusters is probed: the distance to each vector in the cluster is + * calculated, and the top-k nearest neighbors are selected. + * + * @param compute_dist distance function + * @param query_smem_elems number of dimensions of the query vector to fit in a shared memory of a + * block; this number must be a multiple of `WarpSize * Veclen`. + * @param[in] query a pointer to all queries in a row-major contiguous format [gridDim.y, dim] + * @param[in] coarse_index a pointer to the cluster indices to search through [n_probes] + * @param[in] list_indices index.indices + * @param[in] list_data index.data + * @param[in] list_sizes index.list_sizes + * @param[in] list_offsets index.list_offsets + * @param n_probes + * @param k + * @param dim + * @param sample_filter + * @param[out] neighbors + * @param[out] distances + */ +template +RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) + interleaved_scan_kernel(const uint32_t query_smem_elems, + const T* query, + const uint32_t* coarse_index, + const T* const* list_data_ptrs, + const uint32_t* list_sizes, + const uint32_t queries_offset, + const uint32_t n_probes, + const uint32_t k, + const uint32_t max_samples, + const uint32_t* chunk_indices, + const uint32_t dim, + IdxT* const* const inds_ptrs, + uint32_t* bitset_ptr, + IdxT bitset_len, + IdxT original_nbits, + uint32_t* neighbors, + float* distances) +{ + extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[]; + constexpr bool kManageLocalTopK = Capacity > 0; + // Using shared memory for the (part of the) query; + // This allows to save on global memory bandwidth when reading index and query + // data at the same time. + // Its size is `query_smem_elems`. + T* query_shared = reinterpret_cast(interleaved_scan_kernel_smem); + // Make the query input and output point to this block's shared query + { + const int query_id = blockIdx.y; + query += query_id * dim; + if constexpr (kManageLocalTopK) { + neighbors += query_id * k * gridDim.x + blockIdx.x * k; + distances += query_id * k * gridDim.x + blockIdx.x * k; + } else { + distances += query_id * uint64_t(max_samples); + } + chunk_indices += (n_probes * query_id); + coarse_index += query_id * n_probes; + } + + // Copy a part of the query into shared memory for faster processing + copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); + __syncthreads(); + + using local_topk_t = block_sort_t; + local_topk_t queue(k); + { + using align_warp = raft::Pow2; + const int lane_id = align_warp::mod(threadIdx.x); + + // How many full warps needed to compute the distance (without remainder) + const uint32_t full_warps_along_dim = align_warp::roundDown(dim); + + const uint32_t shm_assisted_dim = + (dim > query_smem_elems) ? query_smem_elems : full_warps_along_dim; + + // Every CUDA block scans one cluster at a time. + for (int probe_id = blockIdx.x; probe_id < n_probes; probe_id += gridDim.x) { + const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list) + + // The number of vectors in each cluster(list); [nlist] + const uint32_t list_length = list_sizes[list_id]; + + // The number of interleaved groups to be processed + const uint32_t num_groups = + align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2 + + uint32_t sample_offset = 0; + if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; } + assert(list_length == chunk_indices[probe_id] - sample_offset); + assert(sample_offset + list_length <= max_samples); + + constexpr int kUnroll = raft::WarpSize / Veclen; + constexpr uint32_t kNumWarps = kThreadsPerBlock / raft::WarpSize; + // Every warp reads WarpSize vectors and computes the distances to them. + // Then, the distances and corresponding ids are distributed among the threads, + // and each thread adds one (id, dist) pair to the filtering queue. + for (uint32_t group_id = align_warp::div(threadIdx.x); group_id < num_groups; + group_id += kNumWarps) { + AccT dist = 0; + AccT norm_query = 0; + AccT norm_dataset = 0; + // This is where this warp begins reading data (start position of an interleaved group) + const T* data = list_data_ptrs[list_id] + (group_id * kIndexGroupSize) * dim; + + // This is the vector a given lane/thread handles + const uint32_t vec_id = group_id * raft::WarpSize + lane_id; + const bool valid = vec_id < list_length && sample_filter(inds_ptrs, + queries_offset + blockIdx.y, + list_id, + vec_id, + bitset_ptr, + bitset_len, + original_nbits); + + if (valid) { + // Process first shm_assisted_dim dimensions (always using shared memory) + loadAndComputeDist lc( + dist, norm_query, norm_dataset); + for (int pos = 0; pos < shm_assisted_dim; + pos += raft::WarpSize, data += kIndexGroupSize * raft::WarpSize) { + lc.runLoadShmemCompute(data, query_shared, lane_id, pos); + } + + if (dim > query_smem_elems) { + // The default path - using shfl ops - for dimensions beyond query_smem_elems + loadAndComputeDist lc( + dist, norm_query, norm_dataset); + for (int pos = shm_assisted_dim; pos < full_warps_along_dim; pos += raft::WarpSize) { + lc.runLoadShflAndCompute(data, query, pos, lane_id); + } + lc.runLoadShflAndComputeRemainder(data, query, lane_id, dim, full_warps_along_dim); + } else { + // when shm_assisted_dim == full_warps_along_dim < dim + loadAndComputeDist<1, Veclen, T, AccT, ComputeNorm> lc(dist, norm_query, norm_dataset); + for (int pos = full_warps_along_dim; pos < dim; + pos += Veclen, data += kIndexGroupSize * Veclen) { + lc.runLoadShmemCompute(data, query_shared, lane_id, pos); + } + } + } + + // Enqueue one element per thread + float val = valid ? static_cast(dist) : local_topk_t::queue_t::kDummy; + + if constexpr (ComputeNorm) { + if (valid) + val = val / (raft::sqrt(static_cast(norm_query)) * + raft::sqrt(static_cast(norm_dataset))); + } + if constexpr (kManageLocalTopK) { + queue.add(val, sample_offset + vec_id); + } else { + if (vec_id < list_length) distances[sample_offset + vec_id] = val; + } + } + + // fill up unused slots for current query + if constexpr (!kManageLocalTopK) { + if (probe_id + 1 == n_probes) { + for (uint32_t i = threadIdx.x + sample_offset + list_length; i < max_samples; + i += blockDim.x) { + distances[i] = local_topk_t::queue_t::kDummy; + } + } + } + } + } + + // finalize and store selected neighbours + if constexpr (kManageLocalTopK) { + __syncthreads(); + queue.done(interleaved_scan_kernel_smem); + queue.store(distances, neighbors, [](auto val) { return post_process(val); }); + } +} + +} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric.cu.in b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric.cu.in new file mode 100644 index 0000000000..0f6bb904d1 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric.cu.in @@ -0,0 +1,36 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) @year@, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +// This file is auto-generated. Do not edit manually. + +#ifdef BUILD_KERNEL + +#include <@header_file@> + +namespace cuvs::neighbors::ivf_flat::detail { + +// Instantiate the device function template +template __device__ void compute_dist<@veclen@, @data_type@, @acc_type@>(@acc_type@&, @acc_type@, @acc_type@); + +} // namespace cuvs::neighbors::ivf_flat::detail + +#else + +#include +#include +#include "metric_@metric_name@_@veclen@_@type_abbrev@_@acc_abbrev@.h" + +using namespace cuvs::neighbors::ivf_flat::detail; + +__attribute__((__constructor__)) static void register_metric_@metric_name@_@veclen@_@type_abbrev@_@acc_abbrev@() +{ + registerAlgorithm( + "@metric_name@_@veclen@", + embedded_metric_@metric_name@_@veclen@_@type_abbrev@_@acc_abbrev@, + sizeof(embedded_metric_@metric_name@_@veclen@_@type_abbrev@_@acc_abbrev@)); +} + +#endif diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric_euclidean.cuh b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric_euclidean.cuh new file mode 100644 index 0000000000..8fd2a4c04b --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric_euclidean.cuh @@ -0,0 +1,59 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::ivf_flat::detail { + +template +struct euclidean_dist { + __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) + { + const auto diff = x - y; + acc += diff * diff; + } +}; + +template +struct euclidean_dist { + __device__ __forceinline__ void operator()(uint32_t& acc, uint32_t x, uint32_t y) + { + if constexpr (Veclen > 1) { + const auto diff = __vabsdiffu4(x, y); + acc = raft::dp4a(diff, diff, acc); + } else { + const auto diff = __usad(x, y, 0u); + acc += diff * diff; + } + } +}; + +template +struct euclidean_dist { + __device__ __forceinline__ void operator()(int32_t& acc, int32_t x, int32_t y) + { + if constexpr (Veclen > 1) { + // Note that we enforce here that the unsigned version of dp4a is used, because the difference + // between two int8 numbers can be greater than 127 and therefore represented as a negative + // number in int8. Casting from int8 to int32 would yield incorrect results, while casting + // from uint8 to uint32 is correct. + const auto diff = __vabsdiffs4(x, y); + acc = raft::dp4a(diff, diff, static_cast(acc)); + } else { + const auto diff = x - y; + acc += diff * diff; + } + } +}; + +template +__device__ void compute_dist(AccT& acc, AccT x, AccT y) +{ + euclidean_dist{}(acc, x, y); +} + +} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric_inner_prod.cuh b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric_inner_prod.cuh new file mode 100644 index 0000000000..afc4c401fa --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric_inner_prod.cuh @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::ivf_flat::detail { + +template +struct inner_prod_dist { + __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) + { + if constexpr (Veclen > 1 && (std::is_same_v || std::is_same_v)) { + acc = raft::dp4a(x, y, acc); + } else { + acc += x * y; + } + } +}; + +template +__device__ void compute_dist(AccT& acc, AccT x, AccT y) +{ + inner_prod_dist{}(acc, x, y); +} + +} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_compose.cuh b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_compose.cuh new file mode 100644 index 0000000000..fe3473ba4a --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_compose.cuh @@ -0,0 +1,20 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::ivf_flat::detail { + +template +__device__ T post_process(T val) +{ + // This is for cosine distance: compose(add_const(1.0), mul_const(-1.0)) + // which computes: 1.0 + (-1.0 * val) = 1.0 - val + return raft::compose_op(raft::add_const_op{1.0f}, raft::mul_const_op{-1.0f})(val); +} + +} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_identity.cuh b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_identity.cuh new file mode 100644 index 0000000000..04fd825c92 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_identity.cuh @@ -0,0 +1,18 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::ivf_flat::detail { + +template +__device__ T post_process(T val) +{ + return raft::identity_op{}(val); +} + +} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda.cu.in b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda.cu.in new file mode 100644 index 0000000000..abf156a133 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda.cu.in @@ -0,0 +1,32 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) @year@, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +// This file is auto-generated. Do not edit manually. + +#ifdef BUILD_KERNEL + +#include <@header_file@> + +namespace cuvs::neighbors::ivf_flat::detail { + +// Instantiate the device function template +template __device__ float post_process(float); + +} // namespace cuvs::neighbors::ivf_flat::detail + +#else + +#include +#include "@post_lambda_name@.h" + +__attribute__((__constructor__)) static void register_@post_lambda_name@() +{ + registerAlgorithm( + "@post_lambda_name@", + embedded_@post_lambda_name@, + sizeof(embedded_@post_lambda_name@)); +} + +#endif diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_sqrt.cuh b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_sqrt.cuh new file mode 100644 index 0000000000..28009b04e7 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_sqrt.cuh @@ -0,0 +1,18 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::ivf_flat::detail { + +template +__device__ T post_process(T val) +{ + return raft::sqrt_op{}(val); +} + +} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/sample_filter.cuh b/cpp/src/neighbors/sample_filter.cuh index 132c33e94e..6a45e78780 100644 --- a/cpp/src/neighbors/sample_filter.cuh +++ b/cpp/src/neighbors/sample_filter.cuh @@ -57,8 +57,8 @@ struct takes_three_args< * @tparam filter_t */ template -ivf_to_sample_filter::ivf_to_sample_filter(const index_t* const* inds_ptrs, - const filter_t next_filter) +_RAFT_HOST_DEVICE ivf_to_sample_filter::ivf_to_sample_filter( + const index_t* const* inds_ptrs, const filter_t next_filter) : inds_ptrs_{inds_ptrs}, next_filter_{next_filter} { } diff --git a/dependencies.yaml b/dependencies.yaml index cc9afe3eed..fca48befa0 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -325,6 +325,16 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev + specific: + - output_types: conda + matrices: + - matrix: + cuda: "13.*" + packages: + - libnvjitlink-dev + - matrix: + cuda: "12.*" + packages: cuda_wheels: specific: - output_types: [requirements, pyproject] diff --git a/python/cuvs_bench/cuvs_bench/run/runners.py b/python/cuvs_bench/cuvs_bench/run/runners.py index 39fa92269b..0377ea7f45 100644 --- a/python/cuvs_bench/cuvs_bench/run/runners.py +++ b/python/cuvs_bench/cuvs_bench/run/runners.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # @@ -137,7 +137,7 @@ def cuvs_bench_cpp( "--benchmark_counters_tabular=true", f"--override_kv=k:{k}", f"--override_kv=n_queries:{batch_size}", - "--benchmark_min_warmup_time=1", + "--benchmark_min_warmup_time=4", "--benchmark_out_format=json", f"--mode={mode}", f"--benchmark_out={os.path.join(search_folder, search_file)}",