diff --git a/c/parallel/include/cccl/c/unique_by_key.h b/c/parallel/include/cccl/c/unique_by_key.h new file mode 100644 index 00000000000..632ceffe584 --- /dev/null +++ b/c/parallel/include/cccl/c/unique_by_key.h @@ -0,0 +1,66 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDA Experimental in CUDA Core Compute Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#ifndef CCCL_C_EXPERIMENTAL +# error "C exposure is experimental and subject to change. Define CCCL_C_EXPERIMENTAL to acknowledge this notice." +#endif // !CCCL_C_EXPERIMENTAL + +#include + +#include +#include + +CCCL_C_EXTERN_C_BEGIN + +typedef struct cccl_device_unique_by_key_build_result_t +{ + int cc; + void* cubin; + size_t cubin_size; + CUlibrary library; + CUkernel compact_init_kernel; + CUkernel sweep_kernel; + size_t description_bytes_per_tile; + size_t payload_bytes_per_tile; +} cccl_device_unique_by_key_build_result_t; + +CCCL_C_API CUresult cccl_device_unique_by_key_build( + cccl_device_unique_by_key_build_result_t* build, + cccl_iterator_t d_keys_in, + cccl_iterator_t d_values_in, + cccl_iterator_t d_keys_out, + cccl_iterator_t d_values_out, + cccl_iterator_t d_num_selected_out, + cccl_op_t op, + int cc_major, + int cc_minor, + const char* cub_path, + const char* thrust_path, + const char* libcudacxx_path, + const char* ctk_path) noexcept; + +CCCL_C_API CUresult cccl_device_unique_by_key( + cccl_device_unique_by_key_build_result_t build, + void* d_temp_storage, + size_t* temp_storage_bytes, + cccl_iterator_t d_keys_in, + cccl_iterator_t d_values_in, + cccl_iterator_t d_keys_out, + cccl_iterator_t d_values_out, + cccl_iterator_t d_num_selected_out, + cccl_op_t op, + unsigned long long num_items, + CUstream stream) noexcept; + +CCCL_C_API CUresult cccl_device_unique_by_key_cleanup(cccl_device_unique_by_key_build_result_t* bld_ptr) noexcept; + +CCCL_C_EXTERN_C_END diff --git a/c/parallel/src/kernels/iterators.cpp b/c/parallel/src/kernels/iterators.cpp index 4ba95cca973..44e8f577bff 100644 --- a/c/parallel/src/kernels/iterators.cpp +++ b/c/parallel/src/kernels/iterators.cpp @@ -97,28 +97,28 @@ std::string make_kernel_output_iterator( const std::string iter_def = std::format(R"XXX( extern "C" __device__ void DEREF(const void *self_ptr, VALUE_T x); extern "C" __device__ void ADVANCE(void *self_ptr, DIFF_T offset); -struct __align__(OP_ALIGNMENT) output_iterator_state_t {{ +struct __align__(OP_ALIGNMENT) {0}_state_t {{ char data[OP_SIZE]; }}; -struct output_iterator_proxy_t {{ - __device__ output_iterator_proxy_t operator=(VALUE_T x) {{ +struct {0}_proxy_t {{ + __device__ {0}_proxy_t operator=(VALUE_T x) {{ DEREF(&state, x); return *this; }} - output_iterator_state_t state; + {0}_state_t state; }}; struct {0} {{ using iterator_category = cuda::std::random_access_iterator_tag; using difference_type = DIFF_T; using value_type = void; - using pointer = output_iterator_proxy_t*; - using reference = output_iterator_proxy_t; - __device__ output_iterator_proxy_t operator*() const {{ return {{state}}; }} + using pointer = {0}_proxy_t*; + using reference = {0}_proxy_t; + __device__ {0}_proxy_t operator*() const {{ return {{state}}; }} __device__ {0}& operator+=(difference_type diff) {{ ADVANCE(&state, diff); return *this; }} - __device__ output_iterator_proxy_t operator[](difference_type diff) const {{ + __device__ {0}_proxy_t operator[](difference_type diff) const {{ {0} result = *this; result += diff; return {{ result.state }}; @@ -128,7 +128,7 @@ struct {0} {{ result += diff; return result; }} - output_iterator_state_t state; + {0}_state_t state; }}; )XXX", iterator_name); diff --git a/c/parallel/src/scan.cu b/c/parallel/src/scan.cu index 3a0c1ad8dcc..cc704d8d373 100644 --- a/c/parallel/src/scan.cu +++ b/c/parallel/src/scan.cu @@ -8,7 +8,6 @@ // //===----------------------------------------------------------------------===// -#include #include #include #include @@ -20,7 +19,6 @@ #include #include #include -#include #include #include @@ -30,6 +28,7 @@ #include "util/context.h" #include "util/errors.h" #include "util/indirect_arg.h" +#include "util/scan_tile_state.h" #include "util/types.h" #include #include @@ -172,74 +171,6 @@ std::string get_scan_kernel_name(cccl_iterator_t input_it, cccl_iterator_t outpu init_t); // 9 } -// TODO: NVRTC doesn't currently support extracting basic type -// information (e.g., type sizes and alignments) from compiled -// LTO-IR. So we separately compile a small PTX file that defines the -// necessary types and constants and grep it for the required -// information. If/when NVRTC adds these features, we can remove this -// extra compilation step and get the information directly from the -// LTO-IR. -static constexpr auto ptx_u64_assignment_regex = R"(\.visible\s+\.global\s+\.align\s+\d+\s+\.u64\s+{}\s*=\s*(\d+);)"; - -std::optional find_size_t(char* ptx, std::string_view name) -{ - std::regex regex(std::format(ptx_u64_assignment_regex, name)); - std::cmatch match; - if (std::regex_search(ptx, match, regex)) - { - auto result = std::stoi(match[1].str()); - return result; - } - return std::nullopt; -} - -struct scan_tile_state -{ - // scan_tile_state implements the same (host) interface as cub::ScanTileStateT, except - // that it accepts the acummulator type as a runtime parameter rather than being - // templated on it. - // - // Both specializations ScanTileStateT and ScanTileStateT - where the - // bool parameter indicates whether `T` is primitive - are combined into a single type. - - void* d_tile_status; // d_tile_descriptors - void* d_tile_partial; - void* d_tile_inclusive; - - size_t description_bytes_per_tile; - size_t payload_bytes_per_tile; - - scan_tile_state(size_t description_bytes_per_tile, size_t payload_bytes_per_tile) - : d_tile_status(nullptr) - , d_tile_partial(nullptr) - , d_tile_inclusive(nullptr) - , description_bytes_per_tile(description_bytes_per_tile) - , payload_bytes_per_tile(payload_bytes_per_tile) - {} - - cudaError_t Init(int num_tiles, void* d_temp_storage, size_t temp_storage_bytes) - { - void* allocations[3] = {}; - auto status = cub::detail::tile_state_init( - description_bytes_per_tile, payload_bytes_per_tile, num_tiles, d_temp_storage, temp_storage_bytes, allocations); - if (status != cudaSuccess) - { - return status; - } - d_tile_status = allocations[0]; - d_tile_partial = allocations[1]; - d_tile_inclusive = allocations[2]; - return cudaSuccess; - } - - cudaError_t AllocationSize(int num_tiles, size_t& temp_storage_bytes) const - { - temp_storage_bytes = - cub::detail::tile_state_allocation_size(description_bytes_per_tile, payload_bytes_per_tile, num_tiles); - return cudaSuccess; - } -}; - template struct dynamic_scan_policy_t { @@ -392,43 +323,8 @@ struct device_scan_policy {{ check(cuLibraryGetKernel(&build_ptr->init_kernel, build_ptr->library, init_kernel_lowered_name.c_str())); check(cuLibraryGetKernel(&build_ptr->scan_kernel, build_ptr->library, scan_kernel_lowered_name.c_str())); - constexpr size_t num_ptx_args = 7; - const char* ptx_args[num_ptx_args] = { - arch.c_str(), cub_path, thrust_path, libcudacxx_path, ctk_path, "-rdc=true", "-dlto"}; - constexpr size_t num_ptx_lto_args = 3; - const char* ptx_lopts[num_ptx_lto_args] = {"-lto", arch.c_str(), "-ptx"}; - - constexpr std::string_view ptx_src_template = R"XXX( -#include -#include -struct __align__({1}) storage_t {{ - char data[{0}]; -}}; -__device__ size_t description_bytes_per_tile = cub::ScanTileState<{2}>::description_bytes_per_tile; -__device__ size_t payload_bytes_per_tile = cub::ScanTileState<{2}>::payload_bytes_per_tile; -)XXX"; - - const std::string ptx_src = std::format(ptx_src_template, accum_t.size, accum_t.alignment, accum_cpp); - auto compile_result = - make_nvrtc_command_list() - .add_program(nvrtc_translation_unit{ptx_src.c_str(), "tile_state_info"}) - .compile_program({ptx_args, num_ptx_args}) - .cleanup_program() - .finalize_program(num_ptx_lto_args, ptx_lopts); - auto ptx_code = compile_result.data.get(); - - size_t description_bytes_per_tile; - size_t payload_bytes_per_tile; - auto maybe_description_bytes_per_tile = scan::find_size_t(ptx_code, "description_bytes_per_tile"); - if (maybe_description_bytes_per_tile) - { - description_bytes_per_tile = maybe_description_bytes_per_tile.value(); - } - else - { - throw std::runtime_error("Failed to find description_bytes_per_tile in PTX"); - } - payload_bytes_per_tile = scan::find_size_t(ptx_code, "payload_bytes_per_tile").value_or(0); + auto [description_bytes_per_tile, + payload_bytes_per_tile] = get_tile_state_bytes_per_tile(accum_t, accum_cpp, args, num_args, arch); build_ptr->cc = cc; build_ptr->cubin = (void*) result.data.release(); diff --git a/c/parallel/src/unique_by_key.cu b/c/parallel/src/unique_by_key.cu new file mode 100644 index 00000000000..10b9470371b --- /dev/null +++ b/c/parallel/src/unique_by_key.cu @@ -0,0 +1,538 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDA Experimental in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include + +#include + +#include "kernels/iterators.h" +#include "kernels/operators.h" +#include "util/context.h" +#include "util/indirect_arg.h" +#include "util/scan_tile_state.h" +#include "util/types.h" +#include +#include +#include + +struct op_wrapper; +struct device_unique_by_key_policy; +using OffsetT = int64_t; +static_assert(std::is_same_v, OffsetT>, "OffsetT must be int64"); + +struct input_keys_iterator_state_t; +struct input_values_iterator_state_t; +struct output_keys_iterator_t; +struct output_values_iterator_t; +struct output_num_selected_iterator_t; + +struct num_selected_storage_t; + +namespace unique_by_key +{ +struct unique_by_key_runtime_tuning_policy +{ + int block_size; + int items_per_thread; + cub::BlockLoadAlgorithm load_algorithm; + cub::CacheLoadModifier load_modifier; + + unique_by_key_runtime_tuning_policy UniqueByKey() const + { + return *this; + } + + using UniqueByKeyPolicyT = unique_by_key_runtime_tuning_policy; +}; + +struct unique_by_key_tuning_t +{ + int cc; + int block_size; + int items_per_thread; +}; + +template +Tuning find_tuning(int cc, const Tuning (&tunings)[N]) +{ + for (const Tuning& tuning : tunings) + { + if (cc >= tuning.cc) + { + return tuning; + } + } + + return tunings[N - 1]; +} + +unique_by_key_runtime_tuning_policy get_policy(int /*cc*/) +{ + // TODO: we should update this once we figure out a way to reuse + // tuning logic from C++. Alternately, we should implement + // something better than a hardcoded default: + return {128, 10, cub::BLOCK_LOAD_DIRECT, cub::LOAD_DEFAULT}; +} + +enum class unique_by_key_iterator_t +{ + input_keys = 0, + input_values = 1, + output_keys = 2, + output_values = 3, + num_selected = 4 +}; + +template +std::string get_iterator_name(cccl_iterator_t iterator, unique_by_key_iterator_t which_iterator) +{ + if (iterator.type == cccl_iterator_kind_t::CCCL_POINTER) + { + return cccl_type_enum_to_name(iterator.value_type.type, true); + } + else + { + std::string iterator_t; + switch (which_iterator) + { + case unique_by_key_iterator_t::input_keys: { + check(nvrtcGetTypeName(&iterator_t)); + break; + } + case unique_by_key_iterator_t::input_values: { + check(nvrtcGetTypeName(&iterator_t)); + break; + } + case unique_by_key_iterator_t::output_keys: { + check(nvrtcGetTypeName(&iterator_t)); + break; + } + case unique_by_key_iterator_t::output_values: { + check(nvrtcGetTypeName(&iterator_t)); + break; + } + case unique_by_key_iterator_t::num_selected: { + check(nvrtcGetTypeName(&iterator_t)); + break; + } + } + + return iterator_t; + } +} + +std::string get_compact_init_kernel_name(cccl_iterator_t output_num_selected_it) +{ + std::string offset_t; + check(nvrtcGetTypeName(&offset_t)); + + const std::string num_selected_iterator_t = + get_iterator_name(output_num_selected_it, unique_by_key_iterator_t::num_selected); + + return std::format( + "cub::detail::scan::DeviceCompactInitKernel, {1}>", offset_t, num_selected_iterator_t); +} + +std::string get_sweep_kernel_name( + cccl_iterator_t input_keys_it, + cccl_iterator_t input_values_it, + cccl_iterator_t output_keys_it, + cccl_iterator_t output_values_it, + cccl_iterator_t output_num_selected_it) +{ + std::string chained_policy_t; + check(nvrtcGetTypeName(&chained_policy_t)); + + const std::string input_keys_iterator_t = get_iterator_name(input_keys_it, unique_by_key_iterator_t::input_keys); + const std::string input_values_iterator_t = + get_iterator_name(input_values_it, unique_by_key_iterator_t::input_values); + const std::string output_keys_iterator_t = get_iterator_name(output_keys_it, unique_by_key_iterator_t::output_keys); + const std::string output_values_iterator_t = + get_iterator_name(output_values_it, unique_by_key_iterator_t::output_values); + const std::string output_num_selected_iterator_t = + get_iterator_name(output_num_selected_it, unique_by_key_iterator_t::output_values); + + std::string offset_t; + check(nvrtcGetTypeName(&offset_t)); + + auto tile_state_t = std::format("cub::ScanTileState<{0}>", offset_t); + + std::string equality_op_t; + check(nvrtcGetTypeName(&equality_op_t)); + + return std::format( + "cub::detail::unique_by_key::DeviceUniqueByKeySweepKernel<{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}>", + chained_policy_t, + input_keys_iterator_t, + input_values_iterator_t, + output_keys_iterator_t, + output_values_iterator_t, + output_num_selected_iterator_t, + tile_state_t, + equality_op_t, + offset_t); +} + +template +struct dynamic_unique_by_key_policy_t +{ + using MaxPolicy = dynamic_unique_by_key_policy_t; + + template + cudaError_t Invoke(int device_ptx_version, F& op) + { + return op.template Invoke(GetPolicy(device_ptx_version)); + } +}; + +struct unique_by_key_kernel_source +{ + cccl_device_unique_by_key_build_result_t& build; + + CUkernel UniqueByKeySweepKernel() const + { + return build.sweep_kernel; + } + + CUkernel CompactInitKernel() const + { + return build.compact_init_kernel; + } + + scan_tile_state TileState() + { + return {build.description_bytes_per_tile, build.payload_bytes_per_tile}; + } +}; + +struct dynamic_vsmem_helper_t +{ + template + static int BlockThreads(PolicyT policy) + { + return policy.block_size; + } + + template + static int ItemsPerThread(PolicyT policy) + { + return policy.items_per_thread; + } + + template + static ::cuda::std::size_t VSMemPerBlock(PolicyT /*policy*/) + { + return 0; + } + +private: + unique_by_key_runtime_tuning_policy fallback_policy = {64, 1, cub::BLOCK_LOAD_DIRECT, cub::LOAD_DEFAULT}; + bool uses_fallback_policy() const + { + return false; + } +}; + +} // namespace unique_by_key + +CUresult cccl_device_unique_by_key_build( + cccl_device_unique_by_key_build_result_t* build, + cccl_iterator_t input_keys_it, + cccl_iterator_t input_values_it, + cccl_iterator_t output_keys_it, + cccl_iterator_t output_values_it, + cccl_iterator_t output_num_selected_it, + cccl_op_t op, + int cc_major, + int cc_minor, + const char* cub_path, + const char* thrust_path, + const char* libcudacxx_path, + const char* ctk_path) noexcept +{ + CUresult error = CUDA_SUCCESS; + + try + { + const char* name = "test"; + + const int cc = cc_major * 10 + cc_minor; + const auto policy = unique_by_key::get_policy(cc); + + const auto input_keys_it_value_t = cccl_type_enum_to_name(input_keys_it.value_type.type); + const auto input_values_it_value_t = cccl_type_enum_to_name(input_values_it.value_type.type); + const auto output_keys_it_value_t = cccl_type_enum_to_name(output_keys_it.value_type.type); + const auto output_values_it_value_t = cccl_type_enum_to_name(output_values_it.value_type.type); + const auto output_num_selected_it_value_t = cccl_type_enum_to_name(output_num_selected_it.value_type.type); + const auto offset_cpp = cccl_type_enum_to_name(cccl_type_enum::CCCL_INT64); + const cccl_type_info offset_t{sizeof(int64_t), alignof(int64_t), cccl_type_enum::CCCL_INT64}; + + const std::string input_keys_iterator_src = make_kernel_input_iterator( + offset_cpp, + get_iterator_name(input_keys_it, unique_by_key::unique_by_key_iterator_t::input_keys), + input_keys_it_value_t, + input_keys_it); + const std::string input_values_iterator_src = make_kernel_input_iterator( + offset_cpp, + get_iterator_name(input_values_it, unique_by_key::unique_by_key_iterator_t::input_values), + input_values_it_value_t, + input_values_it); + const std::string output_keys_iterator_src = make_kernel_output_iterator( + offset_cpp, + get_iterator_name(output_keys_it, unique_by_key::unique_by_key_iterator_t::output_keys), + output_keys_it_value_t, + output_keys_it); + const std::string output_values_iterator_src = make_kernel_output_iterator( + offset_cpp, + get_iterator_name(output_values_it, unique_by_key::unique_by_key_iterator_t::output_values), + output_values_it_value_t, + output_values_it); + const std::string output_num_selected_iterator_src = make_kernel_output_iterator( + offset_cpp, + get_iterator_name(output_num_selected_it, unique_by_key::unique_by_key_iterator_t::num_selected), + output_num_selected_it_value_t, + output_num_selected_it); + + const std::string op_src = make_kernel_user_comparison_operator(input_keys_it_value_t, op); + + constexpr std::string_view src_template = R"XXX( +#include +#include +#include +struct __align__({1}) storage_t {{ + char data[{0}]; +}}; +struct __align__({3}) items_storage_t {{ + char data[{2}]; +}}; +struct __align__({5}) num_out_storage_t {{ + char data[{4}]; +}}; +{8} +{9} +{10} +{11} +{12} +struct agent_policy_t {{ + static constexpr int ITEMS_PER_THREAD = {7}; + static constexpr int BLOCK_THREADS = {6}; + static constexpr cub::BlockLoadAlgorithm LOAD_ALGORITHM = cub::BLOCK_LOAD_WARP_TRANSPOSE; + static constexpr cub::CacheLoadModifier LOAD_MODIFIER = cub::LOAD_LDG; + static constexpr cub::BlockScanAlgorithm SCAN_ALGORITHM = cub::BLOCK_SCAN_WARP_SCANS; + struct detail {{ + using delay_constructor_t = cub::detail::default_delay_constructor_t; + }}; +}}; +struct device_unique_by_key_policy {{ + struct ActivePolicy {{ + using UniqueByKeyPolicyT = agent_policy_t; + }}; +}}; +{13} +)XXX"; + + const std::string src = std::format( + src_template, + input_keys_it.value_type.size, // 0 + input_keys_it.value_type.alignment, // 1 + input_values_it.value_type.size, // 2 + input_values_it.value_type.alignment, // 3 + output_values_it.value_type.size, // 4 + output_values_it.value_type.alignment, // 5 + policy.block_size, // 6 + policy.items_per_thread, // 7 + input_keys_iterator_src, // 8 + input_values_iterator_src, // 9 + output_keys_iterator_src, // 10 + output_values_iterator_src, // 11 + output_num_selected_iterator_src, // 12 + op_src); // 13 + +#if false // CCCL_DEBUGGING_SWITCH + fflush(stderr); + printf("\nCODE4NVRTC BEGIN\n%sCODE4NVRTC END\n", src.c_str()); + fflush(stdout); +#endif + + std::string compact_init_kernel_name = unique_by_key::get_compact_init_kernel_name(output_num_selected_it); + std::string sweep_kernel_name = unique_by_key::get_sweep_kernel_name( + input_keys_it, input_values_it, output_keys_it, output_values_it, output_num_selected_it); + std::string compact_init_kernel_lowered_name; + std::string sweep_kernel_lowered_name; + + const std::string arch = std::format("-arch=sm_{0}{1}", cc_major, cc_minor); + + constexpr size_t num_args = 7; + const char* args[num_args] = {arch.c_str(), cub_path, thrust_path, libcudacxx_path, ctk_path, "-rdc=true", "-dlto"}; + + constexpr size_t num_lto_args = 2; + const char* lopts[num_lto_args] = {"-lto", arch.c_str()}; + + // Collect all LTO-IRs to be linked. + nvrtc_ltoir_list ltoir_list; + auto ltoir_list_append = [<oir_list](nvrtc_ltoir lto) { + if (lto.ltsz) + { + ltoir_list.push_back(std::move(lto)); + } + }; + ltoir_list_append({op.ltoir, op.ltoir_size}); + if (cccl_iterator_kind_t::CCCL_ITERATOR == input_keys_it.type) + { + ltoir_list_append({input_keys_it.advance.ltoir, input_keys_it.advance.ltoir_size}); + ltoir_list_append({input_keys_it.dereference.ltoir, input_keys_it.dereference.ltoir_size}); + } + if (cccl_iterator_kind_t::CCCL_ITERATOR == input_values_it.type) + { + ltoir_list_append({input_values_it.advance.ltoir, input_values_it.advance.ltoir_size}); + ltoir_list_append({input_values_it.dereference.ltoir, input_values_it.dereference.ltoir_size}); + } + if (cccl_iterator_kind_t::CCCL_ITERATOR == output_keys_it.type) + { + ltoir_list_append({output_keys_it.advance.ltoir, output_keys_it.advance.ltoir_size}); + ltoir_list_append({output_keys_it.dereference.ltoir, output_keys_it.dereference.ltoir_size}); + } + if (cccl_iterator_kind_t::CCCL_ITERATOR == output_values_it.type) + { + ltoir_list_append({output_values_it.advance.ltoir, output_values_it.advance.ltoir_size}); + ltoir_list_append({output_values_it.dereference.ltoir, output_values_it.dereference.ltoir_size}); + } + if (cccl_iterator_kind_t::CCCL_ITERATOR == output_num_selected_it.type) + { + ltoir_list_append({output_values_it.advance.ltoir, output_values_it.advance.ltoir_size}); + ltoir_list_append({output_values_it.dereference.ltoir, output_values_it.dereference.ltoir_size}); + } + + nvrtc_link_result result = + make_nvrtc_command_list() + .add_program(nvrtc_translation_unit{src.c_str(), name}) + .add_expression({compact_init_kernel_name}) + .add_expression({sweep_kernel_name}) + .compile_program({args, num_args}) + .get_name({compact_init_kernel_name, compact_init_kernel_lowered_name}) + .get_name({sweep_kernel_name, sweep_kernel_lowered_name}) + .cleanup_program() + .add_link_list(ltoir_list) + .finalize_program(num_lto_args, lopts); + + cuLibraryLoadData(&build->library, result.data.get(), nullptr, nullptr, 0, nullptr, nullptr, 0); + check(cuLibraryGetKernel(&build->compact_init_kernel, build->library, compact_init_kernel_lowered_name.c_str())); + check(cuLibraryGetKernel(&build->sweep_kernel, build->library, sweep_kernel_lowered_name.c_str())); + + auto [description_bytes_per_tile, + payload_bytes_per_tile] = get_tile_state_bytes_per_tile(offset_t, offset_cpp, args, num_args, arch); + + build->cc = cc; + build->cubin = (void*) result.data.release(); + build->cubin_size = result.size; + build->description_bytes_per_tile = description_bytes_per_tile; + build->payload_bytes_per_tile = payload_bytes_per_tile; + } + catch (const std::exception& exc) + { + fflush(stderr); + printf("\nEXCEPTION in cccl_device_unique_by_key_build(): %s\n", exc.what()); + fflush(stdout); + error = CUDA_ERROR_UNKNOWN; + } + + return error; +} + +CUresult cccl_device_unique_by_key( + cccl_device_unique_by_key_build_result_t build, + void* d_temp_storage, + size_t* temp_storage_bytes, + cccl_iterator_t d_keys_in, + cccl_iterator_t d_values_in, + cccl_iterator_t d_keys_out, + cccl_iterator_t d_values_out, + cccl_iterator_t d_num_selected_out, + cccl_op_t op, + unsigned long long num_items, + CUstream stream) noexcept +{ + CUresult error = CUDA_SUCCESS; + bool pushed = false; + try + { + pushed = try_push_context(); + + CUdevice cu_device; + check(cuCtxGetDevice(&cu_device)); + + cub::DispatchUniqueByKey< + indirect_arg_t, + indirect_arg_t, + indirect_arg_t, + indirect_arg_t, + indirect_arg_t, + indirect_arg_t, + ::cuda::std::size_t, + unique_by_key::dynamic_unique_by_key_policy_t<&unique_by_key::get_policy>, + unique_by_key::unique_by_key_kernel_source, + cub::detail::CudaDriverLauncherFactory, + unique_by_key::dynamic_vsmem_helper_t, + indirect_arg_t, + indirect_arg_t>::Dispatch(d_temp_storage, + *temp_storage_bytes, + d_keys_in, + d_values_in, + d_keys_out, + d_values_out, + d_num_selected_out, + op, + num_items, + stream, + {build}, + cub::detail::CudaDriverLauncherFactory{cu_device, build.cc}, + {}); + } + catch (const std::exception& exc) + { + fflush(stderr); + printf("\nEXCEPTION in cccl_device_unique_by_key(): %s\n", exc.what()); + fflush(stdout); + error = CUDA_ERROR_UNKNOWN; + } + + if (pushed) + { + CUcontext dummy; + cuCtxPopCurrent(&dummy); + } + + return error; +} + +CUresult cccl_device_unique_by_key_cleanup(cccl_device_unique_by_key_build_result_t* build_ptr) noexcept +{ + try + { + if (build_ptr == nullptr) + { + return CUDA_ERROR_INVALID_VALUE; + } + + std::unique_ptr cubin(reinterpret_cast(build_ptr->cubin)); + check(cuLibraryUnload(build_ptr->library)); + } + catch (const std::exception& exc) + { + fflush(stderr); + printf("\nEXCEPTION in cccl_device_unique_by_key_cleanup(): %s\n", exc.what()); + fflush(stdout); + return CUDA_ERROR_UNKNOWN; + } + + return CUDA_SUCCESS; +} diff --git a/c/parallel/src/util/scan_tile_state.cu b/c/parallel/src/util/scan_tile_state.cu new file mode 100644 index 00000000000..2dadab05d34 --- /dev/null +++ b/c/parallel/src/util/scan_tile_state.cu @@ -0,0 +1,80 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDA Experimental in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "scan_tile_state.h" + +// TODO: NVRTC doesn't currently support extracting basic type +// information (e.g., type sizes and alignments) from compiled +// LTO-IR. So we separately compile a small PTX file that defines the +// necessary types and constants and grep it for the required +// information. If/when NVRTC adds these features, we can remove this +// extra compilation step and get the information directly from the +// LTO-IR. +static constexpr auto ptx_u64_assignment_regex = R"(\.visible\s+\.global\s+\.align\s+\d+\s+\.u64\s+{}\s*=\s*(\d+);)"; + +std::optional find_size_t(char* ptx, std::string_view name) +{ + std::regex regex(std::format(ptx_u64_assignment_regex, name)); + std::cmatch match; + if (std::regex_search(ptx, match, regex)) + { + auto result = std::stoi(match[1].str()); + return result; + } + return std::nullopt; +} + +std::pair get_tile_state_bytes_per_tile( + cccl_type_info accum_t, + const std::string& accum_cpp, + const char** ptx_args, + size_t num_ptx_args, + const std::string& arch) +{ + constexpr size_t num_ptx_lto_args = 3; + const char* ptx_lopts[num_ptx_lto_args] = {"-lto", arch.c_str(), "-ptx"}; + + constexpr std::string_view ptx_src_template = R"XXX( + #include + #include + struct __align__({1}) storage_t {{ + char data[{0}]; + }}; + __device__ size_t description_bytes_per_tile = cub::ScanTileState<{2}>::description_bytes_per_tile; + __device__ size_t payload_bytes_per_tile = cub::ScanTileState<{2}>::payload_bytes_per_tile; + )XXX"; + + const std::string ptx_src = std::format(ptx_src_template, accum_t.size, accum_t.alignment, accum_cpp); + auto compile_result = + make_nvrtc_command_list() + .add_program(nvrtc_translation_unit{ptx_src.c_str(), "tile_state_info"}) + .compile_program({ptx_args, num_ptx_args}) + .cleanup_program() + .finalize_program(num_ptx_lto_args, ptx_lopts); + auto ptx_code = compile_result.data.get(); + + size_t description_bytes_per_tile; + size_t payload_bytes_per_tile; + auto maybe_description_bytes_per_tile = find_size_t(ptx_code, "description_bytes_per_tile"); + if (maybe_description_bytes_per_tile) + { + description_bytes_per_tile = maybe_description_bytes_per_tile.value(); + } + else + { + throw std::runtime_error("Failed to find description_bytes_per_tile in PTX"); + } + payload_bytes_per_tile = find_size_t(ptx_code, "payload_bytes_per_tile").value_or(0); + + return {description_bytes_per_tile, payload_bytes_per_tile}; +} diff --git a/c/parallel/src/util/scan_tile_state.h b/c/parallel/src/util/scan_tile_state.h new file mode 100644 index 00000000000..1b8332f197b --- /dev/null +++ b/c/parallel/src/util/scan_tile_state.h @@ -0,0 +1,70 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDA Experimental in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include "cccl/c/types.h" +#include + +struct scan_tile_state +{ + // scan_tile_state implements the same (host) interface as cub::ScanTileStateT, except + // that it accepts the acummulator type as a runtime parameter rather than being + // templated on it. + // + // Both specializations ScanTileStateT and ScanTileStateT - where the + // bool parameter indicates whether `T` is primitive - are combined into a single type. + + void* d_tile_status; // d_tile_descriptors + void* d_tile_partial; + void* d_tile_inclusive; + + size_t description_bytes_per_tile; + size_t payload_bytes_per_tile; + + scan_tile_state(size_t description_bytes_per_tile, size_t payload_bytes_per_tile) + : d_tile_status(nullptr) + , d_tile_partial(nullptr) + , d_tile_inclusive(nullptr) + , description_bytes_per_tile(description_bytes_per_tile) + , payload_bytes_per_tile(payload_bytes_per_tile) + {} + + cudaError_t Init(int num_tiles, void* d_temp_storage, size_t temp_storage_bytes) + { + void* allocations[3] = {}; + auto status = cub::detail::tile_state_init( + description_bytes_per_tile, payload_bytes_per_tile, num_tiles, d_temp_storage, temp_storage_bytes, allocations); + if (status != cudaSuccess) + { + return status; + } + d_tile_status = allocations[0]; + d_tile_partial = allocations[1]; + d_tile_inclusive = allocations[2]; + return cudaSuccess; + } + + cudaError_t AllocationSize(int num_tiles, size_t& temp_storage_bytes) const + { + temp_storage_bytes = + cub::detail::tile_state_allocation_size(description_bytes_per_tile, payload_bytes_per_tile, num_tiles); + return cudaSuccess; + } +}; + +std::pair get_tile_state_bytes_per_tile( + cccl_type_info accum_t, + const std::string& accum_cpp, + const char** ptx_args, + size_t num_ptx_args, + const std::string& arch); diff --git a/c/parallel/test/test_unique_by_key.cpp b/c/parallel/test/test_unique_by_key.cpp new file mode 100644 index 00000000000..673aad26bf1 --- /dev/null +++ b/c/parallel/test/test_unique_by_key.cpp @@ -0,0 +1,371 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDA Experimental in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +#include + +#include +#include + +#include "test_util.h" +#include +#include +#include + +using key_types = std::tuple; +using item_t = int32_t; + +void unique_by_key( + cccl_iterator_t input_keys, + cccl_iterator_t input_values, + cccl_iterator_t output_keys, + cccl_iterator_t output_values, + cccl_iterator_t output_num_selected, + cccl_op_t op, + unsigned long long num_items) +{ + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + + const int cc_major = deviceProp.major; + const int cc_minor = deviceProp.minor; + + const char* cub_path = TEST_CUB_PATH; + const char* thrust_path = TEST_THRUST_PATH; + const char* libcudacxx_path = TEST_LIBCUDACXX_PATH; + const char* ctk_path = TEST_CTK_PATH; + + cccl_device_unique_by_key_build_result_t build; + REQUIRE( + CUDA_SUCCESS + == cccl_device_unique_by_key_build( + &build, + input_keys, + input_values, + output_keys, + output_values, + output_num_selected, + op, + cc_major, + cc_minor, + cub_path, + thrust_path, + libcudacxx_path, + ctk_path)); + + const std::string sass = inspect_sass(build.cubin, build.cubin_size); + REQUIRE(sass.find("LDL") == std::string::npos); + REQUIRE(sass.find("STL") == std::string::npos); + + size_t temp_storage_bytes = 0; + REQUIRE( + CUDA_SUCCESS + == cccl_device_unique_by_key( + build, + nullptr, + &temp_storage_bytes, + input_keys, + input_values, + output_keys, + output_values, + output_num_selected, + op, + num_items, + 0)); + + pointer_t temp_storage(temp_storage_bytes); + + REQUIRE( + CUDA_SUCCESS + == cccl_device_unique_by_key( + build, + temp_storage.ptr, + &temp_storage_bytes, + input_keys, + input_values, + output_keys, + output_values, + output_num_selected, + op, + num_items, + 0)); + REQUIRE(CUDA_SUCCESS == cccl_device_unique_by_key_cleanup(&build)); +} + +TEMPLATE_LIST_TEST_CASE("DeviceSelect::UniqueByKey can run with empty input", "[unique_by_key]", key_types) +{ + constexpr int num_items = 0; + + operation_t op = make_operation("op", get_unique_by_key_op(get_type_info().type)); + std::vector input_keys(num_items); + + pointer_t input_keys_it(input_keys); + pointer_t output_num_selected_it(1); + + unique_by_key(input_keys_it, input_keys_it, input_keys_it, input_keys_it, output_num_selected_it, op, num_items); + + REQUIRE(0 == std::vector(output_num_selected_it)[0]); +} + +TEMPLATE_LIST_TEST_CASE("DeviceSelect::UniqueByKey works", "[unique_by_key]", key_types) +{ + const int num_items = GENERATE_COPY(take(2, random(1, 1000000))); + + operation_t op = make_operation("op", get_unique_by_key_op(get_type_info().type)); + std::vector input_keys = generate(num_items); + std::vector input_values = generate(num_items); + + pointer_t input_keys_it(input_keys); + pointer_t input_values_it(input_values); + pointer_t output_keys_it(num_items); + pointer_t output_values_it(num_items); + pointer_t output_num_selected_it(1); + + unique_by_key(input_keys_it, input_values_it, output_keys_it, output_values_it, output_num_selected_it, op, num_items); + + std::vector> input_pairs; + for (size_t i = 0; i < input_keys.size(); ++i) + { + input_pairs.emplace_back(input_keys[i], input_values[i]); + } + const auto boundary = std::unique(input_pairs.begin(), input_pairs.end(), [](const auto& a, const auto& b) { + return a.first == b.first; + }); + + int num_selected = std::vector(output_num_selected_it)[0]; + + REQUIRE((boundary - input_pairs.begin()) == num_selected); + + input_pairs.resize(num_selected); + + std::vector host_output_keys(output_keys_it); + std::vector host_output_values(output_values_it); + std::vector> output_pairs; + for (int i = 0; i < num_selected; ++i) + { + output_pairs.emplace_back(host_output_keys[i], host_output_values[i]); + } + + REQUIRE(input_pairs == output_pairs); +} + +TEMPLATE_LIST_TEST_CASE("DeviceSelect::UniqueByKey handles none equal", "[device][select_unique_by_key]", key_types) +{ + const int num_items = 250; // to ensure that we get none equal for smaller data types + + operation_t op = make_operation("op", get_unique_by_key_op(get_type_info().type)); + std::vector input_keys = make_shuffled_sequence(num_items); + std::vector input_values = generate(num_items); + + pointer_t input_keys_it(input_keys); + pointer_t input_values_it(input_values); + pointer_t output_keys_it(num_items); + pointer_t output_values_it(num_items); + pointer_t output_num_selected_it(1); + + unique_by_key(input_keys_it, input_values_it, output_keys_it, output_values_it, output_num_selected_it, op, num_items); + + REQUIRE(num_items == std::vector(output_num_selected_it)[0]); + REQUIRE(input_keys == std::vector(output_keys_it)); + REQUIRE(input_values == std::vector(output_values_it)); +} + +TEMPLATE_LIST_TEST_CASE("DeviceSelect::UniqueByKey handles all equal", "[device][select_unique_by_key]", key_types) +{ + const int num_items = GENERATE_COPY(take(2, random(1, 1000000))); + + operation_t op = make_operation("op", get_unique_by_key_op(get_type_info().type)); + std::vector input_keys(num_items, static_cast(1)); + std::vector input_values = generate(num_items); + + pointer_t input_keys_it(input_keys); + pointer_t input_values_it(input_values); + pointer_t output_keys_it(1); + pointer_t output_values_it(1); + pointer_t output_num_selected_it(1); + + unique_by_key(input_keys_it, input_values_it, output_keys_it, output_values_it, output_num_selected_it, op, num_items); + + REQUIRE(1 == std::vector(output_num_selected_it)[0]); + REQUIRE(input_keys[0] == std::vector(output_keys_it)[0]); + REQUIRE(input_values[0] == std::vector(output_values_it)[0]); +} + +struct key_pair +{ + short a; + size_t b; + + bool operator==(const key_pair& other) const + { + return a == other.a && b == other.b; + } +}; + +TEST_CASE("DeviceSelect::UniqueByKey works with custom types", "[device][select_unique_by_key]") +{ + const int num_items = GENERATE_COPY(take(2, random(1, 1000000))); + + operation_t op = make_operation( + "op", + "struct key_pair { short a; size_t b; };\n" + "extern \"C\" __device__ bool op(key_pair lhs, key_pair rhs) {\n" + " return lhs.a == rhs.a && lhs.b == rhs.b;\n" + "}"); + const std::vector a = generate(num_items); + const std::vector b = generate(num_items); + std::vector input_keys(num_items); + std::vector input_values = generate(num_items); + for (int i = 0; i < num_items; ++i) + { + input_keys[i] = key_pair{a[i], b[i]}; + } + + pointer_t input_keys_it(input_keys); + pointer_t input_values_it(input_values); + pointer_t output_keys_it(num_items); + pointer_t output_values_it(num_items); + pointer_t output_num_selected_it(1); + + unique_by_key(input_keys_it, input_values_it, output_keys_it, output_values_it, output_num_selected_it, op, num_items); + + std::vector> input_pairs; + for (size_t i = 0; i < input_keys.size(); ++i) + { + input_pairs.emplace_back(input_keys[i], input_values[i]); + } + + const auto boundary = std::unique(input_pairs.begin(), input_pairs.end(), [](const auto& a, const auto& b) { + return a.first == b.first; + }); + + int num_selected = std::vector(output_num_selected_it)[0]; + + REQUIRE((boundary - input_pairs.begin()) == num_selected); + + input_pairs.resize(num_selected); + + std::vector host_output_keys(output_keys_it); + std::vector host_output_values(output_values_it); + std::vector> output_pairs; + for (int i = 0; i < num_selected; ++i) + { + output_pairs.emplace_back(host_output_keys[i], host_output_values[i]); + } + + REQUIRE(input_pairs == output_pairs); +} + +struct random_access_iterator_state_t +{ + int* d_input; +}; + +struct value_random_access_iterator_state_t +{ + int* d_input; +}; + +TEST_CASE("DeviceMergeSort::SortPairs works with input and output iterators", "[merge_sort]") +{ + using TestType = int; + + const int num_items = GENERATE_COPY(take(2, random(1, 1000000))); + + operation_t op = make_operation("op", get_unique_by_key_op(get_type_info().type)); + iterator_t input_keys_it = + make_iterator( + "struct random_access_iterator_state_t { int* d_input; };\n", + {"key_advance", + "extern \"C\" __device__ void key_advance(random_access_iterator_state_t* state, unsigned long long offset) {\n" + " state->d_input += offset;\n" + "}"}, + {"key_dereference", + "extern \"C\" __device__ int key_dereference(random_access_iterator_state_t* state) {\n" + " return *state->d_input;\n" + "}"}); + iterator_t input_values_it = + make_iterator( + "struct value_random_access_iterator_state_t { int* d_input; };\n", + {"value_advance", + "extern \"C\" __device__ void value_advance(value_random_access_iterator_state_t* state, unsigned long long " + "offset) {\n" + " state->d_input += offset;\n" + "}"}, + {"value_dereference", + "extern \"C\" __device__ int value_dereference(value_random_access_iterator_state_t* state) {\n" + " return *state->d_input;\n" + "}"}); + iterator_t output_keys_it = + make_iterator( + "struct random_access_iterator_state_t { int* d_input; };\n", + {"key_advance_out", + "extern \"C\" __device__ void key_advance_out(random_access_iterator_state_t* state, unsigned long long offset) " + "{\n" + " state->d_input += offset;\n" + "}"}, + {"key_dereference_out", + "extern \"C\" __device__ void key_dereference_out(random_access_iterator_state_t* state, int x) {\n" + " *state->d_input = x;\n" + "}"}); + iterator_t output_values_it = + make_iterator( + "struct value_random_access_iterator_state_t { int* d_input; };\n", + {"value_advance_out", + "extern \"C\" __device__ void value_advance_out(value_random_access_iterator_state_t* state, unsigned long long " + "offset) {\n" + " state->d_input += offset;\n" + "}"}, + {"value_dereference_out", + "extern \"C\" __device__ void value_dereference_out(value_random_access_iterator_state_t* state, int x) {\n" + " *state->d_input = x;\n" + "}"}); + + std::vector input_keys = generate(num_items); + std::vector input_values = generate(num_items); + + pointer_t input_keys_ptr(input_keys); + input_keys_it.state.d_input = input_keys_ptr.ptr; + pointer_t input_values_ptr(input_values); + input_values_it.state.d_input = input_values_ptr.ptr; + + pointer_t output_keys_ptr(num_items); + output_keys_it.state.d_input = output_keys_ptr.ptr; + pointer_t output_values_ptr(num_items); + output_values_it.state.d_input = output_values_ptr.ptr; + + pointer_t output_num_selected_it(1); + + unique_by_key(input_keys_it, input_values_it, output_keys_it, output_values_it, output_num_selected_it, op, num_items); + + std::vector> input_pairs; + for (size_t i = 0; i < input_keys.size(); ++i) + { + input_pairs.emplace_back(input_keys[i], input_values[i]); + } + const auto boundary = std::unique(input_pairs.begin(), input_pairs.end(), [](const auto& a, const auto& b) { + return a.first == b.first; + }); + + int num_selected = std::vector(output_num_selected_it)[0]; + + REQUIRE((boundary - input_pairs.begin()) == num_selected); + + input_pairs.resize(num_selected); + + std::vector host_output_keys(output_keys_ptr); + std::vector host_output_values(output_values_ptr); + std::vector> output_pairs; + for (int i = 0; i < num_selected; ++i) + { + output_pairs.emplace_back(host_output_keys[i], host_output_values[i]); + } + + REQUIRE(input_pairs == output_pairs); +} diff --git a/c/parallel/test/test_util.h b/c/parallel/test/test_util.h index 0f53a10786c..3833c2e298e 100644 --- a/c/parallel/test/test_util.h +++ b/c/parallel/test/test_util.h @@ -278,6 +278,37 @@ static std::string get_merge_sort_op(cccl_type_enum t) return ""; } +static std::string get_unique_by_key_op(cccl_type_enum t) +{ + switch (t) + { + case cccl_type_enum::CCCL_INT8: + return "extern \"C\" __device__ bool op(char lhs, char rhs) { return lhs == rhs; }"; + case cccl_type_enum::CCCL_UINT8: + return "extern \"C\" __device__ bool op(unsigned char lhs, unsigned char rhs) { return lhs == rhs; }"; + case cccl_type_enum::CCCL_INT16: + return "extern \"C\" __device__ bool op(short lhs, short rhs) { return lhs == rhs; }"; + case cccl_type_enum::CCCL_UINT16: + return "extern \"C\" __device__ bool op(unsigned short lhs, unsigned short rhs) { return lhs == rhs; }"; + case cccl_type_enum::CCCL_INT32: + return "extern \"C\" __device__ bool op(int lhs, int rhs) { return lhs == rhs; }"; + case cccl_type_enum::CCCL_UINT32: + return "extern \"C\" __device__ bool op(unsigned int lhs, unsigned int rhs) { return lhs == rhs; }"; + case cccl_type_enum::CCCL_INT64: + return "extern \"C\" __device__ bool op(long long lhs, long long rhs) { return lhs == rhs; }"; + case cccl_type_enum::CCCL_UINT64: + return "extern \"C\" __device__ bool op(unsigned long long lhs, unsigned long long rhs) { return lhs == rhs; }"; + case cccl_type_enum::CCCL_FLOAT32: + return "extern \"C\" __device__ bool op(float lhs, float rhs) { return lhs == rhs; }"; + case cccl_type_enum::CCCL_FLOAT64: + return "extern \"C\" __device__ bool op(double lhs, double rhs) { return lhs == rhs; }"; + + default: + throw std::runtime_error("Unsupported type"); + } + return ""; +} + template struct pointer_t { diff --git a/cub/cub/device/dispatch/dispatch_unique_by_key.cuh b/cub/cub/device/dispatch/dispatch_unique_by_key.cuh index 9120319c49d..f1b13e9a8bd 100644 --- a/cub/cub/device/dispatch/dispatch_unique_by_key.cuh +++ b/cub/cub/device/dispatch/dispatch_unique_by_key.cuh @@ -43,6 +43,7 @@ #endif // no system header #include +#include #include #include #include @@ -51,6 +52,43 @@ CUB_NAMESPACE_BEGIN +namespace detail::unique_by_key +{ +template + +struct DeviceUniqueByKeyKernelSource +{ + CUB_DEFINE_KERNEL_GETTER(CompactInitKernel, + detail::scan::DeviceCompactInitKernel); + + CUB_DEFINE_KERNEL_GETTER( + UniqueByKeySweepKernel, + DeviceUniqueByKeySweepKernel< + MaxPolicyT, + KeyInputIteratorT, + ValueInputIteratorT, + KeyOutputIteratorT, + ValueOutputIteratorT, + NumSelectedIteratorT, + ScanTileStateT, + EqualityOpT, + OffsetT>); + + CUB_RUNTIME_FUNCTION ScanTileStateT TileState() + { + return ScanTileStateT(); + } +}; +} // namespace detail::unique_by_key + /****************************************************************************** * Dispatch ******************************************************************************/ @@ -88,7 +126,21 @@ template < typename EqualityOpT, typename OffsetT, typename PolicyHub = - detail::unique_by_key::policy_hub, detail::it_value_t>> + detail::unique_by_key::policy_hub, detail::it_value_t>, + typename KernelSource = detail::unique_by_key::DeviceUniqueByKeyKernelSource< + typename PolicyHub::MaxPolicy, + KeyInputIteratorT, + ValueInputIteratorT, + KeyOutputIteratorT, + ValueOutputIteratorT, + NumSelectedIteratorT, + ScanTileState, + EqualityOpT, + OffsetT>, + typename KernelLauncherFactory = detail::TripleChevronFactory, + typename VSMemHelperT = detail::unique_by_key::VSMemHelper, + typename KeyT = detail::it_value_t, + typename ValueT = detail::it_value_t> struct DispatchUniqueByKey { /****************************************************************************** @@ -100,13 +152,6 @@ struct DispatchUniqueByKey INIT_KERNEL_THREADS = 128, }; - // The input key and value type - using KeyT = detail::it_value_t; - using ValueT = detail::it_value_t; - - // Tile status descriptor interface type - using ScanTileStateT = ScanTileState; - /// Device-accessible allocation of temporary storage. When nullptr, the required allocation size /// is written to `temp_storage_bytes` and no work is done. void* d_temp_storage; @@ -139,6 +184,10 @@ struct DispatchUniqueByKey /// **[optional]** CUDA stream to launch kernels within. Default is stream0. cudaStream_t stream; + KernelSource kernel_source; + + KernelLauncherFactory launcher_factory; + /** * @param[in] d_temp_storage * Device-accessible allocation of temporary storage. @@ -184,7 +233,9 @@ struct DispatchUniqueByKey NumSelectedIteratorT d_num_selected_out, EqualityOpT equality_op, OffsetT num_items, - cudaStream_t stream) + cudaStream_t stream, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}) : d_temp_storage(d_temp_storage) , temp_storage_bytes(temp_storage_bytes) , d_keys_in(d_keys_in) @@ -195,27 +246,18 @@ struct DispatchUniqueByKey , equality_op(equality_op) , num_items(num_items) , stream(stream) + , kernel_source(kernel_source) + , launcher_factory(launcher_factory) {} /****************************************************************************** * Dispatch entrypoints ******************************************************************************/ - template - CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t Invoke(InitKernel init_kernel, ScanKernel scan_kernel) + template + CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t + Invoke(InitKernelT init_kernel, UniqueByKeySweepKernelT sweep_kernel, ActivePolicyT policy = {}) { - using Policy = typename ActivePolicyT::UniqueByKeyPolicyT; - - using VsmemHelperT = cub::detail::vsmem_helper_default_fallback_policy_t< - Policy, - detail::unique_by_key::AgentUniqueByKey, - KeyInputIteratorT, - ValueInputIteratorT, - KeyOutputIteratorT, - ValueOutputIteratorT, - EqualityOpT, - OffsetT>; - cudaError error = cudaSuccess; do { @@ -228,17 +270,42 @@ struct DispatchUniqueByKey } // Number of input tiles - constexpr auto block_threads = VsmemHelperT::agent_policy_t::BLOCK_THREADS; - constexpr auto items_per_thread = VsmemHelperT::agent_policy_t::ITEMS_PER_THREAD; - int tile_size = block_threads * items_per_thread; - int num_tiles = static_cast(::cuda::ceil_div(num_items, tile_size)); - const auto vsmem_size = num_tiles * VsmemHelperT::vsmem_per_block; + const auto block_threads = VSMemHelperT::template BlockThreads< + typename ActivePolicyT::UniqueByKeyPolicyT, + KeyInputIteratorT, + ValueInputIteratorT, + KeyOutputIteratorT, + ValueOutputIteratorT, + EqualityOpT, + OffsetT>(policy.UniqueByKey()); + const auto items_per_thread = VSMemHelperT::template ItemsPerThread< + typename ActivePolicyT::UniqueByKeyPolicyT, + KeyInputIteratorT, + ValueInputIteratorT, + KeyOutputIteratorT, + ValueOutputIteratorT, + EqualityOpT, + OffsetT>(policy.UniqueByKey()); + int tile_size = block_threads * items_per_thread; + int num_tiles = static_cast(::cuda::ceil_div(num_items, tile_size)); + const auto vsmem_size = + num_tiles + * VSMemHelperT::template VSMemPerBlock< + typename ActivePolicyT::UniqueByKeyPolicyT, + KeyInputIteratorT, + ValueInputIteratorT, + KeyOutputIteratorT, + ValueOutputIteratorT, + EqualityOpT, + OffsetT>(policy.UniqueByKey()); // Specify temporary storage allocation requirements size_t allocation_sizes[2] = {0, vsmem_size}; + auto tile_state = kernel_source.TileState(); + // Bytes needed for tile status descriptors - error = CubDebug(ScanTileStateT::AllocationSize(num_tiles, allocation_sizes[0])); + error = CubDebug(tile_state.AllocationSize(num_tiles, allocation_sizes[0])); if (cudaSuccess != error) { break; @@ -259,8 +326,6 @@ struct DispatchUniqueByKey break; } - // Construct the tile status interface - ScanTileStateT tile_state; error = CubDebug(tile_state.Init(num_tiles, allocations[0], allocation_sizes[0])); if (cudaSuccess != error) { @@ -276,7 +341,7 @@ struct DispatchUniqueByKey #endif // CUB_DEBUG_LOG // Invoke init_kernel to initialize tile descriptors - THRUST_NS_QUALIFIER::cuda_cub::detail::triple_chevron(init_grid_size, INIT_KERNEL_THREADS, 0, stream) + launcher_factory(init_grid_size, INIT_KERNEL_THREADS, 0, stream) .doit(init_kernel, tile_state, num_tiles, d_num_selected_out); // Check for failure to launch @@ -313,13 +378,13 @@ struct DispatchUniqueByKey scan_grid_size.y = ::cuda::ceil_div(num_tiles, max_dim_x); scan_grid_size.x = CUB_MIN(num_tiles, max_dim_x); -// Log select_if_kernel configuration + // Log select_if_kernel configuration #ifdef CUB_DEBUG_LOG { // Get SM occupancy for unique_by_key_kernel - int scan_sm_occupancy; - error = CubDebug(MaxSmOccupancy(scan_sm_occupancy, // out - scan_kernel, + int sweep_sm_occupancy; + error = CubDebug(MaxSmOccupancy(sweep_sm_occupancy, // out + sweep_kernel, block_threads)); if (cudaSuccess != error) { @@ -334,14 +399,14 @@ struct DispatchUniqueByKey block_threads, (long long) stream, items_per_thread, - scan_sm_occupancy); + sweep_sm_occupancy); } #endif // CUB_DEBUG_LOG // Invoke select_if_kernel error = - THRUST_NS_QUALIFIER::cuda_cub::detail::triple_chevron(scan_grid_size, block_threads, 0, stream) - .doit(scan_kernel, + launcher_factory(scan_grid_size, block_threads, 0, stream) + .doit(sweep_kernel, d_keys_in, d_values_in, d_keys_out, @@ -372,21 +437,11 @@ struct DispatchUniqueByKey } template - CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t Invoke() + CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t Invoke(ActivePolicyT active_policy = {}) { - // Ensure kernels are instantiated. - return Invoke( - detail::scan::DeviceCompactInitKernel, - detail::unique_by_key::DeviceUniqueByKeySweepKernel< - typename PolicyHub::MaxPolicy, - KeyInputIteratorT, - ValueInputIteratorT, - KeyOutputIteratorT, - ValueOutputIteratorT, - NumSelectedIteratorT, - ScanTileStateT, - EqualityOpT, - OffsetT>); + auto wrapped_policy = detail::unique_by_key::MakeUniqueByKeyPolicyWrapper(active_policy); + + return Invoke(kernel_source.CompactInitKernel(), kernel_source.UniqueByKeySweepKernel(), wrapped_policy); } /** @@ -426,6 +481,7 @@ struct DispatchUniqueByKey * **[optional]** CUDA stream to launch kernels within. * Default is stream0. */ + template CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Dispatch( void* d_temp_storage, size_t& temp_storage_bytes, @@ -436,7 +492,10 @@ struct DispatchUniqueByKey NumSelectedIteratorT d_num_selected_out, EqualityOpT equality_op, OffsetT num_items, - cudaStream_t stream) + cudaStream_t stream, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {}) { cudaError_t error; do @@ -460,10 +519,12 @@ struct DispatchUniqueByKey d_num_selected_out, equality_op, num_items, - stream); + stream, + kernel_source, + launcher_factory); // Dispatch to chained policy - error = CubDebug(PolicyHub::MaxPolicy::Invoke(ptx_version, dispatch)); + error = CubDebug(max_policy.Invoke(ptx_version, dispatch)); if (cudaSuccess != error) { break; diff --git a/cub/cub/device/dispatch/kernels/unique_by_key.cuh b/cub/cub/device/dispatch/kernels/unique_by_key.cuh index 2a8cd5858a4..3a9b31e83d7 100644 --- a/cub/cub/device/dispatch/kernels/unique_by_key.cuh +++ b/cub/cub/device/dispatch/kernels/unique_by_key.cuh @@ -26,6 +26,35 @@ CUB_NAMESPACE_BEGIN namespace detail::unique_by_key { + +// TODO: this class should be templated on `typename... Ts` to avoid repetition, +// but due to an issue with NVCC 12.0 we currently template each member function +// individually instead. +struct VSMemHelper +{ + template + using VSMemHelperDefaultFallbackPolicyT = + vsmem_helper_default_fallback_policy_t; + + template + _CCCL_HOST_DEVICE static constexpr int BlockThreads(ActivePolicyT /*policy*/) + { + return VSMemHelperDefaultFallbackPolicyT::agent_policy_t::BLOCK_THREADS; + } + + template + _CCCL_HOST_DEVICE static constexpr int ItemsPerThread(ActivePolicyT /*policy*/) + { + return VSMemHelperDefaultFallbackPolicyT::agent_policy_t::ITEMS_PER_THREAD; + } + + template + _CCCL_HOST_DEVICE static constexpr ::cuda::std::size_t VSMemPerBlock(ActivePolicyT /*policy*/) + { + return VSMemHelperDefaultFallbackPolicyT::vsmem_per_block; + } +}; + /** * @brief Unique by key kernel entry point (multi-block) * @@ -93,11 +122,11 @@ template + typename OffsetT, + typename VSMemHelperT = VSMemHelper> __launch_bounds__(int( - vsmem_helper_default_fallback_policy_t< + VSMemHelperT::template VSMemHelperDefaultFallbackPolicyT< typename ChainedPolicyT::ActivePolicy::UniqueByKeyPolicyT, - AgentUniqueByKey, KeyInputIteratorT, ValueInputIteratorT, KeyOutputIteratorT, @@ -116,9 +145,8 @@ __launch_bounds__(int( int num_tiles, vsmem_t vsmem) { - using VsmemHelperT = vsmem_helper_default_fallback_policy_t< + using VsmemHelperT = typename VSMemHelperT::template VSMemHelperDefaultFallbackPolicyT< typename ChainedPolicyT::ActivePolicy::UniqueByKeyPolicyT, - AgentUniqueByKey, KeyInputIteratorT, ValueInputIteratorT, KeyOutputIteratorT, diff --git a/cub/cub/device/dispatch/tuning/tuning_unique_by_key.cuh b/cub/cub/device/dispatch/tuning/tuning_unique_by_key.cuh index 093a17207e2..c3e146f1136 100644 --- a/cub/cub/device/dispatch/tuning/tuning_unique_by_key.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_unique_by_key.cuh @@ -47,10 +47,7 @@ CUB_NAMESPACE_BEGIN -namespace detail -{ - -namespace unique_by_key +namespace detail::unique_by_key { enum class primitive_key @@ -770,6 +767,35 @@ struct sm100_tuning +struct UniqueByKeyPolicyWrapper : PolicyT +{ + CUB_RUNTIME_FUNCTION UniqueByKeyPolicyWrapper(PolicyT base) + : PolicyT(base) + {} +}; + +template +struct UniqueByKeyPolicyWrapper> + : StaticPolicyT +{ + CUB_RUNTIME_FUNCTION UniqueByKeyPolicyWrapper(StaticPolicyT base) + : StaticPolicyT(base) + {} + + CUB_RUNTIME_FUNCTION static constexpr PolicyWrapper UniqueByKey() + { + return cub::detail::MakePolicyWrapper(typename StaticPolicyT::UniqueByKeyPolicyT()); + } +}; + +template +CUB_RUNTIME_FUNCTION UniqueByKeyPolicyWrapper MakeUniqueByKeyPolicyWrapper(PolicyT policy) +{ + return UniqueByKeyPolicyWrapper{policy}; +} + template struct policy_hub { @@ -843,8 +869,7 @@ struct policy_hub using MaxPolicy = Policy1000; }; -} // namespace unique_by_key -} // namespace detail +} // namespace detail::unique_by_key template using DeviceUniqueByKeyPolicy CCCL_DEPRECATED_BECAUSE("This class is considered an implementation detail and it will "